[WIP] Update SAC configuration and environment settings

- Reduced frame rate in `ManiskillEnvConfig` from 400 to 200.
- Enhanced `SACConfig` with new dataclasses for actor, learner, and network configurations.
- Improved input and output feature management in `SACConfig`.
- Refactored `actor_server` and `learner_server` to access configuration properties directly.
- Updated training pipeline to validate configurations and handle dataset repo IDs more robustly.
This commit is contained in:
AdilZouitine
2025-03-27 08:13:20 +00:00
parent 626e5dd35c
commit 052a4acfc2
7 changed files with 183 additions and 126 deletions

View File

@@ -173,7 +173,7 @@ class ManiskillEnvConfig(EnvConfig):
control_mode: str = "pd_ee_delta_pose"
state_dim: int = 25
action_dim: int = 7
fps: int = 400
fps: int = 200
episode_length: int = 50
obs_type: str = "rgb"
render_mode: str = "rgb_array"

View File

@@ -16,58 +16,100 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Optional
from lerobot.common.optim.optimizers import MultiAdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
@dataclass
class ConcurrencyConfig:
actor: str = "threads"
learner: str = "threads"
@dataclass
class ActorLearnerConfig:
learner_host: str = "127.0.0.1"
learner_port: int = 50051
policy_parameters_push_frequency: int = 4
@dataclass
class CriticNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
final_activation: str | None = None
@dataclass
class ActorNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
@dataclass
class PolicyConfig:
use_tanh_squash: bool = True
log_std_min: int = -5
log_std_max: int = 2
init_final: float = 0.05
@PreTrainedConfig.register_subclass("sac")
@dataclass
class SACConfig(PreTrainedConfig):
"""Configuration class for Soft Actor-Critic (SAC) policy.
"""Soft Actor-Critic (SAC) configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
reinforcement learning framework. It learns a policy and a Q-function simultaneously
using experience collected from the environment.
This configuration class contains all the parameters needed to define a SAC agent,
including network architectures, optimization settings, and algorithm-specific
hyperparameters.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
normalization_mapping: Mapping from feature types to normalization modes.
dataset_stats: Statistics for normalizing different data types.
camera_number: Number of cameras to use.
device: Device to use for training.
storage_device: Device to use for storage.
vision_encoder_name: Name of the vision encoder to use.
freeze_vision_encoder: Whether to freeze the vision encoder.
image_encoder_hidden_dim: Hidden dimension for the image encoder.
shared_encoder: Whether to use a shared encoder.
online_steps: Total number of online training steps.
actor_network: Configuration for the actor network architecture.
critic_network: Configuration for the critic network architecture.
policy: Configuration for the policy parameters.
n_obs_steps: Number of observation steps to consider.
normalization_mapping: Mapping of feature types to normalization modes.
dataset_stats: Statistics for normalizing different types of inputs.
input_features: Dictionary of input features with their types and shapes.
output_features: Dictionary of output features with their types and shapes.
camera_number: Number of cameras used for visual observations.
device: Device to run the model on (e.g., "cuda", "cpu").
storage_device: Device to store the model on.
vision_encoder_name: Name of the vision encoder model.
freeze_vision_encoder: Whether to freeze the vision encoder during training.
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
shared_encoder: Whether to use a shared encoder for actor and critic.
concurrency: Configuration for concurrency settings.
actor_learner: Configuration for actor-learner architecture.
online_steps: Number of steps for online training.
online_env_seed: Seed for the online environment.
online_buffer_capacity: Capacity of the online replay buffer.
online_step_before_learning: Number of steps to collect before starting learning.
offline_buffer_capacity: Capacity of the offline replay buffer.
online_step_before_learning: Number of steps before learning starts.
policy_update_freq: Frequency of policy updates.
discount: Discount factor for the RL algorithm.
temperature_init: Initial temperature for entropy regularization.
num_critics: Number of critic networks.
num_subsample_critics: Number of critics to subsample.
critic_lr: Learning rate for critic networks.
actor_lr: Learning rate for actor network.
temperature_lr: Learning rate for temperature parameter.
critic_target_update_weight: Weight for soft target updates.
utd_ratio: Update-to-data ratio (>1 to enable).
state_encoder_hidden_dim: Hidden dimension for state encoder.
latent_dim: Dimension of latent representation.
target_entropy: Target entropy for automatic temperature tuning.
use_backup_entropy: Whether to use backup entropy.
grad_clip_norm: Gradient clipping norm.
critic_network_kwargs: Additional arguments for critic networks.
actor_network_kwargs: Additional arguments for actor network.
policy_kwargs: Additional arguments for policy.
actor_learner_config: Configuration for actor-learner communication.
concurrency: Configuration for concurrency model.
discount: Discount factor for the SAC algorithm.
temperature_init: Initial temperature value.
num_critics: Number of critics in the ensemble.
num_subsample_critics: Number of subsampled critics for training.
critic_lr: Learning rate for the critic network.
actor_lr: Learning rate for the actor network.
temperature_lr: Learning rate for the temperature parameter.
critic_target_update_weight: Weight for the critic target update.
utd_ratio: Update-to-data ratio for the UTD algorithm.
state_encoder_hidden_dim: Hidden dimension size for the state encoder.
latent_dim: Dimension of the latent space.
target_entropy: Target entropy for the SAC algorithm.
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
"""
# Input / output structure
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
@@ -76,6 +118,7 @@ class SACConfig(PreTrainedConfig):
"ACTION": NormalizationMode.MIN_MAX,
}
)
dataset_stats: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"observation.image": {
@@ -93,6 +136,18 @@ class SACConfig(PreTrainedConfig):
}
)
input_features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
}
)
output_features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,)),
}
)
# Architecture specifics
camera_number: int = 1
device: str = "cuda"
@@ -106,7 +161,8 @@ class SACConfig(PreTrainedConfig):
# Training parameter
online_steps: int = 1000000
online_env_seed: int = 10000
online_buffer_capacity: int = 10000
online_buffer_capacity: int = 100000
offline_buffer_capacity: int = 100000
online_step_before_learning: int = 100
policy_update_freq: int = 1
@@ -127,40 +183,21 @@ class SACConfig(PreTrainedConfig):
grad_clip_norm: float = 40.0
# Network configuration
critic_network_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"hidden_dims": [256, 256],
"activate_final": True,
"final_activation": None,
}
critic_network_kwargs: CriticNetworkConfig = field(
default_factory=CriticNetworkConfig
)
actor_network_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"hidden_dims": [256, 256],
"activate_final": True,
}
actor_network_kwargs: ActorNetworkConfig = field(
default_factory=ActorNetworkConfig
)
policy_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"use_tanh_squash": True,
"log_std_min": -5,
"log_std_max": 2,
"init_final": 0.05,
}
policy_kwargs: PolicyConfig = field(
default_factory=PolicyConfig
)
actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: {
"learner_host": "127.0.0.1",
"learner_port": 50051,
"policy_parameters_push_frequency": 4,
}
actor_learner_config: ActorLearnerConfig = field(
default_factory=ActorLearnerConfig
)
concurrency: dict[str, str] = field(
default_factory=lambda: {
"actor": "threads",
"learner": "threads"
}
concurrency: ConcurrencyConfig = field(
default_factory=ConcurrencyConfig
)
def __post_init__(self):
@@ -181,9 +218,18 @@ class SACConfig(PreTrainedConfig):
return None
def validate_features(self) -> None:
# TODO: Maybe we should remove this raise?
if len(self.image_features) == 0:
raise ValueError("You must provide at least one image among the inputs.")
if "observation.image" not in self.input_features:
raise ValueError("You must provide 'observation.image' in the input features")
if "observation.state" not in self.input_features:
raise ValueError("You must provide 'observation.state' in the input features")
if "action" not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
def image_features(self) -> list[str]:
return [key for key in self.input_features.keys() if 'image' in key]
@property
def observation_delta_indices(self) -> list:

View File

@@ -17,6 +17,7 @@
# TODO: (1) better device management
from dataclasses import asdict
import math
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
@@ -88,7 +89,7 @@ class SACPolicy(
critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
**config.critic_network_kwargs,
**asdict(config.critic_network_kwargs),
)
for _ in range(config.num_critics)
]
@@ -103,7 +104,7 @@ class SACPolicy(
target_critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
**config.critic_network_kwargs,
**asdict(config.critic_network_kwargs),
)
for _ in range(config.num_critics)
]
@@ -121,10 +122,10 @@ class SACPolicy(
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
action_dim=config.output_features["action"].shape[0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
**asdict(config.policy_kwargs),
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)