chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -2,14 +2,15 @@ import torch
|
||||
|
||||
from lerobot.processor import DataProcessorPipeline, TransitionKey
|
||||
from lerobot.processor.converters import batch_to_transition, transition_to_batch
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_PREFIX, OBS_STATE
|
||||
|
||||
|
||||
def _dummy_batch():
|
||||
"""Create a dummy batch using the new format with observation.* and next.* keys."""
|
||||
return {
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.right": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
|
||||
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
|
||||
f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128),
|
||||
OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
|
||||
"action": torch.tensor([[0.5]]),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
@@ -25,15 +26,15 @@ def test_observation_grouping_roundtrip():
|
||||
batch_out = proc(batch_in)
|
||||
|
||||
# Check that all observation.* keys are preserved
|
||||
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")}
|
||||
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")}
|
||||
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith(OBS_PREFIX)}
|
||||
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith(OBS_PREFIX)}
|
||||
|
||||
assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys())
|
||||
|
||||
# Check tensor values
|
||||
assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"])
|
||||
assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"])
|
||||
assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"])
|
||||
assert torch.allclose(batch_out[f"{OBS_IMAGE}.left"], batch_in[f"{OBS_IMAGE}.left"])
|
||||
assert torch.allclose(batch_out[f"{OBS_IMAGE}.right"], batch_in[f"{OBS_IMAGE}.right"])
|
||||
assert torch.allclose(batch_out[OBS_STATE], batch_in[OBS_STATE])
|
||||
|
||||
# Check other fields
|
||||
assert torch.allclose(batch_out["action"], batch_in["action"])
|
||||
@@ -46,9 +47,9 @@ def test_observation_grouping_roundtrip():
|
||||
def test_batch_to_transition_observation_grouping():
|
||||
"""Test that batch_to_transition correctly groups observation.* keys."""
|
||||
batch = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128),
|
||||
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
|
||||
OBS_STATE: [1, 2, 3, 4],
|
||||
"action": torch.tensor([0.1, 0.2, 0.3, 0.4]),
|
||||
"next.reward": 1.5,
|
||||
"next.done": True,
|
||||
@@ -60,18 +61,18 @@ def test_batch_to_transition_observation_grouping():
|
||||
|
||||
# Check observation is a dict with all observation.* keys
|
||||
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
|
||||
assert "observation.image.top" in transition[TransitionKey.OBSERVATION]
|
||||
assert "observation.image.left" in transition[TransitionKey.OBSERVATION]
|
||||
assert "observation.state" in transition[TransitionKey.OBSERVATION]
|
||||
assert f"{OBS_IMAGE}.top" in transition[TransitionKey.OBSERVATION]
|
||||
assert f"{OBS_IMAGE}.left" in transition[TransitionKey.OBSERVATION]
|
||||
assert OBS_STATE in transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check values are preserved
|
||||
assert torch.allclose(
|
||||
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
|
||||
transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.top"], batch[f"{OBS_IMAGE}.top"]
|
||||
)
|
||||
assert torch.allclose(
|
||||
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
|
||||
transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.left"], batch[f"{OBS_IMAGE}.left"]
|
||||
)
|
||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
assert transition[TransitionKey.OBSERVATION][OBS_STATE] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4]))
|
||||
@@ -85,9 +86,9 @@ def test_batch_to_transition_observation_grouping():
|
||||
def test_transition_to_batch_observation_flattening():
|
||||
"""Test that transition_to_batch correctly flattens observation dict."""
|
||||
observation_dict = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128),
|
||||
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
|
||||
OBS_STATE: [1, 2, 3, 4],
|
||||
}
|
||||
|
||||
transition = {
|
||||
@@ -103,14 +104,14 @@ def test_transition_to_batch_observation_flattening():
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Check that observation.* keys are flattened back to batch
|
||||
assert "observation.image.top" in batch
|
||||
assert "observation.image.left" in batch
|
||||
assert "observation.state" in batch
|
||||
assert f"{OBS_IMAGE}.top" in batch
|
||||
assert f"{OBS_IMAGE}.left" in batch
|
||||
assert OBS_STATE in batch
|
||||
|
||||
# Check values are preserved
|
||||
assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"])
|
||||
assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"])
|
||||
assert batch["observation.state"] == [1, 2, 3, 4]
|
||||
assert torch.allclose(batch[f"{OBS_IMAGE}.top"], observation_dict[f"{OBS_IMAGE}.top"])
|
||||
assert torch.allclose(batch[f"{OBS_IMAGE}.left"], observation_dict[f"{OBS_IMAGE}.left"])
|
||||
assert batch[OBS_STATE] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields are mapped to next.* format
|
||||
assert batch["action"] == "action_data"
|
||||
@@ -153,12 +154,12 @@ def test_no_observation_keys():
|
||||
|
||||
def test_minimal_batch():
|
||||
"""Test with minimal batch containing only observation.* and action."""
|
||||
batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])}
|
||||
batch = {OBS_STATE: "minimal_state", "action": torch.tensor([0.5])}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert transition[TransitionKey.OBSERVATION] == {OBS_STATE: "minimal_state"}
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5]))
|
||||
|
||||
# Check defaults
|
||||
@@ -170,7 +171,7 @@ def test_minimal_batch():
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||
assert reconstructed_batch[OBS_STATE] == "minimal_state"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5]))
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
@@ -205,9 +206,9 @@ def test_empty_batch():
|
||||
def test_complex_nested_observation():
|
||||
"""Test with complex nested observation data."""
|
||||
batch = {
|
||||
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
||||
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
||||
"observation.state": torch.randn(7),
|
||||
f"{OBS_IMAGE}.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
||||
f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
||||
OBS_STATE: torch.randn(7),
|
||||
"action": torch.randn(8),
|
||||
"next.reward": 3.14,
|
||||
"next.done": False,
|
||||
@@ -219,20 +220,20 @@ def test_complex_nested_observation():
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
|
||||
# Check that all observation keys are preserved
|
||||
original_obs_keys = {k for k in batch if k.startswith("observation.")}
|
||||
reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")}
|
||||
original_obs_keys = {k for k in batch if k.startswith(OBS_PREFIX)}
|
||||
reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith(OBS_PREFIX)}
|
||||
|
||||
assert original_obs_keys == reconstructed_obs_keys
|
||||
|
||||
# Check tensor values
|
||||
assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"])
|
||||
assert torch.allclose(batch[OBS_STATE], reconstructed_batch[OBS_STATE])
|
||||
|
||||
# Check nested dict with tensors
|
||||
assert torch.allclose(
|
||||
batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"]
|
||||
batch[f"{OBS_IMAGE}.top"]["image"], reconstructed_batch[f"{OBS_IMAGE}.top"]["image"]
|
||||
)
|
||||
assert torch.allclose(
|
||||
batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"]
|
||||
batch[f"{OBS_IMAGE}.left"]["image"], reconstructed_batch[f"{OBS_IMAGE}.left"]["image"]
|
||||
)
|
||||
|
||||
# Check action tensor
|
||||
@@ -264,7 +265,7 @@ def test_custom_converter():
|
||||
processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch)
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 4),
|
||||
OBS_STATE: torch.randn(1, 4),
|
||||
"action": torch.randn(1, 2),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
@@ -274,5 +275,5 @@ def test_custom_converter():
|
||||
|
||||
# Check the reward was doubled by our custom converter
|
||||
assert result["next.reward"] == 2.0
|
||||
assert torch.allclose(result["observation.state"], batch["observation.state"])
|
||||
assert torch.allclose(result[OBS_STATE], batch[OBS_STATE])
|
||||
assert torch.allclose(result["action"], batch["action"])
|
||||
|
||||
Reference in New Issue
Block a user