chore: replace hard-coded action values with constants throughout all the source code (#2055)

* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-26 13:33:18 +02:00
committed by GitHub
parent 9627765ce2
commit d2782cf66b
47 changed files with 269 additions and 255 deletions

View File

@@ -2,7 +2,7 @@ import torch
from lerobot.processor import DataProcessorPipeline, TransitionKey
from lerobot.processor.converters import batch_to_transition, transition_to_batch
from lerobot.utils.constants import OBS_IMAGE, OBS_PREFIX, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_PREFIX, OBS_STATE
def _dummy_batch():
@@ -11,7 +11,7 @@ def _dummy_batch():
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128),
OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
"action": torch.tensor([[0.5]]),
ACTION: torch.tensor([[0.5]]),
"next.reward": 1.0,
"next.done": False,
"next.truncated": False,
@@ -37,7 +37,7 @@ def test_observation_grouping_roundtrip():
assert torch.allclose(batch_out[OBS_STATE], batch_in[OBS_STATE])
# Check other fields
assert torch.allclose(batch_out["action"], batch_in["action"])
assert torch.allclose(batch_out[ACTION], batch_in[ACTION])
assert batch_out["next.reward"] == batch_in["next.reward"]
assert batch_out["next.done"] == batch_in["next.done"]
assert batch_out["next.truncated"] == batch_in["next.truncated"]
@@ -50,7 +50,7 @@ def test_batch_to_transition_observation_grouping():
f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128),
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
OBS_STATE: [1, 2, 3, 4],
"action": torch.tensor([0.1, 0.2, 0.3, 0.4]),
ACTION: torch.tensor([0.1, 0.2, 0.3, 0.4]),
"next.reward": 1.5,
"next.done": True,
"next.truncated": False,
@@ -114,7 +114,7 @@ def test_transition_to_batch_observation_flattening():
assert batch[OBS_STATE] == [1, 2, 3, 4]
# Check other fields are mapped to next.* format
assert batch["action"] == "action_data"
assert batch[ACTION] == "action_data"
assert batch["next.reward"] == 1.5
assert batch["next.done"]
assert not batch["next.truncated"]
@@ -124,7 +124,7 @@ def test_transition_to_batch_observation_flattening():
def test_no_observation_keys():
"""Test behavior when there are no observation.* keys."""
batch = {
"action": torch.tensor([1.0, 2.0]),
ACTION: torch.tensor([1.0, 2.0]),
"next.reward": 2.0,
"next.done": False,
"next.truncated": True,
@@ -145,7 +145,7 @@ def test_no_observation_keys():
# Round trip should work
reconstructed_batch = transition_to_batch(transition)
assert torch.allclose(reconstructed_batch["action"], torch.tensor([1.0, 2.0]))
assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([1.0, 2.0]))
assert reconstructed_batch["next.reward"] == 2.0
assert not reconstructed_batch["next.done"]
assert reconstructed_batch["next.truncated"]
@@ -154,7 +154,7 @@ def test_no_observation_keys():
def test_minimal_batch():
"""Test with minimal batch containing only observation.* and action."""
batch = {OBS_STATE: "minimal_state", "action": torch.tensor([0.5])}
batch = {OBS_STATE: "minimal_state", ACTION: torch.tensor([0.5])}
transition = batch_to_transition(batch)
@@ -172,7 +172,7 @@ def test_minimal_batch():
# Round trip
reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch[OBS_STATE] == "minimal_state"
assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5]))
assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([0.5]))
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
assert not reconstructed_batch["next.truncated"]
@@ -196,7 +196,7 @@ def test_empty_batch():
# Round trip
reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["action"] is None
assert reconstructed_batch[ACTION] is None
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
assert not reconstructed_batch["next.truncated"]
@@ -209,7 +209,7 @@ def test_complex_nested_observation():
f"{OBS_IMAGE}.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
OBS_STATE: torch.randn(7),
"action": torch.randn(8),
ACTION: torch.randn(8),
"next.reward": 3.14,
"next.done": False,
"next.truncated": True,
@@ -237,7 +237,7 @@ def test_complex_nested_observation():
)
# Check action tensor
assert torch.allclose(batch["action"], reconstructed_batch["action"])
assert torch.allclose(batch[ACTION], reconstructed_batch[ACTION])
# Check other fields
assert batch["next.reward"] == reconstructed_batch["next.reward"]
@@ -266,7 +266,7 @@ def test_custom_converter():
batch = {
OBS_STATE: torch.randn(1, 4),
"action": torch.randn(1, 2),
ACTION: torch.randn(1, 2),
"next.reward": 1.0,
"next.done": False,
}
@@ -276,4 +276,4 @@ def test_custom_converter():
# Check the reward was doubled by our custom converter
assert result["next.reward"] == 2.0
assert torch.allclose(result[OBS_STATE], batch[OBS_STATE])
assert torch.allclose(result["action"], batch["action"])
assert torch.allclose(result[ACTION], batch[ACTION])