[PORT HIL-SERL] Add unit tests for SAC modeling (#999)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -70,9 +70,10 @@ class SACConfig(PreTrainedConfig):
|
||||
hyperparameters.
|
||||
|
||||
Args:
|
||||
actor_network: Configuration for the actor network architecture.
|
||||
critic_network: Configuration for the critic network architecture.
|
||||
policy: Configuration for the policy parameters.
|
||||
actor_network_kwargs: Configuration for the actor network architecture.
|
||||
critic_network_kwargs: Configuration for the critic network architecture.
|
||||
discrete_critic_network_kwargs: Configuration for the discrete critic network.
|
||||
policy_kwargs: Configuration for the policy parameters.
|
||||
n_obs_steps: Number of observation steps to consider.
|
||||
normalization_mapping: Mapping of feature types to normalization modes.
|
||||
dataset_stats: Statistics for normalizing different types of inputs.
|
||||
@@ -88,7 +89,7 @@ class SACConfig(PreTrainedConfig):
|
||||
num_discrete_actions: Number of discrete actions, eg for gripper actions.
|
||||
image_embedding_pooling_dim: Dimension of the image embedding pooling.
|
||||
concurrency: Configuration for concurrency settings.
|
||||
actor_learner: Configuration for actor-learner architecture.
|
||||
actor_learner_config: Configuration for actor-learner architecture.
|
||||
online_steps: Number of steps for online training.
|
||||
online_env_seed: Seed for the online environment.
|
||||
online_buffer_capacity: Capacity of the online replay buffer.
|
||||
@@ -140,7 +141,7 @@ class SACConfig(PreTrainedConfig):
|
||||
|
||||
# Architecture specifics
|
||||
camera_number: int = 1
|
||||
device: str = "cuda"
|
||||
device: str = "cpu"
|
||||
storage_device: str = "cpu"
|
||||
# Set to "helper2424/resnet10" for hil serl
|
||||
vision_encoder_name: str | None = None
|
||||
|
||||
218
tests/policies/test_sac_config.py
Normal file
218
tests/policies/test_sac_config.py
Normal file
@@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import (
|
||||
ActorLearnerConfig,
|
||||
ActorNetworkConfig,
|
||||
ConcurrencyConfig,
|
||||
CriticNetworkConfig,
|
||||
PolicyConfig,
|
||||
SACConfig,
|
||||
)
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
def test_sac_config_default_initialization():
|
||||
config = SACConfig()
|
||||
|
||||
assert config.normalization_mapping == {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
assert config.dataset_stats == {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
|
||||
# Basic parameters
|
||||
assert config.device == "cpu"
|
||||
assert config.storage_device == "cpu"
|
||||
assert config.discount == 0.99
|
||||
assert config.temperature_init == 1.0
|
||||
assert config.num_critics == 2
|
||||
|
||||
# Architecture specifics
|
||||
assert config.camera_number == 1
|
||||
assert config.vision_encoder_name is None
|
||||
assert config.freeze_vision_encoder is True
|
||||
assert config.image_encoder_hidden_dim == 32
|
||||
assert config.shared_encoder is True
|
||||
assert config.num_discrete_actions is None
|
||||
assert config.image_embedding_pooling_dim == 8
|
||||
|
||||
# Training parameters
|
||||
assert config.online_steps == 1000000
|
||||
assert config.online_env_seed == 10000
|
||||
assert config.online_buffer_capacity == 100000
|
||||
assert config.offline_buffer_capacity == 100000
|
||||
assert config.async_prefetch is False
|
||||
assert config.online_step_before_learning == 100
|
||||
assert config.policy_update_freq == 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
assert config.num_subsample_critics is None
|
||||
assert config.critic_lr == 3e-4
|
||||
assert config.actor_lr == 3e-4
|
||||
assert config.temperature_lr == 3e-4
|
||||
assert config.critic_target_update_weight == 0.005
|
||||
assert config.utd_ratio == 1
|
||||
assert config.state_encoder_hidden_dim == 256
|
||||
assert config.latent_dim == 256
|
||||
assert config.target_entropy is None
|
||||
assert config.use_backup_entropy is True
|
||||
assert config.grad_clip_norm == 40.0
|
||||
|
||||
# Dataset stats defaults
|
||||
expected_dataset_stats = {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
assert config.dataset_stats == expected_dataset_stats
|
||||
|
||||
# Critic network configuration
|
||||
assert config.critic_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.critic_network_kwargs.activate_final is True
|
||||
assert config.critic_network_kwargs.final_activation is None
|
||||
|
||||
# Actor network configuration
|
||||
assert config.actor_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.actor_network_kwargs.activate_final is True
|
||||
|
||||
# Policy configuration
|
||||
assert config.policy_kwargs.use_tanh_squash is True
|
||||
assert config.policy_kwargs.log_std_min == 1e-5
|
||||
assert config.policy_kwargs.log_std_max == 10.0
|
||||
assert config.policy_kwargs.init_final == 0.05
|
||||
|
||||
# Discrete critic network configuration
|
||||
assert config.discrete_critic_network_kwargs.hidden_dims == [256, 256]
|
||||
assert config.discrete_critic_network_kwargs.activate_final is True
|
||||
assert config.discrete_critic_network_kwargs.final_activation is None
|
||||
|
||||
# Actor learner configuration
|
||||
assert config.actor_learner_config.learner_host == "127.0.0.1"
|
||||
assert config.actor_learner_config.learner_port == 50051
|
||||
assert config.actor_learner_config.policy_parameters_push_frequency == 4
|
||||
|
||||
# Concurrency configuration
|
||||
assert config.concurrency.actor == "threads"
|
||||
assert config.concurrency.learner == "threads"
|
||||
|
||||
assert isinstance(config.actor_network_kwargs, ActorNetworkConfig)
|
||||
assert isinstance(config.critic_network_kwargs, CriticNetworkConfig)
|
||||
assert isinstance(config.policy_kwargs, PolicyConfig)
|
||||
assert isinstance(config.actor_learner_config, ActorLearnerConfig)
|
||||
assert isinstance(config.concurrency, ConcurrencyConfig)
|
||||
|
||||
|
||||
def test_critic_network_kwargs():
|
||||
config = CriticNetworkConfig()
|
||||
assert config.hidden_dims == [256, 256]
|
||||
assert config.activate_final is True
|
||||
assert config.final_activation is None
|
||||
|
||||
|
||||
def test_actor_network_kwargs():
|
||||
config = ActorNetworkConfig()
|
||||
assert config.hidden_dims == [256, 256]
|
||||
assert config.activate_final is True
|
||||
|
||||
|
||||
def test_policy_kwargs():
|
||||
config = PolicyConfig()
|
||||
assert config.use_tanh_squash is True
|
||||
assert config.log_std_min == 1e-5
|
||||
assert config.log_std_max == 10.0
|
||||
assert config.init_final == 0.05
|
||||
|
||||
|
||||
def test_actor_learner_config():
|
||||
config = ActorLearnerConfig()
|
||||
assert config.learner_host == "127.0.0.1"
|
||||
assert config.learner_port == 50051
|
||||
assert config.policy_parameters_push_frequency == 4
|
||||
|
||||
|
||||
def test_concurrency_config():
|
||||
config = ConcurrencyConfig()
|
||||
assert config.actor == "threads"
|
||||
assert config.learner == "threads"
|
||||
|
||||
|
||||
def test_sac_config_custom_initialization():
|
||||
config = SACConfig(
|
||||
device="cpu",
|
||||
discount=0.95,
|
||||
temperature_init=0.5,
|
||||
num_critics=3,
|
||||
)
|
||||
|
||||
assert config.device == "cpu"
|
||||
assert config.discount == 0.95
|
||||
assert config.temperature_init == 0.5
|
||||
assert config.num_critics == 3
|
||||
|
||||
|
||||
def test_validate_features():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": {"shape": (10,), "type": "float32"}},
|
||||
output_features={"action": {"shape": (3,), "type": "float32"}},
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_missing_observation():
|
||||
config = SACConfig(
|
||||
input_features={"wrong_key": {"shape": (10,), "type": "float32"}},
|
||||
output_features={"action": {"shape": (3,), "type": "float32"}},
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError, match="You must provide either 'observation.state' or an image observation"
|
||||
):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_missing_action():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": {"shape": (10,), "type": "float32"}},
|
||||
output_features={"wrong_key": {"shape": (3,), "type": "float32"}},
|
||||
)
|
||||
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):
|
||||
config.validate_features()
|
||||
43
tests/policies/test_sac_policy.py
Normal file
43
tests/policies/test_sac_policy.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.common.policies.sac.modeling_sac import MLP
|
||||
|
||||
|
||||
def test_mlp_with_default_args():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
|
||||
x = torch.randn(10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (256,)
|
||||
|
||||
|
||||
def test_mlp_with_batch_dim():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256])
|
||||
x = torch.randn(2, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (2, 256)
|
||||
|
||||
|
||||
def test_forward_with_empty_hidden_dims():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[])
|
||||
x = torch.randn(1, 10)
|
||||
assert mlp(x).shape == (1, 10)
|
||||
|
||||
|
||||
def test_mlp_with_dropout():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1)
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 11)
|
||||
|
||||
drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net)
|
||||
assert drop_out_layers_count == 2
|
||||
|
||||
|
||||
def test_mlp_with_custom_final_activation():
|
||||
mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh())
|
||||
x = torch.randn(1, 10)
|
||||
y = mlp(x)
|
||||
assert y.shape == (1, 256)
|
||||
assert (y >= -1).all() and (y <= 1).all()
|
||||
Reference in New Issue
Block a user