chore: replace hard-coded action values with constants throughout all the source code (#2055)

* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-26 13:33:18 +02:00
committed by GitHub
parent 9627765ce2
commit d2782cf66b
47 changed files with 269 additions and 255 deletions

View File

@@ -35,7 +35,7 @@ from lerobot.processor import (
TransitionKey,
)
from lerobot.processor.converters import create_transition, identity_transition
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from tests.conftest import assert_contract_is_typed
@@ -257,7 +257,7 @@ def test_step_through_with_dict():
batch = {
OBS_IMAGE: None,
"action": None,
ACTION: None,
"next.reward": 0.0,
"next.done": False,
"next.truncated": False,
@@ -1842,7 +1842,7 @@ def test_save_load_with_custom_converter_functions():
# Verify it uses default converters by checking with standard batch format
batch = {
OBS_IMAGE: torch.randn(1, 3, 32, 32),
"action": torch.randn(1, 7),
ACTION: torch.randn(1, 7),
"next.reward": torch.tensor([1.0]),
"next.done": torch.tensor([False]),
"next.truncated": torch.tensor([False]),
@@ -2094,11 +2094,11 @@ def test_aggregate_joint_action_only():
patterns=["action.j1.pos", "action.j2.pos"],
)
# Expect only "action" with joint names
assert "action" in out and OBS_STATE not in out
assert out["action"]["dtype"] == "float32"
assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"}
assert out["action"]["shape"] == (len(out["action"]["names"]),)
# Expect only ACTION with joint names
assert ACTION in out and OBS_STATE not in out
assert out[ACTION]["dtype"] == "float32"
assert set(out[ACTION]["names"]) == {"j1.pos", "j2.pos"}
assert out[ACTION]["shape"] == (len(out[ACTION]["names"]),)
def test_aggregate_ee_action_and_observation_with_videos():
@@ -2113,9 +2113,9 @@ def test_aggregate_ee_action_and_observation_with_videos():
)
# Action should pack only EE names
assert "action" in out
assert set(out["action"]["names"]) == {"ee.x", "ee.y"}
assert out["action"]["dtype"] == "float32"
assert ACTION in out
assert set(out[ACTION]["names"]) == {"ee.x", "ee.y"}
assert out[ACTION]["dtype"] == "float32"
# Observation state should pack both ee.x and j1.pos as a vector
assert OBS_STATE in out
@@ -2140,10 +2140,10 @@ def test_aggregate_both_action_types():
patterns=["action.ee", "action.j1", "action.j2.pos"],
)
assert "action" in out
assert ACTION in out
expected = {"ee.x", "ee.y", "j1.pos", "j2.pos"}
assert set(out["action"]["names"]) == expected
assert out["action"]["shape"] == (len(expected),)
assert set(out[ACTION]["names"]) == expected
assert out[ACTION]["shape"] == (len(expected),)
def test_aggregate_images_when_use_videos_false():