chore: replace hard-coded obs values with constants throughout all the source code (#2037)

* chore: replace hard-coded OBS values with constants throughout all the source code

* chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-25 15:36:47 +02:00
committed by GitHub
parent ddba994d73
commit 43d878a102
52 changed files with 659 additions and 649 deletions

View File

@@ -25,6 +25,7 @@ from lerobot.policies.sac.configuration_sac import (
PolicyConfig,
SACConfig,
)
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
def test_sac_config_default_initialization():
@@ -37,11 +38,11 @@ def test_sac_config_default_initialization():
"ACTION": NormalizationMode.MIN_MAX,
}
assert config.dataset_stats == {
"observation.image": {
OBS_IMAGE: {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
OBS_STATE: {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
@@ -90,11 +91,11 @@ def test_sac_config_default_initialization():
# Dataset stats defaults
expected_dataset_stats = {
"observation.image": {
OBS_IMAGE: {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
OBS_STATE: {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
@@ -191,7 +192,7 @@ def test_sac_config_custom_initialization():
def test_validate_features():
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
config.validate_features()
@@ -210,7 +211,7 @@ def test_validate_features_missing_observation():
def test_validate_features_missing_action():
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):