[HIL-SERL] Review feedback modifications (#1112)

This commit is contained in:
Adil Zouitine
2025-05-15 15:24:41 +02:00
committed by GitHub
parent c7a3973653
commit 2051dd38fc
17 changed files with 504 additions and 180 deletions

View File

@@ -68,6 +68,7 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)] [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
) )
# TODO: add observation processor wrapper and remove preprocess_observation in the codebase # TODO: add observation processor wrapper and remove preprocess_observation in the codebase
# https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/vector/vectorize_observation.py#L19,
# env = ObservationProcessorWrapper(env=env) # env = ObservationProcessorWrapper(env=env)
return env return env

View File

@@ -81,35 +81,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return return_observations return return_observations
class ObservationProcessorWrapper(gym.vector.VectorEnvWrapper):
def __init__(self, env: gym.vector.VectorEnv):
super().__init__(env)
def _observations(self, observations: dict[str, Any]) -> dict[str, Any]:
return preprocess_observation(observations)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
):
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
observations, infos = self.env.reset(seed=seed, options=options)
return self._observations(observations), infos
def step(self, actions):
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observations, rewards, terminations, truncations, infos = self.env.step(actions)
return (
self._observations(observations),
rewards,
terminations,
truncations,
infos,
)
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to also refactor preprocess_observation and externalize normalization from policies) # (need to also refactor preprocess_observation and externalize normalization from policies)

View File

@@ -46,6 +46,15 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
"""
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
NOTE: Multiple optimizers are useful when you have different models to optimize.
For example, you can have one optimizer for the policy and another one for the value function
in reinforcement learning settings.
Returns:
The optimizer or a dictionary of optimizers.
"""
raise NotImplementedError raise NotImplementedError

View File

@@ -79,48 +79,28 @@ def create_stats_buffers(
) )
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch) # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats and key in stats: if stats:
# NOTE:(maractingi, azouitine): Change the order of these conditions because in online environments we don't have dataset stats if isinstance(stats[key]["mean"], np.ndarray):
# Therefore, we don't access to full stats of the data, some elements either have min-max or mean-std only if norm_mode is NormalizationMode.MEAN_STD:
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" not in stats[key] or "std" not in stats[key]:
raise ValueError(
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
)
if isinstance(stats[key]["mean"], np.ndarray):
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
elif isinstance(stats[key]["mean"], torch.Tensor): elif norm_mode is NormalizationMode.MIN_MAX:
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
)
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" not in stats[key] or "max" not in stats[key]:
raise ValueError(
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
)
if isinstance(stats[key]["min"], np.ndarray):
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["min"], torch.Tensor): elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else: else:
type_ = type(stats[key]["min"]) type_ = type(stats[key]["mean"])
raise ValueError( raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
)
stats_buffers[key] = buffer stats_buffers[key] = buffer
return stats_buffers return stats_buffers
@@ -169,13 +149,12 @@ class Normalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer) setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad? # TODO(rcadene): should we remove torch.no_grad?
# @torch.no_grad @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items(): for key, ft in self.features.items():
if key not in batch: if key not in batch:
# FIXME(aliberts, rcadene): This might lead to silent fail! # FIXME(aliberts, rcadene): This might lead to silent fail!
# NOTE: (azouitine) This continues help us for instantiation SACPolicy
continue continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
@@ -244,7 +223,7 @@ class Unnormalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer) setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad? # TODO(rcadene): should we remove torch.no_grad?
# @torch.no_grad @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items(): for key, ft in self.features.items():
@@ -273,3 +252,170 @@ class Unnormalize(nn.Module):
else: else:
raise ValueError(norm_mode) raise ValueError(norm_mode)
return batch return batch
# TODO: We should replace all normalization on the policies with register_buffer normalization
# and remove the `Normalize` and `Unnormalize` classes.
def _initialize_stats_buffers(
module: nn.Module,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> None:
"""Register statistics buffers (mean/std or min/max) on the given *module*.
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
but is factored out so it can be reused by both classes and stay in sync.
"""
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
shape: tuple[int, ...] = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# reduce spatial dimensions, keep channel dimension only
c, *_ = shape
shape = (c, 1, 1)
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.full(shape, torch.inf, dtype=torch.float32)
std = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
mean_data = stats[key]["mean"]
std_data = stats[key]["std"]
if isinstance(mean_data, np.ndarray):
mean = torch.from_numpy(mean_data).to(dtype=torch.float32)
std = torch.from_numpy(std_data).to(dtype=torch.float32)
elif isinstance(mean_data, torch.Tensor):
mean = mean_data.clone().to(dtype=torch.float32)
std = std_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_mean", mean)
module.register_buffer(f"{prefix}_std", std)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
min_data = stats[key]["min"]
max_data = stats[key]["max"]
if isinstance(min_data, np.ndarray):
min_val = torch.from_numpy(min_data).to(dtype=torch.float32)
max_val = torch.from_numpy(max_data).to(dtype=torch.float32)
elif isinstance(min_data, torch.Tensor):
min_val = min_data.clone().to(dtype=torch.float32)
max_val = max_data.clone().to(dtype=torch.float32)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
module.register_buffer(f"{prefix}_min", min_val)
module.register_buffer(f"{prefix}_max", max_val)
continue
raise ValueError(norm_mode)
class NormalizeBuffer(nn.Module):
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
batch[key] = batch[key] * 2 - 1
continue
raise ValueError(norm_mode)
return batch
class UnnormalizeBuffer(nn.Module):
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch)
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = getattr(self, f"{prefix}_mean")
std = getattr(self, f"{prefix}_std")
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = getattr(self, f"{prefix}_min")
max_val = getattr(self, f"{prefix}_max")
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max_val - min_val) + min_val
continue
raise ValueError(norm_mode)
return batch

