[PORT HIL-SERL] Better unit tests coverage for SAC policy (#1074)

This commit is contained in:
Eugene Mironov
2025-05-14 21:41:36 +07:00
committed by AdilZouitine
parent f8a963b86f
commit 5902f8fcc7
4 changed files with 492 additions and 25 deletions

View File

@@ -24,7 +24,7 @@ from lerobot.common.policies.sac.configuration_sac import (
PolicyConfig,
SACConfig,
)
from lerobot.configs.types import NormalizationMode
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
def test_sac_config_default_initialization():
@@ -192,16 +192,16 @@ def test_sac_config_custom_initialization():
def test_validate_features():
config = SACConfig(
input_features={"observation.state": {"shape": (10,), "type": "float32"}},
output_features={"action": {"shape": (3,), "type": "float32"}},
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
config.validate_features()
def test_validate_features_missing_observation():
config = SACConfig(
input_features={"wrong_key": {"shape": (10,), "type": "float32"}},
output_features={"action": {"shape": (3,), "type": "float32"}},
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
with pytest.raises(
ValueError, match="You must provide either 'observation.state' or an image observation"
@@ -211,8 +211,8 @@ def test_validate_features_missing_observation():
def test_validate_features_missing_action():
config = SACConfig(
input_features={"observation.state": {"shape": (10,), "type": "float32"}},
output_features={"wrong_key": {"shape": (3,), "type": "float32"}},
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):
config.validate_features()