[WIP] Non functional yet
Add ManiSkill environment configuration and wrappers - Introduced `VideoRecordConfig` for video recording settings. - Added `ManiskillEnvConfig` to encapsulate environment-specific configurations. - Implemented various wrappers for the ManiSkill environment, including observation and action scaling. - Enhanced the `make_maniskill` function to create a wrapped ManiSkill environment with video recording and observation processing. - Updated the `actor_server` and `learner_server` to utilize the new configuration structure. - Refactored the training pipeline to accommodate the new environment and policy configurations.
This commit is contained in:
committed by
Michel Aractingi
parent
114ec644d0
commit
056f79d358
@@ -31,12 +31,19 @@ class SACConfig(PreTrainedConfig):
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
|
||||
normalization_mapping: Mapping from feature types to normalization modes.
|
||||
dataset_stats: Statistics for normalizing different data types.
|
||||
camera_number: Number of cameras to use.
|
||||
device: Device to use for training.
|
||||
storage_device: Device to use for storage.
|
||||
vision_encoder_name: Name of the vision encoder to use.
|
||||
freeze_vision_encoder: Whether to freeze the vision encoder.
|
||||
image_encoder_hidden_dim: Hidden dimension for the image encoder.
|
||||
shared_encoder: Whether to use a shared encoder.
|
||||
online_steps: Total number of online training steps.
|
||||
online_env_seed: Seed for the online environment.
|
||||
online_buffer_capacity: Capacity of the online replay buffer.
|
||||
online_step_before_learning: Number of steps to collect before starting learning.
|
||||
policy_update_freq: Frequency of policy updates.
|
||||
discount: Discount factor for the RL algorithm.
|
||||
temperature_init: Initial temperature for entropy regularization.
|
||||
num_critics: Number of critic networks.
|
||||
@@ -54,6 +61,8 @@ class SACConfig(PreTrainedConfig):
|
||||
critic_network_kwargs: Additional arguments for critic networks.
|
||||
actor_network_kwargs: Additional arguments for actor network.
|
||||
policy_kwargs: Additional arguments for policy.
|
||||
actor_learner_config: Configuration for actor-learner communication.
|
||||
concurrency: Configuration for concurrency model.
|
||||
"""
|
||||
|
||||
# Input / output structure
|
||||
@@ -86,13 +95,21 @@ class SACConfig(PreTrainedConfig):
|
||||
|
||||
# Architecture specifics
|
||||
camera_number: int = 1
|
||||
device: str = "cuda"
|
||||
storage_device: str = "cpu"
|
||||
# Set to "helper2424/resnet10" for hil serl
|
||||
vision_encoder_name: str | None = None
|
||||
freeze_vision_encoder: bool = True
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
|
||||
|
||||
# Training parameter
|
||||
online_steps: int = 1000000
|
||||
online_env_seed: int = 10000
|
||||
online_buffer_capacity: int = 10000
|
||||
online_step_before_learning: int = 100
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
@@ -132,11 +149,17 @@ class SACConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
# Deprecated, kept for backward compatibility
|
||||
actor_learner_config: dict[str, str | int] = field(
|
||||
default_factory=lambda: {
|
||||
"learner_host": "127.0.0.1",
|
||||
"learner_port": 50051,
|
||||
"policy_parameters_push_frequency": 4,
|
||||
}
|
||||
)
|
||||
concurrency: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"actor": "threads",
|
||||
"learner": "threads"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user