diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index e6f91ce8..41ae1247 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -68,6 +68,7 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)] ) # TODO: add observation processor wrapper and remove preprocess_observation in the codebase + # https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/vector/vectorize_observation.py#L19, # env = ObservationProcessorWrapper(env=env) return env diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 6003acc0..66d6e5f9 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -81,35 +81,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten return return_observations -class ObservationProcessorWrapper(gym.vector.VectorEnvWrapper): - def __init__(self, env: gym.vector.VectorEnv): - super().__init__(env) - - def _observations(self, observations: dict[str, Any]) -> dict[str, Any]: - return preprocess_observation(observations) - - def reset( - self, - *, - seed: int | list[int] | None = None, - options: dict[str, Any] | None = None, - ): - """Modifies the observation returned from the environment ``reset`` using the :meth:`observation`.""" - observations, infos = self.env.reset(seed=seed, options=options) - return self._observations(observations), infos - - def step(self, actions): - """Modifies the observation returned from the environment ``step`` using the :meth:`observation`.""" - observations, rewards, terminations, truncations, infos = self.env.step(actions) - return ( - self._observations(observations), - rewards, - terminations, - truncations, - infos, - ) - - def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is # (need to also refactor preprocess_observation and externalize normalization from policies) diff --git a/lerobot/common/optim/optimizers.py b/lerobot/common/optim/optimizers.py index 8a5b1803..903434f5 100644 --- a/lerobot/common/optim/optimizers.py +++ b/lerobot/common/optim/optimizers.py @@ -46,6 +46,15 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): @abc.abstractmethod def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: + """ + Build the optimizer. It can be a single optimizer or a dictionary of optimizers. + NOTE: Multiple optimizers are useful when you have different models to optimize. + For example, you can have one optimizer for the policy and another one for the value function + in reinforcement learning settings. + + Returns: + The optimizer or a dictionary of optimizers. + """ raise NotImplementedError diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 3e0d4dfb..9734bcab 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -79,48 +79,28 @@ def create_stats_buffers( ) # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch) - if stats and key in stats: - # NOTE:(maractingi, azouitine): Change the order of these conditions because in online environments we don't have dataset stats - # Therefore, we don't access to full stats of the data, some elements either have min-max or mean-std only - if norm_mode is NormalizationMode.MEAN_STD: - if "mean" not in stats[key] or "std" not in stats[key]: - raise ValueError( - f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization" - ) - - if isinstance(stats[key]["mean"], np.ndarray): + if stats: + if isinstance(stats[key]["mean"], np.ndarray): + if norm_mode is NormalizationMode.MEAN_STD: buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) - elif isinstance(stats[key]["mean"], torch.Tensor): - # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated - # tensors anywhere (for example, when we use the same stats for normalization and - # unnormalization). See the logic here - # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. - buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) - buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) - else: - type_ = type(stats[key]["mean"]) - raise ValueError( - f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead." - ) - - elif norm_mode is NormalizationMode.MIN_MAX: - if "min" not in stats[key] or "max" not in stats[key]: - raise ValueError( - f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization" - ) - - if isinstance(stats[key]["min"], np.ndarray): + elif norm_mode is NormalizationMode.MIN_MAX: buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) - elif isinstance(stats[key]["min"], torch.Tensor): + elif isinstance(stats[key]["mean"], torch.Tensor): + # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated + # tensors anywhere (for example, when we use the same stats for normalization and + # unnormalization). See the logic here + # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. + if norm_mode is NormalizationMode.MEAN_STD: + buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) + buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) + elif norm_mode is NormalizationMode.MIN_MAX: buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) - else: - type_ = type(stats[key]["min"]) - raise ValueError( - f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead." - ) + else: + type_ = type(stats[key]["mean"]) + raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") stats_buffers[key] = buffer return stats_buffers @@ -169,13 +149,12 @@ class Normalize(nn.Module): setattr(self, "buffer_" + key.replace(".", "_"), buffer) # TODO(rcadene): should we remove torch.no_grad? - # @torch.no_grad + @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items(): if key not in batch: # FIXME(aliberts, rcadene): This might lead to silent fail! - # NOTE: (azouitine) This continues help us for instantiation SACPolicy continue norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) @@ -244,7 +223,7 @@ class Unnormalize(nn.Module): setattr(self, "buffer_" + key.replace(".", "_"), buffer) # TODO(rcadene): should we remove torch.no_grad? - # @torch.no_grad + @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items(): @@ -273,3 +252,170 @@ class Unnormalize(nn.Module): else: raise ValueError(norm_mode) return batch + + +# TODO: We should replace all normalization on the policies with register_buffer normalization +# and remove the `Normalize` and `Unnormalize` classes. +def _initialize_stats_buffers( + module: nn.Module, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, +) -> None: + """Register statistics buffers (mean/std or min/max) on the given *module*. + + The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`, + but is factored out so it can be reused by both classes and stay in sync. + """ + for key, ft in features.items(): + norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + shape: tuple[int, ...] = tuple(ft.shape) + if ft.type is FeatureType.VISUAL: + # reduce spatial dimensions, keep channel dimension only + c, *_ = shape + shape = (c, 1, 1) + + prefix = key.replace(".", "_") + + if norm_mode is NormalizationMode.MEAN_STD: + mean = torch.full(shape, torch.inf, dtype=torch.float32) + std = torch.full(shape, torch.inf, dtype=torch.float32) + + if stats and key in stats and "mean" in stats[key] and "std" in stats[key]: + mean_data = stats[key]["mean"] + std_data = stats[key]["std"] + if isinstance(mean_data, np.ndarray): + mean = torch.from_numpy(mean_data).to(dtype=torch.float32) + std = torch.from_numpy(std_data).to(dtype=torch.float32) + elif isinstance(mean_data, torch.Tensor): + mean = mean_data.clone().to(dtype=torch.float32) + std = std_data.clone().to(dtype=torch.float32) + else: + raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") + + module.register_buffer(f"{prefix}_mean", mean) + module.register_buffer(f"{prefix}_std", std) + continue + + if norm_mode is NormalizationMode.MIN_MAX: + min_val = torch.full(shape, torch.inf, dtype=torch.float32) + max_val = torch.full(shape, torch.inf, dtype=torch.float32) + + if stats and key in stats and "min" in stats[key] and "max" in stats[key]: + min_data = stats[key]["min"] + max_data = stats[key]["max"] + if isinstance(min_data, np.ndarray): + min_val = torch.from_numpy(min_data).to(dtype=torch.float32) + max_val = torch.from_numpy(max_data).to(dtype=torch.float32) + elif isinstance(min_data, torch.Tensor): + min_val = min_data.clone().to(dtype=torch.float32) + max_val = max_data.clone().to(dtype=torch.float32) + else: + raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") + + module.register_buffer(f"{prefix}_min", min_val) + module.register_buffer(f"{prefix}_max", max_val) + continue + + raise ValueError(norm_mode) + + +class NormalizeBuffer(nn.Module): + """Same as `Normalize` but statistics are stored as registered buffers rather than parameters.""" + + def __init__( + self, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, + ): + super().__init__() + self.features = features + self.norm_map = norm_map + + _initialize_stats_buffers(self, features, norm_map, stats) + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + batch = dict(batch) + for key, ft in self.features.items(): + if key not in batch: + continue + + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + prefix = key.replace(".", "_") + + if norm_mode is NormalizationMode.MEAN_STD: + mean = getattr(self, f"{prefix}_mean") + std = getattr(self, f"{prefix}_std") + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") + batch[key] = (batch[key] - mean) / (std + 1e-8) + continue + + if norm_mode is NormalizationMode.MIN_MAX: + min_val = getattr(self, f"{prefix}_min") + max_val = getattr(self, f"{prefix}_max") + assert not torch.isinf(min_val).any(), _no_stats_error_str("min") + assert not torch.isinf(max_val).any(), _no_stats_error_str("max") + batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8) + batch[key] = batch[key] * 2 - 1 + continue + + raise ValueError(norm_mode) + + return batch + + +class UnnormalizeBuffer(nn.Module): + """Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics.""" + + def __init__( + self, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, + ): + super().__init__() + self.features = features + self.norm_map = norm_map + + _initialize_stats_buffers(self, features, norm_map, stats) + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + batch = dict(batch) + for key, ft in self.features.items(): + if key not in batch: + continue + + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + prefix = key.replace(".", "_") + + if norm_mode is NormalizationMode.MEAN_STD: + mean = getattr(self, f"{prefix}_mean") + std = getattr(self, f"{prefix}_std") + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") + batch[key] = batch[key] * std + mean + continue + + if norm_mode is NormalizationMode.MIN_MAX: + min_val = getattr(self, f"{prefix}_min") + max_val = getattr(self, f"{prefix}_max") + assert not torch.isinf(min_val).any(), _no_stats_error_str("min") + assert not torch.isinf(max_val).any(), _no_stats_error_str("max") + batch[key] = (batch[key] + 1) / 2 + batch[key] = batch[key] * (max_val - min_val) + min_val + continue + + raise ValueError(norm_mode) + + return batch diff --git a/lerobot/common/policies/reward_model/modeling_classifier.py b/lerobot/common/policies/reward_model/modeling_classifier.py index 476185db..4d665d12 100644 --- a/lerobot/common/policies/reward_model/modeling_classifier.py +++ b/lerobot/common/policies/reward_model/modeling_classifier.py @@ -9,9 +9,6 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") -logger = logging.getLogger(__name__) - class ClassifierOutput: """Wrapper for classifier outputs with additional metadata.""" diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index e851afdc..1f2e9bb8 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -24,6 +24,12 @@ from lerobot.configs.types import NormalizationMode @dataclass class ConcurrencyConfig: + """Configuration for the concurrency of the actor and learner. + Possible values are: + - "threads": Use threads for the actor and learner. + - "processes": Use processes for the actor and learner. + """ + actor: str = "threads" learner: str = "threads" @@ -68,51 +74,9 @@ class SACConfig(PreTrainedConfig): This configuration class contains all the parameters needed to define a SAC agent, including network architectures, optimization settings, and algorithm-specific hyperparameters. - - Args: - actor_network_kwargs: Configuration for the actor network architecture. - critic_network_kwargs: Configuration for the critic network architecture. - discrete_critic_network_kwargs: Configuration for the discrete critic network. - policy_kwargs: Configuration for the policy parameters. - n_obs_steps: Number of observation steps to consider. - normalization_mapping: Mapping of feature types to normalization modes. - dataset_stats: Statistics for normalizing different types of inputs. - input_features: Dictionary of input features with their types and shapes. - output_features: Dictionary of output features with their types and shapes. - camera_number: Number of cameras used for visual observations. - device: Device to run the model on (e.g., "cuda", "cpu"). - storage_device: Device to store the model on. - vision_encoder_name: Name of the vision encoder model. - freeze_vision_encoder: Whether to freeze the vision encoder during training. - image_encoder_hidden_dim: Hidden dimension size for the image encoder. - shared_encoder: Whether to use a shared encoder for actor and critic. - num_discrete_actions: Number of discrete actions, eg for gripper actions. - image_embedding_pooling_dim: Dimension of the image embedding pooling. - concurrency: Configuration for concurrency settings. - actor_learner_config: Configuration for actor-learner architecture. - online_steps: Number of steps for online training. - online_env_seed: Seed for the online environment. - online_buffer_capacity: Capacity of the online replay buffer. - offline_buffer_capacity: Capacity of the offline replay buffer. - async_prefetch: Whether to use asynchronous prefetching for the buffers. - online_step_before_learning: Number of steps before learning starts. - policy_update_freq: Frequency of policy updates. - discount: Discount factor for the SAC algorithm. - temperature_init: Initial temperature value. - num_critics: Number of critics in the ensemble. - num_subsample_critics: Number of subsampled critics for training. - critic_lr: Learning rate for the critic network. - actor_lr: Learning rate for the actor network. - temperature_lr: Learning rate for the temperature parameter. - critic_target_update_weight: Weight for the critic target update. - utd_ratio: Update-to-data ratio for the UTD algorithm. - state_encoder_hidden_dim: Hidden dimension size for the state encoder. - latent_dim: Dimension of the latent space. - target_entropy: Target entropy for the SAC algorithm. - use_backup_entropy: Whether to use backup entropy for the SAC algorithm. - grad_clip_norm: Gradient clipping norm for the SAC algorithm. """ + # Mapping of feature types to normalization modes normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.MEAN_STD, @@ -122,6 +86,7 @@ class SACConfig(PreTrainedConfig): } ) + # Statistics for normalizing different types of inputs dataset_stats: dict[str, dict[str, list[float]]] | None = field( default_factory=lambda: { "observation.image": { @@ -140,47 +105,81 @@ class SACConfig(PreTrainedConfig): ) # Architecture specifics + # Device to run the model on (e.g., "cuda", "cpu") device: str = "cpu" + # Device to store the model on storage_device: str = "cpu" - # Set to "helper2424/resnet10" for hil serl + # Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10) vision_encoder_name: str | None = None + # Whether to freeze the vision encoder during training freeze_vision_encoder: bool = True + # Hidden dimension size for the image encoder image_encoder_hidden_dim: int = 32 + # Whether to use a shared encoder for actor and critic shared_encoder: bool = True + # Number of discrete actions, eg for gripper actions num_discrete_actions: int | None = None + # Dimension of the image embedding pooling image_embedding_pooling_dim: int = 8 # Training parameter + # Number of steps for online training online_steps: int = 1000000 + # Seed for the online environment online_env_seed: int = 10000 + # Capacity of the online replay buffer online_buffer_capacity: int = 100000 + # Capacity of the offline replay buffer offline_buffer_capacity: int = 100000 + # Whether to use asynchronous prefetching for the buffers async_prefetch: bool = False + # Number of steps before learning starts online_step_before_learning: int = 100 + # Frequency of policy updates policy_update_freq: int = 1 # SAC algorithm parameters + # Discount factor for the SAC algorithm discount: float = 0.99 + # Initial temperature value temperature_init: float = 1.0 + # Number of critics in the ensemble num_critics: int = 2 + # Number of subsampled critics for training num_subsample_critics: int | None = None + # Learning rate for the critic network critic_lr: float = 3e-4 + # Learning rate for the actor network actor_lr: float = 3e-4 + # Learning rate for the temperature parameter temperature_lr: float = 3e-4 + # Weight for the critic target update critic_target_update_weight: float = 0.005 - utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1 + # Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1) + utd_ratio: int = 1 + # Hidden dimension size for the state encoder state_encoder_hidden_dim: int = 256 + # Dimension of the latent space latent_dim: int = 256 + # Target entropy for the SAC algorithm target_entropy: float | None = None + # Whether to use backup entropy for the SAC algorithm use_backup_entropy: bool = True + # Gradient clipping norm for the SAC algorithm grad_clip_norm: float = 40.0 # Network configuration + # Configuration for the critic network architecture critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + # Configuration for the actor network architecture actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) + # Configuration for the policy parameters policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) + # Configuration for the discrete critic network discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + # Configuration for actor-learner architecture actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) + # Configuration for concurrency settings (you can use threads or processes for the actor and learner) concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) # Optimizations diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 04145e12..257f37cb 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -27,7 +27,7 @@ import torch.nn.functional as F # noqa: N812 from torch import Tensor from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution -from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.normalize import NormalizeBuffer, UnnormalizeBuffer from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.utils import get_device_from_parameters @@ -394,17 +394,16 @@ class SACPolicy( self.normalize_inputs = nn.Identity() self.normalize_targets = nn.Identity() self.unnormalize_outputs = nn.Identity() - - if self.config.dataset_stats: + if self.config.dataset_stats is not None: params = _convert_normalization_params_to_tensor(self.config.dataset_stats) - self.normalize_inputs = Normalize( + self.normalize_inputs = NormalizeBuffer( self.config.input_features, self.config.normalization_mapping, params ) stats = dataset_stats or params - self.normalize_targets = Normalize( + self.normalize_targets = NormalizeBuffer( self.config.output_features, self.config.normalization_mapping, stats ) - self.unnormalize_outputs = Unnormalize( + self.unnormalize_outputs = UnnormalizeBuffer( self.config.output_features, self.config.normalization_mapping, stats ) @@ -506,7 +505,7 @@ class SACObservationEncoder(nn.Module): if not self.has_images: return - if self.config.vision_encoder_name: + if self.config.vision_encoder_name is not None: self.image_encoder = PretrainedImageEncoder(self.config) else: self.image_encoder = DefaultImageEncoder(self.config) diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index ebdc29e9..930b236f 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -19,9 +19,10 @@ import os.path as osp import platform import subprocess import time -from copy import copy +from copy import copy, deepcopy from datetime import datetime, timezone from pathlib import Path +from statistics import mean import numpy as np import torch @@ -108,11 +109,14 @@ def is_amp_available(device: str): raise ValueError(f"Unknown device '{device}.") -def init_logging(log_file=None): +def init_logging(log_file: Path | None = None, display_pid: bool = False): def custom_format(record): dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" - message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}" + + # NOTE: Display PID is useful for multi-process logging. + pid_str = f"[PID: {os.getpid()}]" if display_pid else "" + message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}" return message logging.basicConfig(level=logging.INFO) @@ -238,30 +242,99 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool: class TimerManager: + """ + Lightweight utility to measure elapsed time. + + Examples + -------- + >>> timer = TimerManager("Policy", log=False) + >>> for _ in range(3): + ... with timer: + ... time.sleep(0.01) + >>> print(timer.last, timer.fps_avg, timer.percentile(90)) + """ + def __init__( self, - elapsed_time_list: list[float] | None = None, - label="Elapsed time", - log=True, + label: str = "Elapsed-time", + log: bool = True, + logger: logging.Logger | None = None, ): self.label = label - self.elapsed_time_list = elapsed_time_list self.log = log - self.elapsed = 0.0 + self.logger = logger + self._start: float | None = None + self._history: list[float] = [] def __enter__(self): - self.start = time.perf_counter() + return self.start() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + def start(self): + self._start = time.perf_counter() return self - def __exit__(self, exc_type, exc_value, traceback): - self.elapsed: float = time.perf_counter() - self.start - - if self.elapsed_time_list is not None: - self.elapsed_time_list.append(self.elapsed) - + def stop(self) -> float: + if self._start is None: + raise RuntimeError("Timer was never started.") + elapsed = time.perf_counter() - self._start + self._history.append(elapsed) + self._start = None if self.log: - print(f"{self.label}: {self.elapsed:.6f} seconds") + if self.logger is not None: + self.logger.info(f"{self.label}: {elapsed:.6f} s") + else: + logging.info(f"{self.label}: {elapsed:.6f} s") + return elapsed + + def reset(self): + self._history.clear() @property - def elapsed_seconds(self): - return self.elapsed + def last(self) -> float: + return self._history[-1] if self._history else 0.0 + + @property + def avg(self) -> float: + return mean(self._history) if self._history else 0.0 + + @property + def total(self) -> float: + return sum(self._history) + + @property + def count(self) -> int: + return len(self._history) + + @property + def history(self) -> list[float]: + return deepcopy(self._history) + + @property + def fps_history(self) -> list[float]: + return [1.0 / t for t in self._history] + + @property + def fps_last(self) -> float: + return 0.0 if self.last == 0 else 1.0 / self.last + + @property + def fps_avg(self) -> float: + return 0.0 if self.avg == 0 else 1.0 / self.avg + + def percentile(self, p: float) -> float: + """ + Return the p-th percentile of recorded times. + """ + if not self._history: + return 0.0 + return float(np.percentile(self._history, p)) + + def fps_percentile(self, p: float) -> float: + """ + FPS corresponding to the p-th percentile time. + """ + val = self.percentile(p) + return 0.0 if val == 0 else 1.0 / val diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py index db8911d5..b24099f3 100644 --- a/lerobot/common/utils/wandb_utils.py +++ b/lerobot/common/utils/wandb_utils.py @@ -123,9 +123,9 @@ class WandBLogger: if step is None and custom_step_key is None: raise ValueError("Either step or custom_step_key must be provided.") - # NOTE: This is not simple. Wandb step is it must always monotonically increase and it + # NOTE: This is not simple. Wandb step must always monotonically increase and it # increases with each wandb.log call, but in the case of asynchronous RL for example, - # multiple time steps is possible for example, the interaction step with the environment, + # multiple time steps is possible. For example, the interaction step with the environment, # the training step, the evaluation step, etc. So we need to define a custom step key # to log the correct step for each metric. if custom_step_key is not None: diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py index b2651573..5f7d5648 100644 --- a/lerobot/scripts/server/actor_server.py +++ b/lerobot/scripts/server/actor_server.py @@ -13,13 +13,67 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Actor server runner for distributed HILSerl robot policy training. + +This script implements the actor component of the distributed HILSerl architecture. +It executes the policy in the robot environment, collects experience, +and sends transitions to the learner server for policy updates. + +Examples of usage: + +- Start an actor server for real robot training with human-in-the-loop intervention: +```bash +python lerobot/scripts/server/actor_server.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +- Run with a specific robot type for a pick and place task: +```bash +python lerobot/scripts/server/actor_server.py \ + --config_path lerobot/configs/train_config_hilserl_so100.json \ + --robot.type=so100 \ + --task=pick_and_place +``` + +- Set a custom workspace bound for the robot's end-effector: +```bash +python lerobot/scripts/server/actor_server.py \ + --config_path lerobot/configs/train_config_hilserl_so100.json \ + --env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \ + --env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]" +``` + +- Run with specific camera crop parameters: +```bash +python lerobot/scripts/server/actor_server.py \ + --config_path lerobot/configs/train_config_hilserl_so100.json \ + --env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}" +``` + +**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner +server is started before launching the actor. + +**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the +gamepad to take control of the robot during training. Initially intervene frequently, then gradually +reduce interventions as the policy improves. + +**WORKFLOW**: +1. Determine robot workspace bounds using `find_joint_limits.py` +2. Record demonstrations with `gym_manipulator.py` in record mode +3. Process the dataset and determine camera crops with `crop_dataset_roi.py` +4. Start the learner server with the training configuration +5. Start this actor server with the same configuration +6. Use human interventions to guide policy learning + +For more details on the complete HILSerl training workflow, see: +https://github.com/michel-aractingi/lerobot-hilserl-guide +""" import logging import os import time from functools import lru_cache from queue import Empty -from statistics import mean, quantiles import grpc import torch @@ -65,10 +119,12 @@ ACTOR_SHUTDOWN_TIMEOUT = 30 @parser.wrap() def actor_cli(cfg: TrainPipelineConfig): cfg.validate() + display_pid = False if not use_threads(cfg): import torch.multiprocessing as mp mp.set_start_method("spawn") + display_pid = True # Create logs directory to ensure it exists log_dir = os.path.join(cfg.output_dir, "logs") @@ -76,7 +132,7 @@ def actor_cli(cfg: TrainPipelineConfig): log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log") # Initialize logging with explicit log file - init_logging(log_file=log_file) + init_logging(log_file=log_file, display_pid=display_pid) logging.info(f"Actor logging initialized, writing to {log_file}") shutdown_event = setup_process_handlers(use_threads(cfg)) @@ -193,7 +249,7 @@ def act_with_policy( log_dir = os.path.join(cfg.output_dir, "logs") os.makedirs(log_dir, exist_ok=True) log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log") - init_logging(log_file=log_file) + init_logging(log_file=log_file, display_pid=True) logging.info("Actor policy process logging initialized") logging.info("make_env online") @@ -223,12 +279,13 @@ def act_with_policy( # NOTE: For the moment we will solely handle the case of a single environment sum_reward_episode = 0 list_transition_to_send_to_learner = [] - list_policy_time = [] episode_intervention = False # Add counters for intervention rate calculation episode_intervention_steps = 0 episode_total_steps = 0 + policy_timer = TimerManager("Policy inference", log=False) + for interaction_step in range(cfg.policy.online_steps): start_time = time.perf_counter() if shutdown_event.is_set(): @@ -237,13 +294,9 @@ def act_with_policy( if interaction_step >= cfg.policy.online_step_before_learning: # Time policy inference and check if it meets FPS requirement - with TimerManager( - elapsed_time_list=list_policy_time, - label="Policy inference time", - log=False, - ) as timer: # noqa: F841 + with policy_timer: action = policy.select_action(batch=obs) - policy_fps = 1.0 / (list_policy_time[-1] + 1e-9) + policy_fps = policy_timer.fps_last log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) @@ -291,8 +344,8 @@ def act_with_policy( ) list_transition_to_send_to_learner = [] - stats = get_frequency_stats(list_policy_time) - list_policy_time.clear() + stats = get_frequency_stats(policy_timer) + policy_timer.reset() # Calculate intervention rate intervention_rate = 0.0 @@ -429,7 +482,7 @@ def receive_policy( log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log") # Initialize logging with explicit log file - init_logging(log_file=log_file) + init_logging(log_file=log_file, display_pid=True) logging.info("Actor receive policy process logging initialized") # Setup process handlers to handle shutdown signal @@ -484,7 +537,7 @@ def send_transitions( log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log") # Initialize logging with explicit log file - init_logging(log_file=log_file) + init_logging(log_file=log_file, display_pid=True) logging.info("Actor transitions process logging initialized") # Setup process handlers to handle shutdown signal @@ -533,7 +586,7 @@ def send_interactions( log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log") # Initialize logging with explicit log file - init_logging(log_file=log_file) + init_logging(log_file=log_file, display_pid=True) logging.info("Actor interactions process logging initialized") # Setup process handlers to handle shutdown signal @@ -632,25 +685,24 @@ def push_transitions_to_transport_queue(transitions: list, transitions_queue): transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner)) -def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]: +def get_frequency_stats(timer: TimerManager) -> dict[str, float]: """Get the frequency statistics of the policy. Args: - list_policy_time (list[float]): The list of policy times. + timer (TimerManager): The timer with collected metrics. Returns: dict[str, float]: The frequency statistics of the policy. """ stats = {} - list_policy_fps = [1.0 / t for t in list_policy_time] - if len(list_policy_fps) > 1: - policy_fps = mean(list_policy_fps) - quantiles_90 = quantiles(list_policy_fps, n=10)[-1] - logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}") - logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}") + if timer.count > 1: + avg_fps = timer.fps_avg + p90_fps = timer.fps_percentile(90) + logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}") + logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}") stats = { - "Policy frequency [Hz]": policy_fps, - "Policy frequency 90th-p [Hz]": quantiles_90, + "Policy frequency [Hz]": avg_fps, + "Policy frequency 90th-p [Hz]": p90_fps, } return stats diff --git a/lerobot/scripts/server/crop_dataset_roi.py b/lerobot/scripts/server/crop_dataset_roi.py index 45881de8..22d64094 100644 --- a/lerobot/scripts/server/crop_dataset_roi.py +++ b/lerobot/scripts/server/crop_dataset_roi.py @@ -203,6 +203,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset( if key in new_dataset.meta.info["features"]: new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size) + # TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset prev_episode_index = 0 for frame_idx in tqdm(range(len(original_dataset))): frame = original_dataset[frame_idx] diff --git a/lerobot/scripts/server/end_effector_control_utils.py b/lerobot/scripts/server/end_effector_control_utils.py index 1b8613f4..e04f3d87 100644 --- a/lerobot/scripts/server/end_effector_control_utils.py +++ b/lerobot/scripts/server/end_effector_control_utils.py @@ -23,10 +23,9 @@ import numpy as np import torch from lerobot.common.robot_devices.utils import busy_wait +from lerobot.common.utils.utils import init_logging from lerobot.scripts.server.kinematics import RobotKinematics -logging.basicConfig(level=logging.INFO) - class InputController: """Base class for input controllers that generate motion deltas.""" @@ -726,6 +725,8 @@ if __name__ == "__main__": from lerobot.common.robot_devices.robots.utils import make_robot_from_config from lerobot.scripts.server.gym_manipulator import make_robot_env + init_logging() + parser = argparse.ArgumentParser(description="Test end-effector control") parser.add_argument( "--mode", diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index ed2f3a45..7ec22e31 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -1588,19 +1588,20 @@ class GamepadControlWrapper(gym.Wrapper): input_threshold: Minimum movement delta to consider as active input. """ super().__init__(env) - from lerobot.scripts.server.end_effector_control_utils import ( - GamepadController, - GamepadControllerHID, - ) # use HidApi for macos if sys.platform == "darwin": + # NOTE: On macOS, pygame doesn’t reliably detect input from some controllers so we fall back to hidapi + from lerobot.scripts.server.end_effector_control_utils import GamepadControllerHID + self.controller = GamepadControllerHID( x_step_size=x_step_size, y_step_size=y_step_size, z_step_size=z_step_size, ) else: + from lerobot.scripts.server.end_effector_control_utils import GamepadController + self.controller = GamepadController( x_step_size=x_step_size, y_step_size=y_step_size, @@ -1748,6 +1749,8 @@ class GymHilDeviceWrapper(gym.Wrapper): for k in obs: obs[k] = obs[k].to(self.device) if "action_intervention" in info: + # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device + info["action_intervention"] = info["action_intervention"].astype(np.float32) info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) return obs, reward, terminated, truncated, info @@ -1756,6 +1759,8 @@ class GymHilDeviceWrapper(gym.Wrapper): for k in obs: obs[k] = obs[k].to(self.device) if "action_intervention" in info: + # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device + info["action_intervention"] = info["action_intervention"].astype(np.float32) info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) return obs, info diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py index eae1e925..7fbf4621 100644 --- a/lerobot/scripts/server/learner_server.py +++ b/lerobot/scripts/server/learner_server.py @@ -14,6 +14,66 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Learner server runner for distributed HILSerl robot policy training. + +This script implements the learner component of the distributed HILSerl architecture. +It initializes the policy network, maintains replay buffers, and updates +the policy based on transitions received from the actor server. + +Examples of usage: + +- Start a learner server for training: +```bash +python lerobot/scripts/server/learner_server.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +- Run with specific SAC hyperparameters: +```bash +python lerobot/scripts/server/learner_server.py \ + --config_path lerobot/configs/train_config_hilserl_so100.json \ + --learner.sac.alpha=0.1 \ + --learner.sac.gamma=0.99 +``` + +- Run with a specific dataset and wandb logging: +```bash +python lerobot/scripts/server/learner_server.py \ + --config_path lerobot/configs/train_config_hilserl_so100.json \ + --dataset.repo_id=username/pick_lift_cube \ + --wandb.enable=true \ + --wandb.project=hilserl_training +``` + +- Run with a pretrained policy for fine-tuning: +```bash +python lerobot/scripts/server/learner_server.py \ + --config_path lerobot/configs/train_config_hilserl_so100.json \ + --pretrained_policy_name_or_path=outputs/previous_training/checkpoints/080000/pretrained_model +``` + +- Run with a reward classifier model: +```bash +python lerobot/scripts/server/learner_server.py \ + --config_path lerobot/configs/train_config_hilserl_so100.json \ + --reward_classifier_pretrained_path=outputs/reward_model/best_model +``` + +**NOTE**: Start the learner server before launching the actor server. The learner opens a gRPC server +to communicate with actors. + +**NOTE**: Training progress can be monitored through Weights & Biases if wandb.enable is set to true +in your configuration. + +**WORKFLOW**: +1. Create training configuration with proper policy, dataset, and environment settings +2. Start this learner server with the configuration +3. Start an actor server with the same configuration +4. Monitor training progress through wandb dashboard + +For more details on the complete HILSerl training workflow, see: +https://github.com/michel-aractingi/lerobot-hilserl-guide +""" import logging import os @@ -73,7 +133,6 @@ from lerobot.scripts.server.utils import ( LOG_PREFIX = "[LEARNER]" -logging.basicConfig(level=logging.INFO) ################################################# # MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS # @@ -113,13 +172,17 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None): if job_name is None: raise ValueError("Job name must be specified either in config or as a parameter") + display_pid = False + if not use_threads(cfg): + display_pid = True + # Create logs directory to ensure it exists log_dir = os.path.join(cfg.output_dir, "logs") os.makedirs(log_dir, exist_ok=True) log_file = os.path.join(log_dir, f"learner_{job_name}.log") # Initialize logging with explicit log file - init_logging(log_file=log_file) + init_logging(log_file=log_file, display_pid=display_pid) logging.info(f"Learner logging initialized, writing to {log_file}") logging.info(pformat(cfg.to_dict())) @@ -275,7 +338,7 @@ def add_actor_information_and_train( log_dir = os.path.join(cfg.output_dir, "logs") os.makedirs(log_dir, exist_ok=True) log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log") - init_logging(log_file=log_file) + init_logging(log_file=log_file, display_pid=True) logging.info("Initialized logging for actor information and training process") logging.info("Initializing policy") @@ -604,7 +667,7 @@ def start_learner_server( log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log") # Initialize logging with explicit log file - init_logging(log_file=log_file) + init_logging(log_file=log_file, display_pid=True) logging.info("Learner server process logging initialized") # Setup process handlers to handle shutdown signal diff --git a/pyproject.toml b/pyproject.toml index d8262f17..47c46a03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,7 @@ dora = [ ] dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] -hilserl = ["transformers>=4.48.0", "gym-hil>=0.1.2", "protobuf>=5.29.3", "grpcio>=1.70.0"] +hilserl = ["transformers>=4.48", "gym-hil>=0.1.3", "protobuf>=5.29.3", "grpcio>=1.70.0"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] pi0 = ["transformers>=4.48.0"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_sac_config.py index 98333e9f..c7733785 100644 --- a/tests/policies/test_sac_config.py +++ b/tests/policies/test_sac_config.py @@ -59,7 +59,6 @@ def test_sac_config_default_initialization(): assert config.num_critics == 2 # Architecture specifics - assert config.camera_number == 1 assert config.vision_encoder_name is None assert config.freeze_vision_encoder is True assert config.image_encoder_hidden_dim == 32 diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 18e3b6f2..a97d5014 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -9,6 +9,13 @@ from lerobot.common.policies.sac.modeling_sac import MLP, SACPolicy from lerobot.common.utils.random_utils import seeded_context from lerobot.configs.types import FeatureType, PolicyFeature +try: + import transformers # noqa: F401 + + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + def test_mlp_with_default_args(): mlp = MLP(input_dim=10, hidden_dims=[256, 256]) @@ -274,6 +281,7 @@ def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_di "batch_size,state_dim,action_dim,vision_encoder_name", [(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], ) +@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed") def test_sac_policy_with_pretrained_encoder( batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str ):