[WIP] Update SAC configuration and environment settings
- Reduced frame rate in `ManiskillEnvConfig` from 400 to 200. - Enhanced `SACConfig` with new dataclasses for actor, learner, and network configurations. - Improved input and output feature management in `SACConfig`. - Refactored `actor_server` and `learner_server` to access configuration properties directly. - Updated training pipeline to validate configurations and handle dataset repo IDs more robustly.
This commit is contained in:
@@ -173,7 +173,7 @@ class ManiskillEnvConfig(EnvConfig):
|
||||
control_mode: str = "pd_ee_delta_pose"
|
||||
state_dim: int = 25
|
||||
action_dim: int = 7
|
||||
fps: int = 400
|
||||
fps: int = 200
|
||||
episode_length: int = 50
|
||||
obs_type: str = "rgb"
|
||||
render_mode: str = "rgb_array"
|
||||
|
||||
@@ -16,58 +16,100 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.configs.types import NormalizationMode, PolicyFeature, FeatureType
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcurrencyConfig:
|
||||
actor: str = "threads"
|
||||
learner: str = "threads"
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorLearnerConfig:
|
||||
learner_host: str = "127.0.0.1"
|
||||
learner_port: int = 50051
|
||||
policy_parameters_push_frequency: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class CriticNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
final_activation: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyConfig:
|
||||
use_tanh_squash: bool = True
|
||||
log_std_min: int = -5
|
||||
log_std_max: int = 2
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACConfig(PreTrainedConfig):
|
||||
"""Configuration class for Soft Actor-Critic (SAC) policy.
|
||||
"""Soft Actor-Critic (SAC) configuration.
|
||||
|
||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
||||
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
||||
using experience collected from the environment.
|
||||
|
||||
This configuration class contains all the parameters needed to define a SAC agent,
|
||||
including network architectures, optimization settings, and algorithm-specific
|
||||
hyperparameters.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy.
|
||||
normalization_mapping: Mapping from feature types to normalization modes.
|
||||
dataset_stats: Statistics for normalizing different data types.
|
||||
camera_number: Number of cameras to use.
|
||||
device: Device to use for training.
|
||||
storage_device: Device to use for storage.
|
||||
vision_encoder_name: Name of the vision encoder to use.
|
||||
freeze_vision_encoder: Whether to freeze the vision encoder.
|
||||
image_encoder_hidden_dim: Hidden dimension for the image encoder.
|
||||
shared_encoder: Whether to use a shared encoder.
|
||||
online_steps: Total number of online training steps.
|
||||
actor_network: Configuration for the actor network architecture.
|
||||
critic_network: Configuration for the critic network architecture.
|
||||
policy: Configuration for the policy parameters.
|
||||
n_obs_steps: Number of observation steps to consider.
|
||||
normalization_mapping: Mapping of feature types to normalization modes.
|
||||
dataset_stats: Statistics for normalizing different types of inputs.
|
||||
input_features: Dictionary of input features with their types and shapes.
|
||||
output_features: Dictionary of output features with their types and shapes.
|
||||
camera_number: Number of cameras used for visual observations.
|
||||
device: Device to run the model on (e.g., "cuda", "cpu").
|
||||
storage_device: Device to store the model on.
|
||||
vision_encoder_name: Name of the vision encoder model.
|
||||
freeze_vision_encoder: Whether to freeze the vision encoder during training.
|
||||
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
||||
shared_encoder: Whether to use a shared encoder for actor and critic.
|
||||
concurrency: Configuration for concurrency settings.
|
||||
actor_learner: Configuration for actor-learner architecture.
|
||||
online_steps: Number of steps for online training.
|
||||
online_env_seed: Seed for the online environment.
|
||||
online_buffer_capacity: Capacity of the online replay buffer.
|
||||
online_step_before_learning: Number of steps to collect before starting learning.
|
||||
offline_buffer_capacity: Capacity of the offline replay buffer.
|
||||
online_step_before_learning: Number of steps before learning starts.
|
||||
policy_update_freq: Frequency of policy updates.
|
||||
discount: Discount factor for the RL algorithm.
|
||||
temperature_init: Initial temperature for entropy regularization.
|
||||
num_critics: Number of critic networks.
|
||||
num_subsample_critics: Number of critics to subsample.
|
||||
critic_lr: Learning rate for critic networks.
|
||||
actor_lr: Learning rate for actor network.
|
||||
temperature_lr: Learning rate for temperature parameter.
|
||||
critic_target_update_weight: Weight for soft target updates.
|
||||
utd_ratio: Update-to-data ratio (>1 to enable).
|
||||
state_encoder_hidden_dim: Hidden dimension for state encoder.
|
||||
latent_dim: Dimension of latent representation.
|
||||
target_entropy: Target entropy for automatic temperature tuning.
|
||||
use_backup_entropy: Whether to use backup entropy.
|
||||
grad_clip_norm: Gradient clipping norm.
|
||||
critic_network_kwargs: Additional arguments for critic networks.
|
||||
actor_network_kwargs: Additional arguments for actor network.
|
||||
policy_kwargs: Additional arguments for policy.
|
||||
actor_learner_config: Configuration for actor-learner communication.
|
||||
concurrency: Configuration for concurrency model.
|
||||
discount: Discount factor for the SAC algorithm.
|
||||
temperature_init: Initial temperature value.
|
||||
num_critics: Number of critics in the ensemble.
|
||||
num_subsample_critics: Number of subsampled critics for training.
|
||||
critic_lr: Learning rate for the critic network.
|
||||
actor_lr: Learning rate for the actor network.
|
||||
temperature_lr: Learning rate for the temperature parameter.
|
||||
critic_target_update_weight: Weight for the critic target update.
|
||||
utd_ratio: Update-to-data ratio for the UTD algorithm.
|
||||
state_encoder_hidden_dim: Hidden dimension size for the state encoder.
|
||||
latent_dim: Dimension of the latent space.
|
||||
target_entropy: Target entropy for the SAC algorithm.
|
||||
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
|
||||
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
|
||||
"""
|
||||
|
||||
# Input / output structure
|
||||
n_obs_steps: int = 1
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
@@ -76,6 +118,7 @@ class SACConfig(PreTrainedConfig):
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
dataset_stats: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": {
|
||||
@@ -93,6 +136,18 @@ class SACConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
}
|
||||
)
|
||||
output_features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,)),
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture specifics
|
||||
camera_number: int = 1
|
||||
device: str = "cuda"
|
||||
@@ -106,7 +161,8 @@ class SACConfig(PreTrainedConfig):
|
||||
# Training parameter
|
||||
online_steps: int = 1000000
|
||||
online_env_seed: int = 10000
|
||||
online_buffer_capacity: int = 10000
|
||||
online_buffer_capacity: int = 100000
|
||||
offline_buffer_capacity: int = 100000
|
||||
online_step_before_learning: int = 100
|
||||
policy_update_freq: int = 1
|
||||
|
||||
@@ -127,40 +183,21 @@ class SACConfig(PreTrainedConfig):
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Network configuration
|
||||
critic_network_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
"final_activation": None,
|
||||
}
|
||||
critic_network_kwargs: CriticNetworkConfig = field(
|
||||
default_factory=CriticNetworkConfig
|
||||
)
|
||||
actor_network_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
actor_network_kwargs: ActorNetworkConfig = field(
|
||||
default_factory=ActorNetworkConfig
|
||||
)
|
||||
policy_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"use_tanh_squash": True,
|
||||
"log_std_min": -5,
|
||||
"log_std_max": 2,
|
||||
"init_final": 0.05,
|
||||
}
|
||||
policy_kwargs: PolicyConfig = field(
|
||||
default_factory=PolicyConfig
|
||||
)
|
||||
|
||||
actor_learner_config: dict[str, str | int] = field(
|
||||
default_factory=lambda: {
|
||||
"learner_host": "127.0.0.1",
|
||||
"learner_port": 50051,
|
||||
"policy_parameters_push_frequency": 4,
|
||||
}
|
||||
actor_learner_config: ActorLearnerConfig = field(
|
||||
default_factory=ActorLearnerConfig
|
||||
)
|
||||
concurrency: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"actor": "threads",
|
||||
"learner": "threads"
|
||||
}
|
||||
concurrency: ConcurrencyConfig = field(
|
||||
default_factory=ConcurrencyConfig
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -181,9 +218,18 @@ class SACConfig(PreTrainedConfig):
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# TODO: Maybe we should remove this raise?
|
||||
if len(self.image_features) == 0:
|
||||
raise ValueError("You must provide at least one image among the inputs.")
|
||||
if "observation.image" not in self.input_features:
|
||||
raise ValueError("You must provide 'observation.image' in the input features")
|
||||
|
||||
if "observation.state" not in self.input_features:
|
||||
raise ValueError("You must provide 'observation.state' in the input features")
|
||||
|
||||
if "action" not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def image_features(self) -> list[str]:
|
||||
return [key for key in self.input_features.keys() if 'image' in key]
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
# TODO: (1) better device management
|
||||
|
||||
from dataclasses import asdict
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
@@ -88,7 +89,7 @@ class SACPolicy(
|
||||
critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
||||
**config.critic_network_kwargs,
|
||||
**asdict(config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
@@ -103,7 +104,7 @@ class SACPolicy(
|
||||
target_critic_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
|
||||
**config.critic_network_kwargs,
|
||||
**asdict(config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
@@ -121,10 +122,10 @@ class SACPolicy(
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **asdict(config.actor_network_kwargs)),
|
||||
action_dim=config.output_features["action"].shape[0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
**asdict(config.policy_kwargs),
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
|
||||
|
||||
Reference in New Issue
Block a user