[PORT HIL-SERL] Better unit tests coverage for SAC policy (#1074)
This commit is contained in:
committed by
AdilZouitine
parent
f8a963b86f
commit
5902f8fcc7
@@ -24,7 +24,7 @@ from lerobot.common.policies.sac.configuration_sac import (
|
||||
PolicyConfig,
|
||||
SACConfig,
|
||||
)
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def test_sac_config_default_initialization():
|
||||
@@ -192,16 +192,16 @@ def test_sac_config_custom_initialization():
|
||||
|
||||
def test_validate_features():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": {"shape": (10,), "type": "float32"}},
|
||||
output_features={"action": {"shape": (3,), "type": "float32"}},
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_validate_features_missing_observation():
|
||||
config = SACConfig(
|
||||
input_features={"wrong_key": {"shape": (10,), "type": "float32"}},
|
||||
output_features={"action": {"shape": (3,), "type": "float32"}},
|
||||
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError, match="You must provide either 'observation.state' or an image observation"
|
||||
@@ -211,8 +211,8 @@ def test_validate_features_missing_observation():
|
||||
|
||||
def test_validate_features_missing_action():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": {"shape": (10,), "type": "float32"}},
|
||||
output_features={"wrong_key": {"shape": (3,), "type": "float32"}},
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):
|
||||
config.validate_features()
|
||||
|
||||
Reference in New Issue
Block a user