[WIP] Non functional yet

Add ManiSkill environment configuration and wrappers

- Introduced `VideoRecordConfig` for video recording settings.
- Added `ManiskillEnvConfig` to encapsulate environment-specific configurations.
- Implemented various wrappers for the ManiSkill environment, including observation and action scaling.
- Enhanced the `make_maniskill` function to create a wrapped ManiSkill environment with video recording and observation processing.
- Updated the `actor_server` and `learner_server` to utilize the new configuration structure.
- Refactored the training pipeline to accommodate the new environment and policy configurations.
This commit is contained in:
AdilZouitine
2025-03-26 08:15:05 +00:00
committed by Michel Aractingi
parent 114ec644d0
commit 056f79d358
9 changed files with 667 additions and 436 deletions

View File

@@ -31,12 +31,19 @@ class SACConfig(PreTrainedConfig):
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
normalization_mapping: Mapping from feature types to normalization modes.
dataset_stats: Statistics for normalizing different data types.
camera_number: Number of cameras to use.
device: Device to use for training.
storage_device: Device to use for storage.
vision_encoder_name: Name of the vision encoder to use.
freeze_vision_encoder: Whether to freeze the vision encoder.
image_encoder_hidden_dim: Hidden dimension for the image encoder.
shared_encoder: Whether to use a shared encoder.
online_steps: Total number of online training steps.
online_env_seed: Seed for the online environment.
online_buffer_capacity: Capacity of the online replay buffer.
online_step_before_learning: Number of steps to collect before starting learning.
policy_update_freq: Frequency of policy updates.
discount: Discount factor for the RL algorithm.
temperature_init: Initial temperature for entropy regularization.
num_critics: Number of critic networks.
@@ -54,6 +61,8 @@ class SACConfig(PreTrainedConfig):
critic_network_kwargs: Additional arguments for critic networks.
actor_network_kwargs: Additional arguments for actor network.
policy_kwargs: Additional arguments for policy.
actor_learner_config: Configuration for actor-learner communication.
concurrency: Configuration for concurrency model.
"""
# Input / output structure
@@ -86,13 +95,21 @@ class SACConfig(PreTrainedConfig):
# Architecture specifics
camera_number: int = 1
device: str = "cuda"
storage_device: str = "cpu"
# Set to "helper2424/resnet10" for hil serl
vision_encoder_name: str | None = None
freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32
shared_encoder: bool = True
# Training parameter
online_steps: int = 1000000
online_env_seed: int = 10000
online_buffer_capacity: int = 10000
online_step_before_learning: int = 100
policy_update_freq: int = 1
# SAC algorithm parameters
discount: float = 0.99
temperature_init: float = 1.0
@@ -132,11 +149,17 @@ class SACConfig(PreTrainedConfig):
}
)
# Deprecated, kept for backward compatibility
actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: {
"learner_host": "127.0.0.1",
"learner_port": 50051,
"policy_parameters_push_frequency": 4,
}
)
concurrency: dict[str, str] = field(
default_factory=lambda: {
"actor": "threads",
"learner": "threads"
}
)