chore: replace hard-coded action values with constants throughout all the source code (#2055)

* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-26 13:33:18 +02:00
committed by GitHub
parent 9627765ce2
commit d2782cf66b
47 changed files with 269 additions and 255 deletions

View File

@@ -59,7 +59,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
},
}
motor_features = {
"action": {
ACTION: {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
@@ -287,7 +287,7 @@ def test_multikey_construction(multikey: bool):
),
}
output_features = {
"action": PolicyFeature(
ACTION: PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
),
@@ -304,7 +304,7 @@ def test_multikey_construction(multikey: bool):
output_features = {}
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))
output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,))
output_features["action"] = PolicyFeature(
output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
)

View File

@@ -25,7 +25,7 @@ from lerobot.policies.sac.configuration_sac import (
PolicyConfig,
SACConfig,
)
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
def test_sac_config_default_initialization():
@@ -46,7 +46,7 @@ def test_sac_config_default_initialization():
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
ACTION: {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
@@ -99,7 +99,7 @@ def test_sac_config_default_initialization():
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
ACTION: {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
@@ -193,7 +193,7 @@ def test_sac_config_custom_initialization():
def test_validate_features():
config = SACConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
config.validate_features()
@@ -201,7 +201,7 @@ def test_validate_features():
def test_validate_features_missing_observation():
config = SACConfig(
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
with pytest.raises(
ValueError, match="You must provide either 'observation.state' or an image observation"

View File

@@ -23,7 +23,7 @@ from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
try:
@@ -105,7 +105,7 @@ def create_default_train_batch(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
"action": create_dummy_action(batch_size, action_dim),
ACTION: create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_state(batch_size, state_dim),
"next_state": create_dummy_state(batch_size, state_dim),
@@ -117,7 +117,7 @@ def create_train_batch_with_visual_input(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
"action": create_dummy_action(batch_size, action_dim),
ACTION: create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_with_visual_input(batch_size, state_dim),
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
@@ -182,13 +182,13 @@ def create_default_config(
config = SACConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={
OBS_STATE: {
"min": [0.0] * state_dim,
"max": [1.0] * state_dim,
},
"action": {
ACTION: {
"min": [0.0] * continuous_action_dim,
"max": [1.0] * continuous_action_dim,
},