forked from tangger/lerobot
[HIL-SERL] Review feedback modifications (#1112)
This commit is contained in:
@@ -68,6 +68,7 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
|||||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||||
)
|
)
|
||||||
# TODO: add observation processor wrapper and remove preprocess_observation in the codebase
|
# TODO: add observation processor wrapper and remove preprocess_observation in the codebase
|
||||||
|
# https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/vector/vectorize_observation.py#L19,
|
||||||
# env = ObservationProcessorWrapper(env=env)
|
# env = ObservationProcessorWrapper(env=env)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -81,35 +81,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
class ObservationProcessorWrapper(gym.vector.VectorEnvWrapper):
|
|
||||||
def __init__(self, env: gym.vector.VectorEnv):
|
|
||||||
super().__init__(env)
|
|
||||||
|
|
||||||
def _observations(self, observations: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
return preprocess_observation(observations)
|
|
||||||
|
|
||||||
def reset(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
seed: int | list[int] | None = None,
|
|
||||||
options: dict[str, Any] | None = None,
|
|
||||||
):
|
|
||||||
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
|
|
||||||
observations, infos = self.env.reset(seed=seed, options=options)
|
|
||||||
return self._observations(observations), infos
|
|
||||||
|
|
||||||
def step(self, actions):
|
|
||||||
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
|
|
||||||
observations, rewards, terminations, truncations, infos = self.env.step(actions)
|
|
||||||
return (
|
|
||||||
self._observations(observations),
|
|
||||||
rewards,
|
|
||||||
terminations,
|
|
||||||
truncations,
|
|
||||||
infos,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||||
|
|||||||
@@ -46,6 +46,15 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||||
|
"""
|
||||||
|
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
|
||||||
|
NOTE: Multiple optimizers are useful when you have different models to optimize.
|
||||||
|
For example, you can have one optimizer for the policy and another one for the value function
|
||||||
|
in reinforcement learning settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The optimizer or a dictionary of optimizers.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -79,48 +79,28 @@ def create_stats_buffers(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||||
if stats and key in stats:
|
if stats:
|
||||||
# NOTE:(maractingi, azouitine): Change the order of these conditions because in online environments we don't have dataset stats
|
if isinstance(stats[key]["mean"], np.ndarray):
|
||||||
# Therefore, we don't access to full stats of the data, some elements either have min-max or mean-std only
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
|
||||||
if "mean" not in stats[key] or "std" not in stats[key]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(stats[key]["mean"], np.ndarray):
|
|
||||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
|
||||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
|
||||||
# unnormalization). See the logic here
|
|
||||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
|
||||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
|
||||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
|
||||||
else:
|
|
||||||
type_ = type(stats[key]["mean"])
|
|
||||||
raise ValueError(
|
|
||||||
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
|
||||||
if "min" not in stats[key] or "max" not in stats[key]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(stats[key]["min"], np.ndarray):
|
|
||||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||||
elif isinstance(stats[key]["min"], torch.Tensor):
|
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||||
|
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||||
|
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||||
|
# unnormalization). See the logic here
|
||||||
|
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||||
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
|
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||||
|
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||||
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
type_ = type(stats[key]["min"])
|
type_ = type(stats[key]["mean"])
|
||||||
raise ValueError(
|
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||||
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
stats_buffers[key] = buffer
|
stats_buffers[key] = buffer
|
||||||
return stats_buffers
|
return stats_buffers
|
||||||
@@ -169,13 +149,12 @@ class Normalize(nn.Module):
|
|||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
# @torch.no_grad
|
@torch.no_grad
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
if key not in batch:
|
if key not in batch:
|
||||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||||
# NOTE: (azouitine) This continues help us for instantiation SACPolicy
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||||
@@ -244,7 +223,7 @@ class Unnormalize(nn.Module):
|
|||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
# @torch.no_grad
|
@torch.no_grad
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
@@ -273,3 +252,170 @@ class Unnormalize(nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(norm_mode)
|
raise ValueError(norm_mode)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: We should replace all normalization on the policies with register_buffer normalization
|
||||||
|
# and remove the `Normalize` and `Unnormalize` classes.
|
||||||
|
def _initialize_stats_buffers(
|
||||||
|
module: nn.Module,
|
||||||
|
features: dict[str, PolicyFeature],
|
||||||
|
norm_map: dict[str, NormalizationMode],
|
||||||
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Register statistics buffers (mean/std or min/max) on the given *module*.
|
||||||
|
|
||||||
|
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
|
||||||
|
but is factored out so it can be reused by both classes and stay in sync.
|
||||||
|
"""
|
||||||
|
for key, ft in features.items():
|
||||||
|
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||||
|
if norm_mode is NormalizationMode.IDENTITY:
|
||||||
|
continue
|
||||||
|
|
||||||
|
shape: tuple[int, ...] = tuple(ft.shape)
|
||||||
|
if ft.type is FeatureType.VISUAL:
|
||||||
|
# reduce spatial dimensions, keep channel dimension only
|
||||||
|
c, *_ = shape
|
||||||
|
shape = (c, 1, 1)
|
||||||
|
|
||||||
|
prefix = key.replace(".", "_")
|
||||||
|
|
||||||
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
|
mean = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||||
|
std = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||||
|
|
||||||
|
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||||
|
mean_data = stats[key]["mean"]
|
||||||
|
std_data = stats[key]["std"]
|
||||||
|
if isinstance(mean_data, np.ndarray):
|
||||||
|
mean = torch.from_numpy(mean_data).to(dtype=torch.float32)
|
||||||
|
std = torch.from_numpy(std_data).to(dtype=torch.float32)
|
||||||
|
elif isinstance(mean_data, torch.Tensor):
|
||||||
|
mean = mean_data.clone().to(dtype=torch.float32)
|
||||||
|
std = std_data.clone().to(dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||||
|
|
||||||
|
module.register_buffer(f"{prefix}_mean", mean)
|
||||||
|
module.register_buffer(f"{prefix}_std", std)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if norm_mode is NormalizationMode.MIN_MAX:
|
||||||
|
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||||
|
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||||
|
|
||||||
|
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||||
|
min_data = stats[key]["min"]
|
||||||
|
max_data = stats[key]["max"]
|
||||||
|
if isinstance(min_data, np.ndarray):
|
||||||
|
min_val = torch.from_numpy(min_data).to(dtype=torch.float32)
|
||||||
|
max_val = torch.from_numpy(max_data).to(dtype=torch.float32)
|
||||||
|
elif isinstance(min_data, torch.Tensor):
|
||||||
|
min_val = min_data.clone().to(dtype=torch.float32)
|
||||||
|
max_val = max_data.clone().to(dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||||
|
|
||||||
|
module.register_buffer(f"{prefix}_min", min_val)
|
||||||
|
module.register_buffer(f"{prefix}_max", max_val)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise ValueError(norm_mode)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeBuffer(nn.Module):
|
||||||
|
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
features: dict[str, PolicyFeature],
|
||||||
|
norm_map: dict[str, NormalizationMode],
|
||||||
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.features = features
|
||||||
|
self.norm_map = norm_map
|
||||||
|
|
||||||
|
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
batch = dict(batch)
|
||||||
|
for key, ft in self.features.items():
|
||||||
|
if key not in batch:
|
||||||
|
continue
|
||||||
|
|
||||||
|
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||||
|
if norm_mode is NormalizationMode.IDENTITY:
|
||||||
|
continue
|
||||||
|
|
||||||
|
prefix = key.replace(".", "_")
|
||||||
|
|
||||||
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
|
mean = getattr(self, f"{prefix}_mean")
|
||||||
|
std = getattr(self, f"{prefix}_std")
|
||||||
|
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||||
|
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||||
|
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if norm_mode is NormalizationMode.MIN_MAX:
|
||||||
|
min_val = getattr(self, f"{prefix}_min")
|
||||||
|
max_val = getattr(self, f"{prefix}_max")
|
||||||
|
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||||
|
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||||
|
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
|
||||||
|
batch[key] = batch[key] * 2 - 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise ValueError(norm_mode)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
class UnnormalizeBuffer(nn.Module):
|
||||||
|
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
features: dict[str, PolicyFeature],
|
||||||
|
norm_map: dict[str, NormalizationMode],
|
||||||
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.features = features
|
||||||
|
self.norm_map = norm_map
|
||||||
|
|
||||||
|
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
batch = dict(batch)
|
||||||
|
for key, ft in self.features.items():
|
||||||
|
if key not in batch:
|
||||||
|
continue
|
||||||
|
|
||||||
|
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||||
|
if norm_mode is NormalizationMode.IDENTITY:
|
||||||
|
continue
|
||||||
|
|
||||||
|
prefix = key.replace(".", "_")
|
||||||
|
|
||||||
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
|
mean = getattr(self, f"{prefix}_mean")
|
||||||
|
std = getattr(self, f"{prefix}_std")
|
||||||
|
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||||
|
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||||
|
batch[key] = batch[key] * std + mean
|
||||||
|
continue
|
||||||
|
|
||||||
|
if norm_mode is NormalizationMode.MIN_MAX:
|
||||||
|
min_val = getattr(self, f"{prefix}_min")
|
||||||
|
max_val = getattr(self, f"{prefix}_max")
|
||||||
|
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||||
|
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||||
|
batch[key] = (batch[key] + 1) / 2
|
||||||
|
batch[key] = batch[key] * (max_val - min_val) + min_val
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise ValueError(norm_mode)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|||||||
@@ -9,9 +9,6 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize
|
|||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassifierOutput:
|
class ClassifierOutput:
|
||||||
"""Wrapper for classifier outputs with additional metadata."""
|
"""Wrapper for classifier outputs with additional metadata."""
|
||||||
|
|||||||
@@ -24,6 +24,12 @@ from lerobot.configs.types import NormalizationMode
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConcurrencyConfig:
|
class ConcurrencyConfig:
|
||||||
|
"""Configuration for the concurrency of the actor and learner.
|
||||||
|
Possible values are:
|
||||||
|
- "threads": Use threads for the actor and learner.
|
||||||
|
- "processes": Use processes for the actor and learner.
|
||||||
|
"""
|
||||||
|
|
||||||
actor: str = "threads"
|
actor: str = "threads"
|
||||||
learner: str = "threads"
|
learner: str = "threads"
|
||||||
|
|
||||||
@@ -68,51 +74,9 @@ class SACConfig(PreTrainedConfig):
|
|||||||
This configuration class contains all the parameters needed to define a SAC agent,
|
This configuration class contains all the parameters needed to define a SAC agent,
|
||||||
including network architectures, optimization settings, and algorithm-specific
|
including network architectures, optimization settings, and algorithm-specific
|
||||||
hyperparameters.
|
hyperparameters.
|
||||||
|
|
||||||
Args:
|
|
||||||
actor_network_kwargs: Configuration for the actor network architecture.
|
|
||||||
critic_network_kwargs: Configuration for the critic network architecture.
|
|
||||||
discrete_critic_network_kwargs: Configuration for the discrete critic network.
|
|
||||||
policy_kwargs: Configuration for the policy parameters.
|
|
||||||
n_obs_steps: Number of observation steps to consider.
|
|
||||||
normalization_mapping: Mapping of feature types to normalization modes.
|
|
||||||
dataset_stats: Statistics for normalizing different types of inputs.
|
|
||||||
input_features: Dictionary of input features with their types and shapes.
|
|
||||||
output_features: Dictionary of output features with their types and shapes.
|
|
||||||
camera_number: Number of cameras used for visual observations.
|
|
||||||
device: Device to run the model on (e.g., "cuda", "cpu").
|
|
||||||
storage_device: Device to store the model on.
|
|
||||||
vision_encoder_name: Name of the vision encoder model.
|
|
||||||
freeze_vision_encoder: Whether to freeze the vision encoder during training.
|
|
||||||
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
|
||||||
shared_encoder: Whether to use a shared encoder for actor and critic.
|
|
||||||
num_discrete_actions: Number of discrete actions, eg for gripper actions.
|
|
||||||
image_embedding_pooling_dim: Dimension of the image embedding pooling.
|
|
||||||
concurrency: Configuration for concurrency settings.
|
|
||||||
actor_learner_config: Configuration for actor-learner architecture.
|
|
||||||
online_steps: Number of steps for online training.
|
|
||||||
online_env_seed: Seed for the online environment.
|
|
||||||
online_buffer_capacity: Capacity of the online replay buffer.
|
|
||||||
offline_buffer_capacity: Capacity of the offline replay buffer.
|
|
||||||
async_prefetch: Whether to use asynchronous prefetching for the buffers.
|
|
||||||
online_step_before_learning: Number of steps before learning starts.
|
|
||||||
policy_update_freq: Frequency of policy updates.
|
|
||||||
discount: Discount factor for the SAC algorithm.
|
|
||||||
temperature_init: Initial temperature value.
|
|
||||||
num_critics: Number of critics in the ensemble.
|
|
||||||
num_subsample_critics: Number of subsampled critics for training.
|
|
||||||
critic_lr: Learning rate for the critic network.
|
|
||||||
actor_lr: Learning rate for the actor network.
|
|
||||||
temperature_lr: Learning rate for the temperature parameter.
|
|
||||||
critic_target_update_weight: Weight for the critic target update.
|
|
||||||
utd_ratio: Update-to-data ratio for the UTD algorithm.
|
|
||||||
state_encoder_hidden_dim: Hidden dimension size for the state encoder.
|
|
||||||
latent_dim: Dimension of the latent space.
|
|
||||||
target_entropy: Target entropy for the SAC algorithm.
|
|
||||||
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
|
|
||||||
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Mapping of feature types to normalization modes
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"VISUAL": NormalizationMode.MEAN_STD,
|
"VISUAL": NormalizationMode.MEAN_STD,
|
||||||
@@ -122,6 +86,7 @@ class SACConfig(PreTrainedConfig):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Statistics for normalizing different types of inputs
|
||||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.image": {
|
"observation.image": {
|
||||||
@@ -140,47 +105,81 @@ class SACConfig(PreTrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Architecture specifics
|
# Architecture specifics
|
||||||
|
# Device to run the model on (e.g., "cuda", "cpu")
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
|
# Device to store the model on
|
||||||
storage_device: str = "cpu"
|
storage_device: str = "cpu"
|
||||||
# Set to "helper2424/resnet10" for hil serl
|
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
||||||
vision_encoder_name: str | None = None
|
vision_encoder_name: str | None = None
|
||||||
|
# Whether to freeze the vision encoder during training
|
||||||
freeze_vision_encoder: bool = True
|
freeze_vision_encoder: bool = True
|
||||||
|
# Hidden dimension size for the image encoder
|
||||||
image_encoder_hidden_dim: int = 32
|
image_encoder_hidden_dim: int = 32
|
||||||
|
# Whether to use a shared encoder for actor and critic
|
||||||
shared_encoder: bool = True
|
shared_encoder: bool = True
|
||||||
|
# Number of discrete actions, eg for gripper actions
|
||||||
num_discrete_actions: int | None = None
|
num_discrete_actions: int | None = None
|
||||||
|
# Dimension of the image embedding pooling
|
||||||
image_embedding_pooling_dim: int = 8
|
image_embedding_pooling_dim: int = 8
|
||||||
|
|
||||||
# Training parameter
|
# Training parameter
|
||||||
|
# Number of steps for online training
|
||||||
online_steps: int = 1000000
|
online_steps: int = 1000000
|
||||||
|
# Seed for the online environment
|
||||||
online_env_seed: int = 10000
|
online_env_seed: int = 10000
|
||||||
|
# Capacity of the online replay buffer
|
||||||
online_buffer_capacity: int = 100000
|
online_buffer_capacity: int = 100000
|
||||||
|
# Capacity of the offline replay buffer
|
||||||
offline_buffer_capacity: int = 100000
|
offline_buffer_capacity: int = 100000
|
||||||
|
# Whether to use asynchronous prefetching for the buffers
|
||||||
async_prefetch: bool = False
|
async_prefetch: bool = False
|
||||||
|
# Number of steps before learning starts
|
||||||
online_step_before_learning: int = 100
|
online_step_before_learning: int = 100
|
||||||
|
# Frequency of policy updates
|
||||||
policy_update_freq: int = 1
|
policy_update_freq: int = 1
|
||||||
|
|
||||||
# SAC algorithm parameters
|
# SAC algorithm parameters
|
||||||
|
# Discount factor for the SAC algorithm
|
||||||
discount: float = 0.99
|
discount: float = 0.99
|
||||||
|
# Initial temperature value
|
||||||
temperature_init: float = 1.0
|
temperature_init: float = 1.0
|
||||||
|
# Number of critics in the ensemble
|
||||||
num_critics: int = 2
|
num_critics: int = 2
|
||||||
|
# Number of subsampled critics for training
|
||||||
num_subsample_critics: int | None = None
|
num_subsample_critics: int | None = None
|
||||||
|
# Learning rate for the critic network
|
||||||
critic_lr: float = 3e-4
|
critic_lr: float = 3e-4
|
||||||
|
# Learning rate for the actor network
|
||||||
actor_lr: float = 3e-4
|
actor_lr: float = 3e-4
|
||||||
|
# Learning rate for the temperature parameter
|
||||||
temperature_lr: float = 3e-4
|
temperature_lr: float = 3e-4
|
||||||
|
# Weight for the critic target update
|
||||||
critic_target_update_weight: float = 0.005
|
critic_target_update_weight: float = 0.005
|
||||||
utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1
|
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
||||||
|
utd_ratio: int = 1
|
||||||
|
# Hidden dimension size for the state encoder
|
||||||
state_encoder_hidden_dim: int = 256
|
state_encoder_hidden_dim: int = 256
|
||||||
|
# Dimension of the latent space
|
||||||
latent_dim: int = 256
|
latent_dim: int = 256
|
||||||
|
# Target entropy for the SAC algorithm
|
||||||
target_entropy: float | None = None
|
target_entropy: float | None = None
|
||||||
|
# Whether to use backup entropy for the SAC algorithm
|
||||||
use_backup_entropy: bool = True
|
use_backup_entropy: bool = True
|
||||||
|
# Gradient clipping norm for the SAC algorithm
|
||||||
grad_clip_norm: float = 40.0
|
grad_clip_norm: float = 40.0
|
||||||
|
|
||||||
# Network configuration
|
# Network configuration
|
||||||
|
# Configuration for the critic network architecture
|
||||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||||
|
# Configuration for the actor network architecture
|
||||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||||
|
# Configuration for the policy parameters
|
||||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||||
|
# Configuration for the discrete critic network
|
||||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||||
|
# Configuration for actor-learner architecture
|
||||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||||
|
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||||
|
|
||||||
# Optimizations
|
# Optimizations
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import torch.nn.functional as F # noqa: N812
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||||
|
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import NormalizeBuffer, UnnormalizeBuffer
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
@@ -394,17 +394,16 @@ class SACPolicy(
|
|||||||
self.normalize_inputs = nn.Identity()
|
self.normalize_inputs = nn.Identity()
|
||||||
self.normalize_targets = nn.Identity()
|
self.normalize_targets = nn.Identity()
|
||||||
self.unnormalize_outputs = nn.Identity()
|
self.unnormalize_outputs = nn.Identity()
|
||||||
|
if self.config.dataset_stats is not None:
|
||||||
if self.config.dataset_stats:
|
|
||||||
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
|
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = NormalizeBuffer(
|
||||||
self.config.input_features, self.config.normalization_mapping, params
|
self.config.input_features, self.config.normalization_mapping, params
|
||||||
)
|
)
|
||||||
stats = dataset_stats or params
|
stats = dataset_stats or params
|
||||||
self.normalize_targets = Normalize(
|
self.normalize_targets = NormalizeBuffer(
|
||||||
self.config.output_features, self.config.normalization_mapping, stats
|
self.config.output_features, self.config.normalization_mapping, stats
|
||||||
)
|
)
|
||||||
self.unnormalize_outputs = Unnormalize(
|
self.unnormalize_outputs = UnnormalizeBuffer(
|
||||||
self.config.output_features, self.config.normalization_mapping, stats
|
self.config.output_features, self.config.normalization_mapping, stats
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -506,7 +505,7 @@ class SACObservationEncoder(nn.Module):
|
|||||||
if not self.has_images:
|
if not self.has_images:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.config.vision_encoder_name:
|
if self.config.vision_encoder_name is not None:
|
||||||
self.image_encoder = PretrainedImageEncoder(self.config)
|
self.image_encoder = PretrainedImageEncoder(self.config)
|
||||||
else:
|
else:
|
||||||
self.image_encoder = DefaultImageEncoder(self.config)
|
self.image_encoder = DefaultImageEncoder(self.config)
|
||||||
|
|||||||
@@ -19,9 +19,10 @@ import os.path as osp
|
|||||||
import platform
|
import platform
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from copy import copy
|
from copy import copy, deepcopy
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from statistics import mean
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -108,11 +109,14 @@ def is_amp_available(device: str):
|
|||||||
raise ValueError(f"Unknown device '{device}.")
|
raise ValueError(f"Unknown device '{device}.")
|
||||||
|
|
||||||
|
|
||||||
def init_logging(log_file=None):
|
def init_logging(log_file: Path | None = None, display_pid: bool = False):
|
||||||
def custom_format(record):
|
def custom_format(record):
|
||||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
fnameline = f"{record.pathname}:{record.lineno}"
|
fnameline = f"{record.pathname}:{record.lineno}"
|
||||||
message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}"
|
|
||||||
|
# NOTE: Display PID is useful for multi-process logging.
|
||||||
|
pid_str = f"[PID: {os.getpid()}]" if display_pid else ""
|
||||||
|
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||||
return message
|
return message
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -238,30 +242,99 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
class TimerManager:
|
class TimerManager:
|
||||||
|
"""
|
||||||
|
Lightweight utility to measure elapsed time.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> timer = TimerManager("Policy", log=False)
|
||||||
|
>>> for _ in range(3):
|
||||||
|
... with timer:
|
||||||
|
... time.sleep(0.01)
|
||||||
|
>>> print(timer.last, timer.fps_avg, timer.percentile(90))
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
elapsed_time_list: list[float] | None = None,
|
label: str = "Elapsed-time",
|
||||||
label="Elapsed time",
|
log: bool = True,
|
||||||
log=True,
|
logger: logging.Logger | None = None,
|
||||||
):
|
):
|
||||||
self.label = label
|
self.label = label
|
||||||
self.elapsed_time_list = elapsed_time_list
|
|
||||||
self.log = log
|
self.log = log
|
||||||
self.elapsed = 0.0
|
self.logger = logger
|
||||||
|
self._start: float | None = None
|
||||||
|
self._history: list[float] = []
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.start = time.perf_counter()
|
return self.start()
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self._start = time.perf_counter()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def stop(self) -> float:
|
||||||
self.elapsed: float = time.perf_counter() - self.start
|
if self._start is None:
|
||||||
|
raise RuntimeError("Timer was never started.")
|
||||||
if self.elapsed_time_list is not None:
|
elapsed = time.perf_counter() - self._start
|
||||||
self.elapsed_time_list.append(self.elapsed)
|
self._history.append(elapsed)
|
||||||
|
self._start = None
|
||||||
if self.log:
|
if self.log:
|
||||||
print(f"{self.label}: {self.elapsed:.6f} seconds")
|
if self.logger is not None:
|
||||||
|
self.logger.info(f"{self.label}: {elapsed:.6f} s")
|
||||||
|
else:
|
||||||
|
logging.info(f"{self.label}: {elapsed:.6f} s")
|
||||||
|
return elapsed
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._history.clear()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def elapsed_seconds(self):
|
def last(self) -> float:
|
||||||
return self.elapsed
|
return self._history[-1] if self._history else 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg(self) -> float:
|
||||||
|
return mean(self._history) if self._history else 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total(self) -> float:
|
||||||
|
return sum(self._history)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self) -> int:
|
||||||
|
return len(self._history)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def history(self) -> list[float]:
|
||||||
|
return deepcopy(self._history)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps_history(self) -> list[float]:
|
||||||
|
return [1.0 / t for t in self._history]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps_last(self) -> float:
|
||||||
|
return 0.0 if self.last == 0 else 1.0 / self.last
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps_avg(self) -> float:
|
||||||
|
return 0.0 if self.avg == 0 else 1.0 / self.avg
|
||||||
|
|
||||||
|
def percentile(self, p: float) -> float:
|
||||||
|
"""
|
||||||
|
Return the p-th percentile of recorded times.
|
||||||
|
"""
|
||||||
|
if not self._history:
|
||||||
|
return 0.0
|
||||||
|
return float(np.percentile(self._history, p))
|
||||||
|
|
||||||
|
def fps_percentile(self, p: float) -> float:
|
||||||
|
"""
|
||||||
|
FPS corresponding to the p-th percentile time.
|
||||||
|
"""
|
||||||
|
val = self.percentile(p)
|
||||||
|
return 0.0 if val == 0 else 1.0 / val
|
||||||
|
|||||||
@@ -123,9 +123,9 @@ class WandBLogger:
|
|||||||
if step is None and custom_step_key is None:
|
if step is None and custom_step_key is None:
|
||||||
raise ValueError("Either step or custom_step_key must be provided.")
|
raise ValueError("Either step or custom_step_key must be provided.")
|
||||||
|
|
||||||
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
|
# NOTE: This is not simple. Wandb step must always monotonically increase and it
|
||||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||||
# multiple time steps is possible for example, the interaction step with the environment,
|
# multiple time steps is possible. For example, the interaction step with the environment,
|
||||||
# the training step, the evaluation step, etc. So we need to define a custom step key
|
# the training step, the evaluation step, etc. So we need to define a custom step key
|
||||||
# to log the correct step for each metric.
|
# to log the correct step for each metric.
|
||||||
if custom_step_key is not None:
|
if custom_step_key is not None:
|
||||||
|
|||||||
@@ -13,13 +13,67 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Actor server runner for distributed HILSerl robot policy training.
|
||||||
|
|
||||||
|
This script implements the actor component of the distributed HILSerl architecture.
|
||||||
|
It executes the policy in the robot environment, collects experience,
|
||||||
|
and sends transitions to the learner server for policy updates.
|
||||||
|
|
||||||
|
Examples of usage:
|
||||||
|
|
||||||
|
- Start an actor server for real robot training with human-in-the-loop intervention:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/server/actor_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run with a specific robot type for a pick and place task:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/server/actor_server.py \
|
||||||
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
|
--robot.type=so100 \
|
||||||
|
--task=pick_and_place
|
||||||
|
```
|
||||||
|
|
||||||
|
- Set a custom workspace bound for the robot's end-effector:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/server/actor_server.py \
|
||||||
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
|
--env.ee_action_space_params.bounds.max="[0.24, 0.20, 0.10]" \
|
||||||
|
--env.ee_action_space_params.bounds.min="[0.16, -0.08, 0.03]"
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run with specific camera crop parameters:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/server/actor_server.py \
|
||||||
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
|
--env.crop_params_dict="{'observation.images.side': [180, 207, 180, 200], 'observation.images.front': [180, 250, 120, 150]}"
|
||||||
|
```
|
||||||
|
|
||||||
|
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
|
||||||
|
server is started before launching the actor.
|
||||||
|
|
||||||
|
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
|
||||||
|
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
|
||||||
|
reduce interventions as the policy improves.
|
||||||
|
|
||||||
|
**WORKFLOW**:
|
||||||
|
1. Determine robot workspace bounds using `find_joint_limits.py`
|
||||||
|
2. Record demonstrations with `gym_manipulator.py` in record mode
|
||||||
|
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
|
||||||
|
4. Start the learner server with the training configuration
|
||||||
|
5. Start this actor server with the same configuration
|
||||||
|
6. Use human interventions to guide policy learning
|
||||||
|
|
||||||
|
For more details on the complete HILSerl training workflow, see:
|
||||||
|
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from queue import Empty
|
from queue import Empty
|
||||||
from statistics import mean, quantiles
|
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
import torch
|
import torch
|
||||||
@@ -65,10 +119,12 @@ ACTOR_SHUTDOWN_TIMEOUT = 30
|
|||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def actor_cli(cfg: TrainPipelineConfig):
|
def actor_cli(cfg: TrainPipelineConfig):
|
||||||
cfg.validate()
|
cfg.validate()
|
||||||
|
display_pid = False
|
||||||
if not use_threads(cfg):
|
if not use_threads(cfg):
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
mp.set_start_method("spawn")
|
mp.set_start_method("spawn")
|
||||||
|
display_pid = True
|
||||||
|
|
||||||
# Create logs directory to ensure it exists
|
# Create logs directory to ensure it exists
|
||||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||||
@@ -76,7 +132,7 @@ def actor_cli(cfg: TrainPipelineConfig):
|
|||||||
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
|
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
|
||||||
|
|
||||||
# Initialize logging with explicit log file
|
# Initialize logging with explicit log file
|
||||||
init_logging(log_file=log_file)
|
init_logging(log_file=log_file, display_pid=display_pid)
|
||||||
logging.info(f"Actor logging initialized, writing to {log_file}")
|
logging.info(f"Actor logging initialized, writing to {log_file}")
|
||||||
|
|
||||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||||
@@ -193,7 +249,7 @@ def act_with_policy(
|
|||||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
|
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
|
||||||
init_logging(log_file=log_file)
|
init_logging(log_file=log_file, display_pid=True)
|
||||||
logging.info("Actor policy process logging initialized")
|
logging.info("Actor policy process logging initialized")
|
||||||
|
|
||||||
logging.info("make_env online")
|
logging.info("make_env online")
|
||||||
@@ -223,12 +279,13 @@ def act_with_policy(
|
|||||||
# NOTE: For the moment we will solely handle the case of a single environment
|
# NOTE: For the moment we will solely handle the case of a single environment
|
||||||
sum_reward_episode = 0
|
sum_reward_episode = 0
|
||||||
list_transition_to_send_to_learner = []
|
list_transition_to_send_to_learner = []
|
||||||
list_policy_time = []
|
|
||||||
episode_intervention = False
|
episode_intervention = False
|
||||||
# Add counters for intervention rate calculation
|
# Add counters for intervention rate calculation
|
||||||
episode_intervention_steps = 0
|
episode_intervention_steps = 0
|
||||||
episode_total_steps = 0
|
episode_total_steps = 0
|
||||||
|
|
||||||
|
policy_timer = TimerManager("Policy inference", log=False)
|
||||||
|
|
||||||
for interaction_step in range(cfg.policy.online_steps):
|
for interaction_step in range(cfg.policy.online_steps):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
if shutdown_event.is_set():
|
if shutdown_event.is_set():
|
||||||
@@ -237,13 +294,9 @@ def act_with_policy(
|
|||||||
|
|
||||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||||
# Time policy inference and check if it meets FPS requirement
|
# Time policy inference and check if it meets FPS requirement
|
||||||
with TimerManager(
|
with policy_timer:
|
||||||
elapsed_time_list=list_policy_time,
|
|
||||||
label="Policy inference time",
|
|
||||||
log=False,
|
|
||||||
) as timer: # noqa: F841
|
|
||||||
action = policy.select_action(batch=obs)
|
action = policy.select_action(batch=obs)
|
||||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
policy_fps = policy_timer.fps_last
|
||||||
|
|
||||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||||
|
|
||||||
@@ -291,8 +344,8 @@ def act_with_policy(
|
|||||||
)
|
)
|
||||||
list_transition_to_send_to_learner = []
|
list_transition_to_send_to_learner = []
|
||||||
|
|
||||||
stats = get_frequency_stats(list_policy_time)
|
stats = get_frequency_stats(policy_timer)
|
||||||
list_policy_time.clear()
|
policy_timer.reset()
|
||||||
|
|
||||||
# Calculate intervention rate
|
# Calculate intervention rate
|
||||||
intervention_rate = 0.0
|
intervention_rate = 0.0
|
||||||
@@ -429,7 +482,7 @@ def receive_policy(
|
|||||||
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
|
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
|
||||||
|
|
||||||
# Initialize logging with explicit log file
|
# Initialize logging with explicit log file
|
||||||
init_logging(log_file=log_file)
|
init_logging(log_file=log_file, display_pid=True)
|
||||||
logging.info("Actor receive policy process logging initialized")
|
logging.info("Actor receive policy process logging initialized")
|
||||||
|
|
||||||
# Setup process handlers to handle shutdown signal
|
# Setup process handlers to handle shutdown signal
|
||||||
@@ -484,7 +537,7 @@ def send_transitions(
|
|||||||
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
|
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
|
||||||
|
|
||||||
# Initialize logging with explicit log file
|
# Initialize logging with explicit log file
|
||||||
init_logging(log_file=log_file)
|
init_logging(log_file=log_file, display_pid=True)
|
||||||
logging.info("Actor transitions process logging initialized")
|
logging.info("Actor transitions process logging initialized")
|
||||||
|
|
||||||
# Setup process handlers to handle shutdown signal
|
# Setup process handlers to handle shutdown signal
|
||||||
@@ -533,7 +586,7 @@ def send_interactions(
|
|||||||
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
|
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
|
||||||
|
|
||||||
# Initialize logging with explicit log file
|
# Initialize logging with explicit log file
|
||||||
init_logging(log_file=log_file)
|
init_logging(log_file=log_file, display_pid=True)
|
||||||
logging.info("Actor interactions process logging initialized")
|
logging.info("Actor interactions process logging initialized")
|
||||||
|
|
||||||
# Setup process handlers to handle shutdown signal
|
# Setup process handlers to handle shutdown signal
|
||||||
@@ -632,25 +685,24 @@ def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
|||||||
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
|
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
|
||||||
|
|
||||||
|
|
||||||
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
|
||||||
"""Get the frequency statistics of the policy.
|
"""Get the frequency statistics of the policy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
list_policy_time (list[float]): The list of policy times.
|
timer (TimerManager): The timer with collected metrics.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, float]: The frequency statistics of the policy.
|
dict[str, float]: The frequency statistics of the policy.
|
||||||
"""
|
"""
|
||||||
stats = {}
|
stats = {}
|
||||||
list_policy_fps = [1.0 / t for t in list_policy_time]
|
if timer.count > 1:
|
||||||
if len(list_policy_fps) > 1:
|
avg_fps = timer.fps_avg
|
||||||
policy_fps = mean(list_policy_fps)
|
p90_fps = timer.fps_percentile(90)
|
||||||
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
|
logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
|
||||||
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
|
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
|
||||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
|
|
||||||
stats = {
|
stats = {
|
||||||
"Policy frequency [Hz]": policy_fps,
|
"Policy frequency [Hz]": avg_fps,
|
||||||
"Policy frequency 90th-p [Hz]": quantiles_90,
|
"Policy frequency 90th-p [Hz]": p90_fps,
|
||||||
}
|
}
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|||||||
@@ -203,6 +203,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
|||||||
if key in new_dataset.meta.info["features"]:
|
if key in new_dataset.meta.info["features"]:
|
||||||
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
|
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
|
||||||
|
|
||||||
|
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
|
||||||
prev_episode_index = 0
|
prev_episode_index = 0
|
||||||
for frame_idx in tqdm(range(len(original_dataset))):
|
for frame_idx in tqdm(range(len(original_dataset))):
|
||||||
frame = original_dataset[frame_idx]
|
frame = original_dataset[frame_idx]
|
||||||
|
|||||||
@@ -23,10 +23,9 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.robot_devices.utils import busy_wait
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
|
from lerobot.common.utils.utils import init_logging
|
||||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
class InputController:
|
class InputController:
|
||||||
"""Base class for input controllers that generate motion deltas."""
|
"""Base class for input controllers that generate motion deltas."""
|
||||||
@@ -726,6 +725,8 @@ if __name__ == "__main__":
|
|||||||
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
|
||||||
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
from lerobot.scripts.server.gym_manipulator import make_robot_env
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Test end-effector control")
|
parser = argparse.ArgumentParser(description="Test end-effector control")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mode",
|
"--mode",
|
||||||
|
|||||||
@@ -1588,19 +1588,20 @@ class GamepadControlWrapper(gym.Wrapper):
|
|||||||
input_threshold: Minimum movement delta to consider as active input.
|
input_threshold: Minimum movement delta to consider as active input.
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
from lerobot.scripts.server.end_effector_control_utils import (
|
|
||||||
GamepadController,
|
|
||||||
GamepadControllerHID,
|
|
||||||
)
|
|
||||||
|
|
||||||
# use HidApi for macos
|
# use HidApi for macos
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
|
# NOTE: On macOS, pygame doesn’t reliably detect input from some controllers so we fall back to hidapi
|
||||||
|
from lerobot.scripts.server.end_effector_control_utils import GamepadControllerHID
|
||||||
|
|
||||||
self.controller = GamepadControllerHID(
|
self.controller = GamepadControllerHID(
|
||||||
x_step_size=x_step_size,
|
x_step_size=x_step_size,
|
||||||
y_step_size=y_step_size,
|
y_step_size=y_step_size,
|
||||||
z_step_size=z_step_size,
|
z_step_size=z_step_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
from lerobot.scripts.server.end_effector_control_utils import GamepadController
|
||||||
|
|
||||||
self.controller = GamepadController(
|
self.controller = GamepadController(
|
||||||
x_step_size=x_step_size,
|
x_step_size=x_step_size,
|
||||||
y_step_size=y_step_size,
|
y_step_size=y_step_size,
|
||||||
@@ -1748,6 +1749,8 @@ class GymHilDeviceWrapper(gym.Wrapper):
|
|||||||
for k in obs:
|
for k in obs:
|
||||||
obs[k] = obs[k].to(self.device)
|
obs[k] = obs[k].to(self.device)
|
||||||
if "action_intervention" in info:
|
if "action_intervention" in info:
|
||||||
|
# NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device
|
||||||
|
info["action_intervention"] = info["action_intervention"].astype(np.float32)
|
||||||
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
|
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
@@ -1756,6 +1759,8 @@ class GymHilDeviceWrapper(gym.Wrapper):
|
|||||||
for k in obs:
|
for k in obs:
|
||||||
obs[k] = obs[k].to(self.device)
|
obs[k] = obs[k].to(self.device)
|
||||||
if "action_intervention" in info:
|
if "action_intervention" in info:
|
||||||
|
# NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device
|
||||||
|
info["action_intervention"] = info["action_intervention"].astype(np.float32)
|
||||||
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
|
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
|
||||||
return obs, info
|
return obs, info
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,66 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Learner server runner for distributed HILSerl robot policy training.
|
||||||
|
|
||||||
|
This script implements the learner component of the distributed HILSerl architecture.
|
||||||
|
It initializes the policy network, maintains replay buffers, and updates
|
||||||
|
the policy based on transitions received from the actor server.
|
||||||
|
|
||||||
|
Examples of usage:
|
||||||
|
|
||||||
|
- Start a learner server for training:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/server/learner_server.py --config_path lerobot/configs/train_config_hilserl_so100.json
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run with specific SAC hyperparameters:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/server/learner_server.py \
|
||||||
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
|
--learner.sac.alpha=0.1 \
|
||||||
|
--learner.sac.gamma=0.99
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run with a specific dataset and wandb logging:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/server/learner_server.py \
|
||||||
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
|
--dataset.repo_id=username/pick_lift_cube \
|
||||||
|
--wandb.enable=true \
|
||||||
|
--wandb.project=hilserl_training
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run with a pretrained policy for fine-tuning:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/server/learner_server.py \
|
||||||
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
|
--pretrained_policy_name_or_path=outputs/previous_training/checkpoints/080000/pretrained_model
|
||||||
|
```
|
||||||
|
|
||||||
|
- Run with a reward classifier model:
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/server/learner_server.py \
|
||||||
|
--config_path lerobot/configs/train_config_hilserl_so100.json \
|
||||||
|
--reward_classifier_pretrained_path=outputs/reward_model/best_model
|
||||||
|
```
|
||||||
|
|
||||||
|
**NOTE**: Start the learner server before launching the actor server. The learner opens a gRPC server
|
||||||
|
to communicate with actors.
|
||||||
|
|
||||||
|
**NOTE**: Training progress can be monitored through Weights & Biases if wandb.enable is set to true
|
||||||
|
in your configuration.
|
||||||
|
|
||||||
|
**WORKFLOW**:
|
||||||
|
1. Create training configuration with proper policy, dataset, and environment settings
|
||||||
|
2. Start this learner server with the configuration
|
||||||
|
3. Start an actor server with the same configuration
|
||||||
|
4. Monitor training progress through wandb dashboard
|
||||||
|
|
||||||
|
For more details on the complete HILSerl training workflow, see:
|
||||||
|
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -73,7 +133,6 @@ from lerobot.scripts.server.utils import (
|
|||||||
|
|
||||||
LOG_PREFIX = "[LEARNER]"
|
LOG_PREFIX = "[LEARNER]"
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
#################################################
|
#################################################
|
||||||
# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS #
|
# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS #
|
||||||
@@ -113,13 +172,17 @@ def train(cfg: TrainPipelineConfig, job_name: str | None = None):
|
|||||||
if job_name is None:
|
if job_name is None:
|
||||||
raise ValueError("Job name must be specified either in config or as a parameter")
|
raise ValueError("Job name must be specified either in config or as a parameter")
|
||||||
|
|
||||||
|
display_pid = False
|
||||||
|
if not use_threads(cfg):
|
||||||
|
display_pid = True
|
||||||
|
|
||||||
# Create logs directory to ensure it exists
|
# Create logs directory to ensure it exists
|
||||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
log_file = os.path.join(log_dir, f"learner_{job_name}.log")
|
log_file = os.path.join(log_dir, f"learner_{job_name}.log")
|
||||||
|
|
||||||
# Initialize logging with explicit log file
|
# Initialize logging with explicit log file
|
||||||
init_logging(log_file=log_file)
|
init_logging(log_file=log_file, display_pid=display_pid)
|
||||||
logging.info(f"Learner logging initialized, writing to {log_file}")
|
logging.info(f"Learner logging initialized, writing to {log_file}")
|
||||||
logging.info(pformat(cfg.to_dict()))
|
logging.info(pformat(cfg.to_dict()))
|
||||||
|
|
||||||
@@ -275,7 +338,7 @@ def add_actor_information_and_train(
|
|||||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log")
|
log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log")
|
||||||
init_logging(log_file=log_file)
|
init_logging(log_file=log_file, display_pid=True)
|
||||||
logging.info("Initialized logging for actor information and training process")
|
logging.info("Initialized logging for actor information and training process")
|
||||||
|
|
||||||
logging.info("Initializing policy")
|
logging.info("Initializing policy")
|
||||||
@@ -604,7 +667,7 @@ def start_learner_server(
|
|||||||
log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log")
|
log_file = os.path.join(log_dir, f"learner_server_process_{os.getpid()}.log")
|
||||||
|
|
||||||
# Initialize logging with explicit log file
|
# Initialize logging with explicit log file
|
||||||
init_logging(log_file=log_file)
|
init_logging(log_file=log_file, display_pid=True)
|
||||||
logging.info("Learner server process logging initialized")
|
logging.info("Learner server process logging initialized")
|
||||||
|
|
||||||
# Setup process handlers to handle shutdown signal
|
# Setup process handlers to handle shutdown signal
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ dora = [
|
|||||||
]
|
]
|
||||||
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||||
hilserl = ["transformers>=4.48.0", "gym-hil>=0.1.2", "protobuf>=5.29.3", "grpcio>=1.70.0"]
|
hilserl = ["transformers>=4.48", "gym-hil>=0.1.3", "protobuf>=5.29.3", "grpcio>=1.70.0"]
|
||||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||||
pi0 = ["transformers>=4.48.0"]
|
pi0 = ["transformers>=4.48.0"]
|
||||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||||
|
|||||||
@@ -59,7 +59,6 @@ def test_sac_config_default_initialization():
|
|||||||
assert config.num_critics == 2
|
assert config.num_critics == 2
|
||||||
|
|
||||||
# Architecture specifics
|
# Architecture specifics
|
||||||
assert config.camera_number == 1
|
|
||||||
assert config.vision_encoder_name is None
|
assert config.vision_encoder_name is None
|
||||||
assert config.freeze_vision_encoder is True
|
assert config.freeze_vision_encoder is True
|
||||||
assert config.image_encoder_hidden_dim == 32
|
assert config.image_encoder_hidden_dim == 32
|
||||||
|
|||||||
@@ -9,6 +9,13 @@ from lerobot.common.policies.sac.modeling_sac import MLP, SACPolicy
|
|||||||
from lerobot.common.utils.random_utils import seeded_context
|
from lerobot.common.utils.random_utils import seeded_context
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
|
||||||
|
try:
|
||||||
|
import transformers # noqa: F401
|
||||||
|
|
||||||
|
TRANSFORMERS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TRANSFORMERS_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
def test_mlp_with_default_args():
|
def test_mlp_with_default_args():
|
||||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||||
@@ -274,6 +281,7 @@ def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_di
|
|||||||
"batch_size,state_dim,action_dim,vision_encoder_name",
|
"batch_size,state_dim,action_dim,vision_encoder_name",
|
||||||
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||||
def test_sac_policy_with_pretrained_encoder(
|
def test_sac_policy_with_pretrained_encoder(
|
||||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
||||||
):
|
):
|
||||||
|
|||||||
Reference in New Issue
Block a user