View File

@@ -9,9 +9,6 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class ClassifierOutput: class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata.""" """Wrapper for classifier outputs with additional metadata."""

View File

@@ -24,6 +24,12 @@ from lerobot.configs.types import NormalizationMode
@dataclass @dataclass
class ConcurrencyConfig: class ConcurrencyConfig:
"""Configuration for the concurrency of the actor and learner.
Possible values are:
- "threads": Use threads for the actor and learner.
- "processes": Use processes for the actor and learner.
"""
actor: str = "threads" actor: str = "threads"
learner: str = "threads" learner: str = "threads"
@@ -68,51 +74,9 @@ class SACConfig(PreTrainedConfig):
This configuration class contains all the parameters needed to define a SAC agent, This configuration class contains all the parameters needed to define a SAC agent,
including network architectures, optimization settings, and algorithm-specific including network architectures, optimization settings, and algorithm-specific
hyperparameters. hyperparameters.
Args:
actor_network_kwargs: Configuration for the actor network architecture.
critic_network_kwargs: Configuration for the critic network architecture.
discrete_critic_network_kwargs: Configuration for the discrete critic network.
policy_kwargs: Configuration for the policy parameters.
n_obs_steps: Number of observation steps to consider.
normalization_mapping: Mapping of feature types to normalization modes.
dataset_stats: Statistics for normalizing different types of inputs.
input_features: Dictionary of input features with their types and shapes.
output_features: Dictionary of output features with their types and shapes.
camera_number: Number of cameras used for visual observations.
device: Device to run the model on (e.g., "cuda", "cpu").
storage_device: Device to store the model on.
vision_encoder_name: Name of the vision encoder model.
freeze_vision_encoder: Whether to freeze the vision encoder during training.
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
shared_encoder: Whether to use a shared encoder for actor and critic.
num_discrete_actions: Number of discrete actions, eg for gripper actions.
image_embedding_pooling_dim: Dimension of the image embedding pooling.
concurrency: Configuration for concurrency settings.
actor_learner_config: Configuration for actor-learner architecture.
online_steps: Number of steps for online training.
online_env_seed: Seed for the online environment.
online_buffer_capacity: Capacity of the online replay buffer.
offline_buffer_capacity: Capacity of the offline replay buffer.
async_prefetch: Whether to use asynchronous prefetching for the buffers.
online_step_before_learning: Number of steps before learning starts.
policy_update_freq: Frequency of policy updates.
discount: Discount factor for the SAC algorithm.
temperature_init: Initial temperature value.
num_critics: Number of critics in the ensemble.
num_subsample_critics: Number of subsampled critics for training.
critic_lr: Learning rate for the critic network.
actor_lr: Learning rate for the actor network.
temperature_lr: Learning rate for the temperature parameter.
critic_target_update_weight: Weight for the critic target update.
utd_ratio: Update-to-data ratio for the UTD algorithm.
state_encoder_hidden_dim: Hidden dimension size for the state encoder.
latent_dim: Dimension of the latent space.
target_entropy: Target entropy for the SAC algorithm.
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
""" """
# Mapping of feature types to normalization modes
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD, "VISUAL": NormalizationMode.MEAN_STD,
@@ -122,6 +86,7 @@ class SACConfig(PreTrainedConfig):
} }
) )
# Statistics for normalizing different types of inputs
dataset_stats: dict[str, dict[str, list[float]]] | None = field( dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: { default_factory=lambda: {
"observation.image": { "observation.image": {
@@ -140,47 +105,81 @@ class SACConfig(PreTrainedConfig):
) )
# Architecture specifics # Architecture specifics
# Device to run the model on (e.g., "cuda", "cpu")
device: str = "cpu" device: str = "cpu"
# Device to store the model on
storage_device: str = "cpu" storage_device: str = "cpu"
# Set to "helper2424/resnet10" for hil serl # Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
vision_encoder_name: str | None = None vision_encoder_name: str | None = None
# Whether to freeze the vision encoder during training
freeze_vision_encoder: bool = True freeze_vision_encoder: bool = True
# Hidden dimension size for the image encoder
image_encoder_hidden_dim: int = 32 image_encoder_hidden_dim: int = 32
# Whether to use a shared encoder for actor and critic
shared_encoder: bool = True shared_encoder: bool = True
# Number of discrete actions, eg for gripper actions
num_discrete_actions: int | None = None num_discrete_actions: int | None = None
# Dimension of the image embedding pooling
image_embedding_pooling_dim: int = 8 image_embedding_pooling_dim: int = 8
# Training parameter # Training parameter
# Number of steps for online training
online_steps: int = 1000000 online_steps: int = 1000000
# Seed for the online environment
online_env_seed: int = 10000 online_env_seed: int = 10000
# Capacity of the online replay buffer
online_buffer_capacity: int = 100000 online_buffer_capacity: int = 100000
# Capacity of the offline replay buffer
offline_buffer_capacity: int = 100000 offline_buffer_capacity: int = 100000
# Whether to use asynchronous prefetching for the buffers
async_prefetch: bool = False async_prefetch: bool = False
# Number of steps before learning starts
online_step_before_learning: int = 100 online_step_before_learning: int = 100
# Frequency of policy updates
policy_update_freq: int = 1 policy_update_freq: int = 1
# SAC algorithm parameters # SAC algorithm parameters
# Discount factor for the SAC algorithm
discount: float = 0.99 discount: float = 0.99
# Initial temperature value
temperature_init: float = 1.0 temperature_init: float = 1.0
# Number of critics in the ensemble
num_critics: int = 2 num_critics: int = 2
# Number of subsampled critics for training
num_subsample_critics: int | None = None num_subsample_critics: int | None = None
# Learning rate for the critic network
critic_lr: float = 3e-4 critic_lr: float = 3e-4
# Learning rate for the actor network
actor_lr: float = 3e-4 actor_lr: float = 3e-4
# Learning rate for the temperature parameter
temperature_lr: float = 3e-4 temperature_lr: float = 3e-4
# Weight for the critic target update
critic_target_update_weight: float = 0.005 critic_target_update_weight: float = 0.005
utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1 # Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
utd_ratio: int = 1
# Hidden dimension size for the state encoder
state_encoder_hidden_dim: int = 256 state_encoder_hidden_dim: int = 256
# Dimension of the latent space
latent_dim: int = 256 latent_dim: int = 256
# Target entropy for the SAC algorithm
target_entropy: float | None = None target_entropy: float | None = None
# Whether to use backup entropy for the SAC algorithm
use_backup_entropy: bool = True use_backup_entropy: bool = True
# Gradient clipping norm for the SAC algorithm
grad_clip_norm: float = 40.0 grad_clip_norm: float = 40.0
# Network configuration # Network configuration
# Configuration for the critic network architecture
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for the actor network architecture
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
# Configuration for the policy parameters
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
# Configuration for the discrete critic network
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for actor-learner architecture
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
# Optimizations # Optimizations

View File

@@ -27,7 +27,7 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import NormalizeBuffer, UnnormalizeBuffer
from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.sac.configuration_sac import SACConfig from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.policies.utils import get_device_from_parameters
@@ -394,17 +394,16 @@ class SACPolicy(
self.normalize_inputs = nn.Identity() self.normalize_inputs = nn.Identity()
self.normalize_targets = nn.Identity() self.normalize_targets = nn.Identity()
self.unnormalize_outputs = nn.Identity() self.unnormalize_outputs = nn.Identity()
if self.config.dataset_stats is not None:
if self.config.dataset_stats:
params = _convert_normalization_params_to_tensor(self.config.dataset_stats) params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
self.normalize_inputs = Normalize( self.normalize_inputs = NormalizeBuffer(
self.config.input_features, self.config.normalization_mapping, params self.config.input_features, self.config.normalization_mapping, params
) )
stats = dataset_stats or params stats = dataset_stats or params
self.normalize_targets = Normalize( self.normalize_targets = NormalizeBuffer(
self.config.output_features, self.config.normalization_mapping, stats self.config.output_features, self.config.normalization_mapping, stats
) )
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = UnnormalizeBuffer(
self.config.output_features, self.config.normalization_mapping, stats self.config.output_features, self.config.normalization_mapping, stats
) )
@@ -506,7 +505,7 @@ class SACObservationEncoder(nn.Module):
if not self.has_images: if not self.has_images:
return return
if self.config.vision_encoder_name: if self.config.vision_encoder_name is not None:
self.image_encoder = PretrainedImageEncoder(self.config) self.image_encoder = PretrainedImageEncoder(self.config)
else: else:
self.image_encoder = DefaultImageEncoder(self.config) self.image_encoder = DefaultImageEncoder(self.config)

View File

@@ -19,9 +19,10 @@ import os.path as osp
import platform import platform
import subprocess import subprocess
import time import time
from copy import copy from copy import copy, deepcopy
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from statistics import mean
import numpy as np import numpy as np
import torch import torch
@@ -108,11 +109,14 @@ def is_amp_available(device: str):
raise ValueError(f"Unknown device '{device}.") raise ValueError(f"Unknown device '{device}.")
def init_logging(log_file=None): def init_logging(log_file: Path | None = None, display_pid: bool = False):
def custom_format(record): def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}" fnameline = f"{record.pathname}:{record.lineno}"
message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}"
# NOTE: Display PID is useful for multi-process logging.
pid_str = f"[PID: {os.getpid()}]" if display_pid else ""
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
return message return message
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@@ -238,30 +242,99 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
class TimerManager: class TimerManager:
"""
Lightweight utility to measure elapsed time.
Examples
--------
>>> timer = TimerManager("Policy", log=False)
>>> for _ in range(3):
... with timer:
... time.sleep(0.01)
>>> print(timer.last, timer.fps_avg, timer.percentile(90))
"""
def __init__( def __init__(
self, self,
elapsed_time_list: list[float] | None = None, label: str = "Elapsed-time",
label="Elapsed time", log: bool = True,
log=True, logger: logging.Logger | None = None,
): ):
self.label = label self.label = label
self.elapsed_time_list = elapsed_time_list
self.log = log self.log = log
self.elapsed = 0.0 self.logger = logger
self._start: float | None = None
self._history: list[float] = []
def __enter__(self): def __enter__(self):
self.start = time.perf_counter() return self.start()
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def start(self):
self._start = time.perf_counter()
return self return self
def __exit__(self, exc_type, exc_value, traceback): def stop(self) -> float:
self.elapsed: float = time.perf_counter() - self.start if self._start is None:
raise RuntimeError("Timer was never started.")
if self.elapsed_time_list is not None: elapsed = time.perf_counter() - self._start
self.elapsed_time_list.append(self.elapsed) self._history.append(elapsed)
self._start = None
if self.log: if self.log:
print(f"{self.label}: {self.elapsed:.6f} seconds") if self.logger is not None:
self.logger.info(f"{self.label}: {elapsed:.6f} s")
else:
logging.info(f"{self.label}: {elapsed:.6f} s")
return elapsed
def reset(self):
self._history.clear()
@property @property
def elapsed_seconds(self): def last(self) -> float:
return self.elapsed return self._history[-1] if self._history else 0.0
@property
def avg(self) -> float:
return mean(self._history) if self._history else 0.0
@property
def total(self) -> float:
return sum(self._history)
@property
def count(self) -> int:
return len(self._history)
@property
def history(self) -> list[float]:
return deepcopy(self._history)
@property
def fps_history(self) -> list[float]:
return [1.0 / t for t in self._history]
@property
def fps_last(self) -> float:
return 0.0 if self.last == 0 else 1.0 / self.last
@property
def fps_avg(self) -> float:
return 0.0 if self.avg == 0 else 1.0 / self.avg
def percentile(self, p: float) -> float:
"""
Return the p-th percentile of recorded times.
"""
if not self._history:
return 0.0
return float(np.percentile(self._history, p))
def fps_percentile(self, p: float) -> float:
"""
FPS corresponding to the p-th percentile time.
"""
val = self.percentile(p)
return 0.0 if val == 0 else 1.0 / val

View File

@@ -123,9 +123,9 @@ class WandBLogger:
if step is None and custom_step_key is None: if step is None and custom_step_key is None:
raise ValueError("Either step or custom_step_key must be provided.") raise ValueError("Either step or custom_step_key must be provided.")
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it # NOTE: This is not simple. Wandb step must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example, # increases with each wandb.log call, but in the case of asynchronous RL for example,
# multiple time steps is possible for example, the interaction step with the environment, # multiple time steps is possible. For example, the interaction step with the environment,
# the training step, the evaluation step, etc. So we need to define a custom step key # the training step, the evaluation step, etc. So we need to define a custom step key
# to log the correct step for each metric. # to log the correct step for each metric.
if custom_step_key is not None: if custom_step_key is not None:

View File

@@ -13,13 +13,67 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Actor server runner for distributed HILSerl robot policy training.
This script implements the actor component of the distributed HILSerl architecture.
It executes the policy in the robot environment, collects experience,
and sends transitions to the learner server for policy updates.
Examples of usage:
- Start an actor server for real robot training with human-in-the-loop intervention:
```bash
python lerobot/scripts/server/actor_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
```
- Run with a specific robot type for a pick and place task:
```bash
python lerobot/scripts/server/actor_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--robot.type=so100 \
--task=pick_and_place
```
- Set a custom workspace bound for the robot's end-effector:
```bash
python lerobot/scripts/server/actor_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \
--env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]"
```
- Run with specific camera crop parameters:
```bash
python lerobot/scripts/server/actor_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}"
```
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
server is started before launching the actor.
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
reduce interventions as the policy improves.
**WORKFLOW**:
1. Determine robot workspace bounds using `find_joint_limits.py`
2. Record demonstrations with `gym_manipulator.py` in record mode
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
4. Start the learner server with the training configuration
5. Start this actor server with the same configuration
6. Use human interventions to guide policy learning
For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide
"""
import logging import logging
import os import os
import time import time
from functools import lru_cache from functools import lru_cache
from queue import Empty from queue import Empty
from statistics import mean, quantiles
import grpc import grpc
import torch import torch
@@ -65,10 +119,12 @@ ACTOR_SHUTDOWN_TIMEOUT = 30
@parser.wrap() @parser.wrap()
def actor_cli(cfg: TrainPipelineConfig): def actor_cli(cfg: TrainPipelineConfig):
cfg.validate() cfg.validate()
display_pid = False
if not use_threads(cfg): if not use_threads(cfg):
import torch.multiprocessing as mp import torch.multiprocessing as mp
mp.set_start_method("spawn") mp.set_start_method("spawn")
display_pid = True
# Create logs directory to ensure it exists # Create logs directory to ensure it exists
log_dir = os.path.join(cfg.output_dir, "logs") log_dir = os.path.join(cfg.output_dir, "logs")
@@ -76,7 +132,7 @@ def actor_cli(cfg: TrainPipelineConfig):
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log") log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
# Initialize logging with explicit log file # Initialize logging with explicit log file
init_logging(log_file=log_file) init_logging(log_file=log_file, display_pid=display_pid)
logging.info(f"Actor logging initialized, writing to {log_file}") logging.info(f"Actor logging initialized, writing to {log_file}")
shutdown_event = setup_process_handlers(use_threads(cfg)) shutdown_event = setup_process_handlers(use_threads(cfg))
@@ -193,7 +249,7 @@ def act_with_policy(
log_dir = os.path.join(cfg.output_dir, "logs") log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log") log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
init_logging(log_file=log_file) init_logging(log_file=log_file, display_pid=True)
logging.info("Actor policy process logging initialized") logging.info("Actor policy process logging initialized")
logging.info("make_env online") logging.info("make_env online")
@@ -223,12 +279,13 @@ def act_with_policy(
# NOTE: For the moment we will solely handle the case of a single environment # NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0 sum_reward_episode = 0
list_transition_to_send_to_learner = [] list_transition_to_send_to_learner = []
list_policy_time = []
episode_intervention = False episode_intervention = False
# Add counters for intervention rate calculation # Add counters for intervention rate calculation
episode_intervention_steps = 0 episode_intervention_steps = 0
episode_total_steps = 0 episode_total_steps = 0
policy_timer = TimerManager("Policy inference", log=False)
for interaction_step in range(cfg.policy.online_steps): for interaction_step in range(cfg.policy.online_steps):
start_time = time.perf_counter() start_time = time.perf_counter()
if shutdown_event.is_set(): if shutdown_event.is_set():
@@ -237,13 +294,9 @@ def act_with_policy(
if interaction_step >= cfg.policy.online_step_before_learning: if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement # Time policy inference and check if it meets FPS requirement
with TimerManager( with policy_timer:
elapsed_time_list=list_policy_time,
label="Policy inference time",
log=False,
) as timer: # noqa: F841
action = policy.select_action(batch=obs) action = policy.select_action(batch=obs)
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9) policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
@@ -291,8 +344,8 @@ def act_with_policy(
) )
list_transition_to_send_to_learner = [] list_transition_to_send_to_learner = []
stats = get_frequency_stats(list_policy_time) stats = get_frequency_stats(policy_timer)
list_policy_time.clear() policy_timer.reset()
# Calculate intervention rate # Calculate intervention rate
intervention_rate = 0.0 intervention_rate = 0.0
@@ -429,7 +482,7 @@ def receive_policy(
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log") log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
# Initialize logging with explicit log file # Initialize logging with explicit log file
init_logging(log_file=log_file) init_logging(log_file=log_file, display_pid=True)
logging.info("Actor receive policy process logging initialized") logging.info("Actor receive policy process logging initialized")
# Setup process handlers to handle shutdown signal # Setup process handlers to handle shutdown signal
@@ -484,7 +537,7 @@ def send_transitions(
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log") log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
# Initialize logging with explicit log file # Initialize logging with explicit log file
init_logging(log_file=log_file) init_logging(log_file=log_file, display_pid=True)
logging.info("Actor transitions process logging initialized") logging.info("Actor transitions process logging initialized")
# Setup process handlers to handle shutdown signal # Setup process handlers to handle shutdown signal
@@ -533,7 +586,7 @@ def send_interactions(
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log") log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
# Initialize logging with explicit log file # Initialize logging with explicit log file
init_logging(log_file=log_file) init_logging(log_file=log_file, display_pid=True)
logging.info("Actor interactions process logging initialized") logging.info("Actor interactions process logging initialized")
# Setup process handlers to handle shutdown signal # Setup process handlers to handle shutdown signal
@@ -632,25 +685,24 @@ def push_transitions_to_transport_queue(transitions: list, transitions_queue):
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner)) transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]: def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
"""Get the frequency statistics of the policy. """Get the frequency statistics of the policy.
Args: Args:
list_policy_time (list[float]): The list of policy times. timer (TimerManager): The timer with collected metrics.
Returns: Returns:
dict[str, float]: The frequency statistics of the policy. dict[str, float]: The frequency statistics of the policy.
""" """
stats = {} stats = {}
list_policy_fps = [1.0 / t for t in list_policy_time] if timer.count > 1:
if len(list_policy_fps) > 1: avg_fps = timer.fps_avg
policy_fps = mean(list_policy_fps) p90_fps = timer.fps_percentile(90)
quantiles_90 = quantiles(list_policy_fps, n=10)[-1] logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}") logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
stats = { stats = {
"Policy frequency [Hz]": policy_fps, "Policy frequency [Hz]": avg_fps,
"Policy frequency 90th-p [Hz]": quantiles_90, "Policy frequency 90th-p [Hz]": p90_fps,
} }
return stats return stats

