chore: replace hard-coded action values with constants throughout all the source code (#2055)

* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-26 13:33:18 +02:00
committed by GitHub
parent 9627765ce2
commit d2782cf66b
47 changed files with 269 additions and 255 deletions

View File

@@ -25,7 +25,7 @@ from lerobot.policies.sac.configuration_sac import (
PolicyConfig,
SACConfig,
)
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
def test_sac_config_default_initialization():
@@ -46,7 +46,7 @@ def test_sac_config_default_initialization():
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
ACTION: {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
@@ -99,7 +99,7 @@ def test_sac_config_default_initialization():
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
ACTION: {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
@@ -193,7 +193,7 @@ def test_sac_config_custom_initialization():
def test_validate_features():
config = SACConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
config.validate_features()
@@ -201,7 +201,7 @@ def test_validate_features():
def test_validate_features_missing_observation():
config = SACConfig(
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
with pytest.raises(
ValueError, match="You must provide either 'observation.state' or an image observation"