forked from tangger/lerobot
chore: replace hard-coded action values with constants throughout all the source code (#2055)
* chore: replace hard-coded 'action' values with constants throughout all the source code * chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
@@ -35,7 +35,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -257,7 +257,7 @@ def test_step_through_with_dict():
|
||||
|
||||
batch = {
|
||||
OBS_IMAGE: None,
|
||||
"action": None,
|
||||
ACTION: None,
|
||||
"next.reward": 0.0,
|
||||
"next.done": False,
|
||||
"next.truncated": False,
|
||||
@@ -1842,7 +1842,7 @@ def test_save_load_with_custom_converter_functions():
|
||||
# Verify it uses default converters by checking with standard batch format
|
||||
batch = {
|
||||
OBS_IMAGE: torch.randn(1, 3, 32, 32),
|
||||
"action": torch.randn(1, 7),
|
||||
ACTION: torch.randn(1, 7),
|
||||
"next.reward": torch.tensor([1.0]),
|
||||
"next.done": torch.tensor([False]),
|
||||
"next.truncated": torch.tensor([False]),
|
||||
@@ -2094,11 +2094,11 @@ def test_aggregate_joint_action_only():
|
||||
patterns=["action.j1.pos", "action.j2.pos"],
|
||||
)
|
||||
|
||||
# Expect only "action" with joint names
|
||||
assert "action" in out and OBS_STATE not in out
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"}
|
||||
assert out["action"]["shape"] == (len(out["action"]["names"]),)
|
||||
# Expect only ACTION with joint names
|
||||
assert ACTION in out and OBS_STATE not in out
|
||||
assert out[ACTION]["dtype"] == "float32"
|
||||
assert set(out[ACTION]["names"]) == {"j1.pos", "j2.pos"}
|
||||
assert out[ACTION]["shape"] == (len(out[ACTION]["names"]),)
|
||||
|
||||
|
||||
def test_aggregate_ee_action_and_observation_with_videos():
|
||||
@@ -2113,9 +2113,9 @@ def test_aggregate_ee_action_and_observation_with_videos():
|
||||
)
|
||||
|
||||
# Action should pack only EE names
|
||||
assert "action" in out
|
||||
assert set(out["action"]["names"]) == {"ee.x", "ee.y"}
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
assert ACTION in out
|
||||
assert set(out[ACTION]["names"]) == {"ee.x", "ee.y"}
|
||||
assert out[ACTION]["dtype"] == "float32"
|
||||
|
||||
# Observation state should pack both ee.x and j1.pos as a vector
|
||||
assert OBS_STATE in out
|
||||
@@ -2140,10 +2140,10 @@ def test_aggregate_both_action_types():
|
||||
patterns=["action.ee", "action.j1", "action.j2.pos"],
|
||||
)
|
||||
|
||||
assert "action" in out
|
||||
assert ACTION in out
|
||||
expected = {"ee.x", "ee.y", "j1.pos", "j2.pos"}
|
||||
assert set(out["action"]["names"]) == expected
|
||||
assert out["action"]["shape"] == (len(expected),)
|
||||
assert set(out[ACTION]["names"]) == expected
|
||||
assert out[ACTION]["shape"] == (len(expected),)
|
||||
|
||||
|
||||
def test_aggregate_images_when_use_videos_false():
|
||||
|
||||
Reference in New Issue
Block a user