View File

@@ -203,6 +203,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
if key in new_dataset.meta.info["features"]: if key in new_dataset.meta.info["features"]:
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size) new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
prev_episode_index = 0 prev_episode_index = 0
for frame_idx in tqdm(range(len(original_dataset))): for frame_idx in tqdm(range(len(original_dataset))):
frame = original_dataset[frame_idx] frame = original_dataset[frame_idx]

View File

@@ -23,10 +23,9 @@ import numpy as np
import torch import torch
from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import init_logging
from lerobot.scripts.server.kinematics import RobotKinematics from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO)
class InputController: class InputController:
"""Base class for input controllers that generate motion deltas.""" """Base class for input controllers that generate motion deltas."""
@@ -726,6 +725,8 @@ if __name__ == "__main__":
from lerobot.common.robot_devices.robots.utils import make_robot_from_config from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import make_robot_env from lerobot.scripts.server.gym_manipulator import make_robot_env
init_logging()
parser = argparse.ArgumentParser(description="Test end-effector control") parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument( parser.add_argument(
"--mode", "--mode",

View File

@@ -1588,19 +1588,20 @@ class GamepadControlWrapper(gym.Wrapper):
input_threshold: Minimum movement delta to consider as active input. input_threshold: Minimum movement delta to consider as active input.
""" """
super().__init__(env) super().__init__(env)
from lerobot.scripts.server.end_effector_control_utils import (
GamepadController,
GamepadControllerHID,
)
# use HidApi for macos # use HidApi for macos
if sys.platform == "darwin": if sys.platform == "darwin":
# NOTE: On macOS, pygame doesnt reliably detect input from some controllers so we fall back to hidapi
from lerobot.scripts.server.end_effector_control_utils import GamepadControllerHID
self.controller = GamepadControllerHID( self.controller = GamepadControllerHID(
x_step_size=x_step_size, x_step_size=x_step_size,
y_step_size=y_step_size, y_step_size=y_step_size,
z_step_size=z_step_size, z_step_size=z_step_size,
) )
else: else:
from lerobot.scripts.server.end_effector_control_utils import GamepadController
self.controller = GamepadController( self.controller = GamepadController(
x_step_size=x_step_size, x_step_size=x_step_size,
y_step_size=y_step_size, y_step_size=y_step_size,
@@ -1748,6 +1749,8 @@ class GymHilDeviceWrapper(gym.Wrapper):
for k in obs: for k in obs:
obs[k] = obs[k].to(self.device) obs[k] = obs[k].to(self.device)
if "action_intervention" in info: if "action_intervention" in info:
# NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device
info["action_intervention"] = info["action_intervention"].astype(np.float32)
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
return obs, reward, terminated, truncated, info return obs, reward, terminated, truncated, info
@@ -1756,6 +1759,8 @@ class GymHilDeviceWrapper(gym.Wrapper):
for k in obs: for k in obs:
obs[k] = obs[k].to(self.device) obs[k] = obs[k].to(self.device)
if "action_intervention" in info: if "action_intervention" in info:
# NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device
info["action_intervention"] = info["action_intervention"].astype(np.float32)
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
return obs, info return obs, info

View File

@@ -14,6 +14,66 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Learner server runner for distributed HILSerl robot policy training.
This script implements the learner component of the distributed HILSerl architecture.
It initializes the policy network, maintains replay buffers, and updates
the policy based on transitions received from the actor server.
Examples of usage:
- Start a learner server for training:
```bash
python lerobot/scripts/server/learner_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
```
- Run with specific SAC hyperparameters:
```bash
python lerobot/scripts/server/learner_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--learner.sac.alpha=0.1 \
--learner.sac.gamma=0.99
```
- Run with a specific dataset and wandb logging:
```bash
python lerobot/scripts/server/learner_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--dataset.repo_id=username/pick_lift_cube \
--wandb.enable=true \
--wandb.project=hilserl_training
```
- Run with a pretrained policy for fine-tuning:
```bash
python lerobot/scripts/server/learner_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--pretrained_policy_name_or_path=outputs/previous_training/checkpoints/080000/pretrained_model
```
- Run with a reward classifier model:
```bash
python lerobot/scripts/server/learner_server.py \
--config_path lerobot/configs/train_config_hilserl_so100.json \
--reward_classifier_pretrained_path=outputs/reward_model/best_model
```
**NOTE**: Start the learner server before launching the actor server. The learner opens a gRPC server
to communicate with actors.
**NOTE**: Training progress can be monitored through Weights & Biases if wandb.enable is set to true
in your configuration.
**WORKFLOW**:
1. Create training configuration with proper policy, dataset, and environment settings
2. Start this learner server with the configuration
3. Start an actor server with the same configuration
4. Monitor training progress through wandb dashboard
For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide
"""
import logging import logging
import os import os
@@ -73,7 +133,6 @@ from lerobot.scripts.server.utils import (
LOG_PREFIX = "[LEARNER]" LOG_PREFIX = "[LEARNER]"
logging.basicConfig(level=logging.INFO)
################################################# #################################################
# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS # # MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS #
@@ -113,13 +172,17 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
if job_name is None: if job_name is None:
raise ValueError("Job name must be specified either in config or as a parameter") raise ValueError("Job name must be specified either in config or as a parameter")
display_pid = False
if not use_threads(cfg):
display_pid = True
# Create logs directory to ensure it exists # Create logs directory to ensure it exists
log_dir = os.path.join(cfg.output_dir, "logs") log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"learner_{job_name}.log") log_file = os.path.join(log_dir, f"learner_{job_name}.log")
# Initialize logging with explicit log file # Initialize logging with explicit log file
init_logging(log_file=log_file) init_logging(log_file=log_file, display_pid=display_pid)
logging.info(f"Learner logging initialized, writing to {log_file}") logging.info(f"Learner logging initialized, writing to {log_file}")
logging.info(pformat(cfg.to_dict())) logging.info(pformat(cfg.to_dict()))
@@ -275,7 +338,7 @@ def add_actor_information_and_train(
log_dir = os.path.join(cfg.output_dir, "logs") log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log") log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log")
init_logging(log_file=log_file) init_logging(log_file=log_file, display_pid=True)
logging.info("Initialized logging for actor information and training process") logging.info("Initialized logging for actor information and training process")
logging.info("Initializing policy") logging.info("Initializing policy")
@@ -604,7 +667,7 @@ def start_learner_server(
log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log") log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log")
# Initialize logging with explicit log file # Initialize logging with explicit log file
init_logging(log_file=log_file) init_logging(log_file=log_file, display_pid=True)
logging.info("Learner server process logging initialized") logging.info("Learner server process logging initialized")
# Setup process handlers to handle shutdown signal # Setup process handlers to handle shutdown signal

View File

@@ -84,7 +84,7 @@ dora = [
] ]
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
hilserl = ["transformers>=4.48.0", "gym-hil>=0.1.2", "protobuf>=5.29.3", "grpcio>=1.70.0"] hilserl = ["transformers>=4.48", "gym-hil>=0.1.3", "protobuf>=5.29.3", "grpcio>=1.70.0"]
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
pi0 = ["transformers>=4.48.0"] pi0 = ["transformers>=4.48.0"]
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]

View File

@@ -59,7 +59,6 @@ def test_sac_config_default_initialization():
assert config.num_critics == 2 assert config.num_critics == 2
# Architecture specifics # Architecture specifics
assert config.camera_number == 1
assert config.vision_encoder_name is None assert config.vision_encoder_name is None
assert config.freeze_vision_encoder is True assert config.freeze_vision_encoder is True
assert config.image_encoder_hidden_dim == 32 assert config.image_encoder_hidden_dim == 32

View File

@@ -9,6 +9,13 @@ from lerobot.common.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.common.utils.random_utils import seeded_context from lerobot.common.utils.random_utils import seeded_context
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
try:
import transformers # noqa: F401
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
def test_mlp_with_default_args(): def test_mlp_with_default_args():
mlp = MLP(input_dim=10, hidden_dims=[256, 256]) mlp = MLP(input_dim=10, hidden_dims=[256, 256])
@@ -274,6 +281,7 @@ def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_di
"batch_size,state_dim,action_dim,vision_encoder_name", "batch_size,state_dim,action_dim,vision_encoder_name",
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], [(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
) )
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
def test_sac_policy_with_pretrained_encoder( def test_sac_policy_with_pretrained_encoder(
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
): ):