From 6fa7df35df4904b7eb38f419759d053254b6373c Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 5 May 2025 14:27:42 +0700 Subject: [PATCH] [PORT HIL-SERL] Add unit tests for SAC modeling (#999) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/policies/sac/configuration_sac.py | 11 +- tests/policies/test_sac_config.py | 218 ++++++++++++++++++ tests/policies/test_sac_policy.py | 43 ++++ 3 files changed, 267 insertions(+), 5 deletions(-) create mode 100644 tests/policies/test_sac_config.py create mode 100644 tests/policies/test_sac_policy.py diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index c9bd90fc..2fabf5a8 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -70,9 +70,10 @@ class SACConfig(PreTrainedConfig): hyperparameters. Args: - actor_network: Configuration for the actor network architecture. - critic_network: Configuration for the critic network architecture. - policy: Configuration for the policy parameters. + actor_network_kwargs: Configuration for the actor network architecture. + critic_network_kwargs: Configuration for the critic network architecture. + discrete_critic_network_kwargs: Configuration for the discrete critic network. + policy_kwargs: Configuration for the policy parameters. n_obs_steps: Number of observation steps to consider. normalization_mapping: Mapping of feature types to normalization modes. dataset_stats: Statistics for normalizing different types of inputs. @@ -88,7 +89,7 @@ class SACConfig(PreTrainedConfig): num_discrete_actions: Number of discrete actions, eg for gripper actions. image_embedding_pooling_dim: Dimension of the image embedding pooling. concurrency: Configuration for concurrency settings. - actor_learner: Configuration for actor-learner architecture. + actor_learner_config: Configuration for actor-learner architecture. online_steps: Number of steps for online training. online_env_seed: Seed for the online environment. online_buffer_capacity: Capacity of the online replay buffer. @@ -140,7 +141,7 @@ class SACConfig(PreTrainedConfig): # Architecture specifics camera_number: int = 1 - device: str = "cuda" + device: str = "cpu" storage_device: str = "cpu" # Set to "helper2424/resnet10" for hil serl vision_encoder_name: str | None = None diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_sac_config.py new file mode 100644 index 00000000..3ee21672 --- /dev/null +++ b/tests/policies/test_sac_config.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from lerobot.common.policies.sac.configuration_sac import ( + ActorLearnerConfig, + ActorNetworkConfig, + ConcurrencyConfig, + CriticNetworkConfig, + PolicyConfig, + SACConfig, +) +from lerobot.configs.types import NormalizationMode + + +def test_sac_config_default_initialization(): + config = SACConfig() + + assert config.normalization_mapping == { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ENV": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + assert config.dataset_stats == { + "observation.image": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + }, + "observation.state": { + "min": [0.0, 0.0], + "max": [1.0, 1.0], + }, + "action": { + "min": [0.0, 0.0, 0.0], + "max": [1.0, 1.0, 1.0], + }, + } + + # Basic parameters + assert config.device == "cpu" + assert config.storage_device == "cpu" + assert config.discount == 0.99 + assert config.temperature_init == 1.0 + assert config.num_critics == 2 + + # Architecture specifics + assert config.camera_number == 1 + assert config.vision_encoder_name is None + assert config.freeze_vision_encoder is True + assert config.image_encoder_hidden_dim == 32 + assert config.shared_encoder is True + assert config.num_discrete_actions is None + assert config.image_embedding_pooling_dim == 8 + + # Training parameters + assert config.online_steps == 1000000 + assert config.online_env_seed == 10000 + assert config.online_buffer_capacity == 100000 + assert config.offline_buffer_capacity == 100000 + assert config.async_prefetch is False + assert config.online_step_before_learning == 100 + assert config.policy_update_freq == 1 + + # SAC algorithm parameters + assert config.num_subsample_critics is None + assert config.critic_lr == 3e-4 + assert config.actor_lr == 3e-4 + assert config.temperature_lr == 3e-4 + assert config.critic_target_update_weight == 0.005 + assert config.utd_ratio == 1 + assert config.state_encoder_hidden_dim == 256 + assert config.latent_dim == 256 + assert config.target_entropy is None + assert config.use_backup_entropy is True + assert config.grad_clip_norm == 40.0 + + # Dataset stats defaults + expected_dataset_stats = { + "observation.image": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + }, + "observation.state": { + "min": [0.0, 0.0], + "max": [1.0, 1.0], + }, + "action": { + "min": [0.0, 0.0, 0.0], + "max": [1.0, 1.0, 1.0], + }, + } + assert config.dataset_stats == expected_dataset_stats + + # Critic network configuration + assert config.critic_network_kwargs.hidden_dims == [256, 256] + assert config.critic_network_kwargs.activate_final is True + assert config.critic_network_kwargs.final_activation is None + + # Actor network configuration + assert config.actor_network_kwargs.hidden_dims == [256, 256] + assert config.actor_network_kwargs.activate_final is True + + # Policy configuration + assert config.policy_kwargs.use_tanh_squash is True + assert config.policy_kwargs.log_std_min == 1e-5 + assert config.policy_kwargs.log_std_max == 10.0 + assert config.policy_kwargs.init_final == 0.05 + + # Discrete critic network configuration + assert config.discrete_critic_network_kwargs.hidden_dims == [256, 256] + assert config.discrete_critic_network_kwargs.activate_final is True + assert config.discrete_critic_network_kwargs.final_activation is None + + # Actor learner configuration + assert config.actor_learner_config.learner_host == "127.0.0.1" + assert config.actor_learner_config.learner_port == 50051 + assert config.actor_learner_config.policy_parameters_push_frequency == 4 + + # Concurrency configuration + assert config.concurrency.actor == "threads" + assert config.concurrency.learner == "threads" + + assert isinstance(config.actor_network_kwargs, ActorNetworkConfig) + assert isinstance(config.critic_network_kwargs, CriticNetworkConfig) + assert isinstance(config.policy_kwargs, PolicyConfig) + assert isinstance(config.actor_learner_config, ActorLearnerConfig) + assert isinstance(config.concurrency, ConcurrencyConfig) + + +def test_critic_network_kwargs(): + config = CriticNetworkConfig() + assert config.hidden_dims == [256, 256] + assert config.activate_final is True + assert config.final_activation is None + + +def test_actor_network_kwargs(): + config = ActorNetworkConfig() + assert config.hidden_dims == [256, 256] + assert config.activate_final is True + + +def test_policy_kwargs(): + config = PolicyConfig() + assert config.use_tanh_squash is True + assert config.log_std_min == 1e-5 + assert config.log_std_max == 10.0 + assert config.init_final == 0.05 + + +def test_actor_learner_config(): + config = ActorLearnerConfig() + assert config.learner_host == "127.0.0.1" + assert config.learner_port == 50051 + assert config.policy_parameters_push_frequency == 4 + + +def test_concurrency_config(): + config = ConcurrencyConfig() + assert config.actor == "threads" + assert config.learner == "threads" + + +def test_sac_config_custom_initialization(): + config = SACConfig( + device="cpu", + discount=0.95, + temperature_init=0.5, + num_critics=3, + ) + + assert config.device == "cpu" + assert config.discount == 0.95 + assert config.temperature_init == 0.5 + assert config.num_critics == 3 + + +def test_validate_features(): + config = SACConfig( + input_features={"observation.state": {"shape": (10,), "type": "float32"}}, + output_features={"action": {"shape": (3,), "type": "float32"}}, + ) + config.validate_features() + + +def test_validate_features_missing_observation(): + config = SACConfig( + input_features={"wrong_key": {"shape": (10,), "type": "float32"}}, + output_features={"action": {"shape": (3,), "type": "float32"}}, + ) + with pytest.raises( + ValueError, match="You must provide either 'observation.state' or an image observation" + ): + config.validate_features() + + +def test_validate_features_missing_action(): + config = SACConfig( + input_features={"observation.state": {"shape": (10,), "type": "float32"}}, + output_features={"wrong_key": {"shape": (3,), "type": "float32"}}, + ) + with pytest.raises(ValueError, match="You must provide 'action' in the output features"): + config.validate_features() diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py new file mode 100644 index 00000000..cfda877a --- /dev/null +++ b/tests/policies/test_sac_policy.py @@ -0,0 +1,43 @@ +import torch +from torch import nn + +from lerobot.common.policies.sac.modeling_sac import MLP + + +def test_mlp_with_default_args(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256]) + + x = torch.randn(10) + y = mlp(x) + assert y.shape == (256,) + + +def test_mlp_with_batch_dim(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256]) + x = torch.randn(2, 10) + y = mlp(x) + assert y.shape == (2, 256) + + +def test_forward_with_empty_hidden_dims(): + mlp = MLP(input_dim=10, hidden_dims=[]) + x = torch.randn(1, 10) + assert mlp(x).shape == (1, 10) + + +def test_mlp_with_dropout(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1) + x = torch.randn(1, 10) + y = mlp(x) + assert y.shape == (1, 11) + + drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net) + assert drop_out_layers_count == 2 + + +def test_mlp_with_custom_final_activation(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh()) + x = torch.randn(1, 10) + y = mlp(x) + assert y.shape == (1, 256) + assert (y >= -1).all() and (y <= 1).all()