chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -342,7 +342,7 @@ def test_act_processor_batch_consistency():
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
assert processed["observation.state"].shape[0] == 1 # Batched
|
||||
assert processed[OBS_STATE].shape[0] == 1 # Batched
|
||||
|
||||
# Test already batched data
|
||||
observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8
|
||||
|
||||
@@ -2,14 +2,15 @@ import torch
|
||||
|
||||
from lerobot.processor import DataProcessorPipeline, TransitionKey
|
||||
from lerobot.processor.converters import batch_to_transition, transition_to_batch
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_PREFIX, OBS_STATE
|
||||
|
||||
|
||||
def _dummy_batch():
|
||||
"""Create a dummy batch using the new format with observation.* and next.* keys."""
|
||||
return {
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.right": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
|
||||
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
|
||||
f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128),
|
||||
OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
|
||||
"action": torch.tensor([[0.5]]),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
@@ -25,15 +26,15 @@ def test_observation_grouping_roundtrip():
|
||||
batch_out = proc(batch_in)
|
||||
|
||||
# Check that all observation.* keys are preserved
|
||||
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")}
|
||||
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")}
|
||||
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith(OBS_PREFIX)}
|
||||
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith(OBS_PREFIX)}
|
||||
|
||||
assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys())
|
||||
|
||||
# Check tensor values
|
||||
assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"])
|
||||
assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"])
|
||||
assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"])
|
||||
assert torch.allclose(batch_out[f"{OBS_IMAGE}.left"], batch_in[f"{OBS_IMAGE}.left"])
|
||||
assert torch.allclose(batch_out[f"{OBS_IMAGE}.right"], batch_in[f"{OBS_IMAGE}.right"])
|
||||
assert torch.allclose(batch_out[OBS_STATE], batch_in[OBS_STATE])
|
||||
|
||||
# Check other fields
|
||||
assert torch.allclose(batch_out["action"], batch_in["action"])
|
||||
@@ -46,9 +47,9 @@ def test_observation_grouping_roundtrip():
|
||||
def test_batch_to_transition_observation_grouping():
|
||||
"""Test that batch_to_transition correctly groups observation.* keys."""
|
||||
batch = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128),
|
||||
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
|
||||
OBS_STATE: [1, 2, 3, 4],
|
||||
"action": torch.tensor([0.1, 0.2, 0.3, 0.4]),
|
||||
"next.reward": 1.5,
|
||||
"next.done": True,
|
||||
@@ -60,18 +61,18 @@ def test_batch_to_transition_observation_grouping():
|
||||
|
||||
# Check observation is a dict with all observation.* keys
|
||||
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
|
||||
assert "observation.image.top" in transition[TransitionKey.OBSERVATION]
|
||||
assert "observation.image.left" in transition[TransitionKey.OBSERVATION]
|
||||
assert "observation.state" in transition[TransitionKey.OBSERVATION]
|
||||
assert f"{OBS_IMAGE}.top" in transition[TransitionKey.OBSERVATION]
|
||||
assert f"{OBS_IMAGE}.left" in transition[TransitionKey.OBSERVATION]
|
||||
assert OBS_STATE in transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check values are preserved
|
||||
assert torch.allclose(
|
||||
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
|
||||
transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.top"], batch[f"{OBS_IMAGE}.top"]
|
||||
)
|
||||
assert torch.allclose(
|
||||
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
|
||||
transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.left"], batch[f"{OBS_IMAGE}.left"]
|
||||
)
|
||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
assert transition[TransitionKey.OBSERVATION][OBS_STATE] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4]))
|
||||
@@ -85,9 +86,9 @@ def test_batch_to_transition_observation_grouping():
|
||||
def test_transition_to_batch_observation_flattening():
|
||||
"""Test that transition_to_batch correctly flattens observation dict."""
|
||||
observation_dict = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128),
|
||||
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
|
||||
OBS_STATE: [1, 2, 3, 4],
|
||||
}
|
||||
|
||||
transition = {
|
||||
@@ -103,14 +104,14 @@ def test_transition_to_batch_observation_flattening():
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Check that observation.* keys are flattened back to batch
|
||||
assert "observation.image.top" in batch
|
||||
assert "observation.image.left" in batch
|
||||
assert "observation.state" in batch
|
||||
assert f"{OBS_IMAGE}.top" in batch
|
||||
assert f"{OBS_IMAGE}.left" in batch
|
||||
assert OBS_STATE in batch
|
||||
|
||||
# Check values are preserved
|
||||
assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"])
|
||||
assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"])
|
||||
assert batch["observation.state"] == [1, 2, 3, 4]
|
||||
assert torch.allclose(batch[f"{OBS_IMAGE}.top"], observation_dict[f"{OBS_IMAGE}.top"])
|
||||
assert torch.allclose(batch[f"{OBS_IMAGE}.left"], observation_dict[f"{OBS_IMAGE}.left"])
|
||||
assert batch[OBS_STATE] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields are mapped to next.* format
|
||||
assert batch["action"] == "action_data"
|
||||
@@ -153,12 +154,12 @@ def test_no_observation_keys():
|
||||
|
||||
def test_minimal_batch():
|
||||
"""Test with minimal batch containing only observation.* and action."""
|
||||
batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])}
|
||||
batch = {OBS_STATE: "minimal_state", "action": torch.tensor([0.5])}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert transition[TransitionKey.OBSERVATION] == {OBS_STATE: "minimal_state"}
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5]))
|
||||
|
||||
# Check defaults
|
||||
@@ -170,7 +171,7 @@ def test_minimal_batch():
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||
assert reconstructed_batch[OBS_STATE] == "minimal_state"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5]))
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
@@ -205,9 +206,9 @@ def test_empty_batch():
|
||||
def test_complex_nested_observation():
|
||||
"""Test with complex nested observation data."""
|
||||
batch = {
|
||||
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
||||
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
||||
"observation.state": torch.randn(7),
|
||||
f"{OBS_IMAGE}.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
||||
f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
||||
OBS_STATE: torch.randn(7),
|
||||
"action": torch.randn(8),
|
||||
"next.reward": 3.14,
|
||||
"next.done": False,
|
||||
@@ -219,20 +220,20 @@ def test_complex_nested_observation():
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
|
||||
# Check that all observation keys are preserved
|
||||
original_obs_keys = {k for k in batch if k.startswith("observation.")}
|
||||
reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")}
|
||||
original_obs_keys = {k for k in batch if k.startswith(OBS_PREFIX)}
|
||||
reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith(OBS_PREFIX)}
|
||||
|
||||
assert original_obs_keys == reconstructed_obs_keys
|
||||
|
||||
# Check tensor values
|
||||
assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"])
|
||||
assert torch.allclose(batch[OBS_STATE], reconstructed_batch[OBS_STATE])
|
||||
|
||||
# Check nested dict with tensors
|
||||
assert torch.allclose(
|
||||
batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"]
|
||||
batch[f"{OBS_IMAGE}.top"]["image"], reconstructed_batch[f"{OBS_IMAGE}.top"]["image"]
|
||||
)
|
||||
assert torch.allclose(
|
||||
batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"]
|
||||
batch[f"{OBS_IMAGE}.left"]["image"], reconstructed_batch[f"{OBS_IMAGE}.left"]["image"]
|
||||
)
|
||||
|
||||
# Check action tensor
|
||||
@@ -264,7 +265,7 @@ def test_custom_converter():
|
||||
processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch)
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 4),
|
||||
OBS_STATE: torch.randn(1, 4),
|
||||
"action": torch.randn(1, 2),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
@@ -274,5 +275,5 @@ def test_custom_converter():
|
||||
|
||||
# Check the reward was doubled by our custom converter
|
||||
assert result["next.reward"] == 2.0
|
||||
assert torch.allclose(result["observation.state"], batch["observation.state"])
|
||||
assert torch.allclose(result[OBS_STATE], batch[OBS_STATE])
|
||||
assert torch.allclose(result["action"], batch["action"])
|
||||
|
||||
@@ -9,6 +9,7 @@ from lerobot.processor.converters import (
|
||||
to_tensor,
|
||||
transition_to_batch,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
# Tests for the unified to_tensor function
|
||||
@@ -118,16 +119,16 @@ def test_to_tensor_dictionaries():
|
||||
# Nested dictionary
|
||||
nested = {
|
||||
"action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
|
||||
"observation": {"mean": np.array([0.5, 0.6]), "count": 10},
|
||||
OBS_STR: {"mean": np.array([0.5, 0.6]), "count": 10},
|
||||
}
|
||||
result = to_tensor(nested)
|
||||
assert isinstance(result, dict)
|
||||
assert isinstance(result["action"], dict)
|
||||
assert isinstance(result["observation"], dict)
|
||||
assert isinstance(result[OBS_STR], dict)
|
||||
assert isinstance(result["action"]["mean"], torch.Tensor)
|
||||
assert isinstance(result["observation"]["mean"], torch.Tensor)
|
||||
assert isinstance(result[OBS_STR]["mean"], torch.Tensor)
|
||||
assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6]))
|
||||
assert torch.allclose(result[OBS_STR]["mean"], torch.tensor([0.5, 0.6]))
|
||||
|
||||
|
||||
def test_to_tensor_none_filtering():
|
||||
@@ -198,7 +199,7 @@ def test_batch_to_transition_with_index_fields():
|
||||
|
||||
# Create batch with index and task_index fields
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 7),
|
||||
OBS_STATE: torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
"next.reward": 1.5,
|
||||
"next.done": False,
|
||||
@@ -231,7 +232,7 @@ def testtransition_to_batch_with_index_fields():
|
||||
|
||||
# Create transition with index and task_index in complementary_data
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
observation={OBS_STATE: torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
reward=1.5,
|
||||
done=False,
|
||||
@@ -260,7 +261,7 @@ def test_batch_to_transition_without_index_fields():
|
||||
|
||||
# Batch without index/task_index
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 7),
|
||||
OBS_STATE: torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
"task": ["pick_cube"],
|
||||
}
|
||||
@@ -279,7 +280,7 @@ def test_transition_to_batch_without_index_fields():
|
||||
|
||||
# Transition without index/task_index
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
observation={OBS_STATE: torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
complementary_data={"task": ["navigate"]},
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def test_basic_functionality():
|
||||
@@ -28,7 +29,7 @@ def test_basic_functionality():
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
# Create a transition with CPU tensors
|
||||
observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)}
|
||||
observation = {OBS_STATE: torch.randn(10), OBS_IMAGE: torch.randn(3, 224, 224)}
|
||||
action = torch.randn(5)
|
||||
reward = torch.tensor(1.0)
|
||||
done = torch.tensor(False)
|
||||
@@ -41,8 +42,8 @@ def test_basic_functionality():
|
||||
result = processor(transition)
|
||||
|
||||
# Check that all tensors are on CPU
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
|
||||
assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cpu"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cpu"
|
||||
assert result[TransitionKey.ACTION].device.type == "cpu"
|
||||
assert result[TransitionKey.REWARD].device.type == "cpu"
|
||||
assert result[TransitionKey.DONE].device.type == "cpu"
|
||||
@@ -55,7 +56,7 @@ def test_cuda_functionality():
|
||||
processor = DeviceProcessorStep(device="cuda")
|
||||
|
||||
# Create a transition with CPU tensors
|
||||
observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)}
|
||||
observation = {OBS_STATE: torch.randn(10), OBS_IMAGE: torch.randn(3, 224, 224)}
|
||||
action = torch.randn(5)
|
||||
reward = torch.tensor(1.0)
|
||||
done = torch.tensor(False)
|
||||
@@ -68,8 +69,8 @@ def test_cuda_functionality():
|
||||
result = processor(transition)
|
||||
|
||||
# Check that all tensors are on CUDA
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
assert result[TransitionKey.REWARD].device.type == "cuda"
|
||||
assert result[TransitionKey.DONE].device.type == "cuda"
|
||||
@@ -81,14 +82,14 @@ def test_specific_cuda_device():
|
||||
"""Test device processor with specific CUDA device."""
|
||||
processor = DeviceProcessorStep(device="cuda:0")
|
||||
|
||||
observation = {"observation.state": torch.randn(10)}
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
result = processor(transition)
|
||||
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.index == 0
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.index == 0
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.index == 0
|
||||
|
||||
@@ -98,7 +99,7 @@ def test_non_tensor_values():
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.randn(10),
|
||||
OBS_STATE: torch.randn(10),
|
||||
"observation.metadata": {"key": "value"}, # Non-tensor data
|
||||
"observation.list": [1, 2, 3], # Non-tensor data
|
||||
}
|
||||
@@ -110,7 +111,7 @@ def test_non_tensor_values():
|
||||
result = processor(transition)
|
||||
|
||||
# Check tensors are processed
|
||||
assert isinstance(result[TransitionKey.OBSERVATION]["observation.state"], torch.Tensor)
|
||||
assert isinstance(result[TransitionKey.OBSERVATION][OBS_STATE], torch.Tensor)
|
||||
assert isinstance(result[TransitionKey.ACTION], torch.Tensor)
|
||||
|
||||
# Check non-tensor values are preserved
|
||||
@@ -130,9 +131,9 @@ def test_none_values():
|
||||
assert result[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
# Test with None action
|
||||
transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None)
|
||||
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=None)
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
assert result[TransitionKey.ACTION] is None
|
||||
|
||||
|
||||
@@ -271,9 +272,7 @@ def test_features():
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
features = {
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
||||
},
|
||||
PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
|
||||
}
|
||||
|
||||
@@ -376,7 +375,7 @@ def test_reward_done_truncated_types():
|
||||
|
||||
# Test with scalar values (not tensors)
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(5)},
|
||||
observation={OBS_STATE: torch.randn(5)},
|
||||
action=torch.randn(3),
|
||||
reward=1.0, # float
|
||||
done=False, # bool
|
||||
@@ -392,7 +391,7 @@ def test_reward_done_truncated_types():
|
||||
|
||||
# Test with tensor values
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(5)},
|
||||
observation={OBS_STATE: torch.randn(5)},
|
||||
action=torch.randn(3),
|
||||
reward=torch.tensor(1.0),
|
||||
done=torch.tensor(False),
|
||||
@@ -422,7 +421,7 @@ def test_complementary_data_preserved():
|
||||
}
|
||||
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(5)}, complementary_data=complementary_data
|
||||
observation={OBS_STATE: torch.randn(5)}, complementary_data=complementary_data
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
@@ -491,13 +490,13 @@ def test_float_dtype_bfloat16():
|
||||
"""Test conversion to bfloat16."""
|
||||
processor = DeviceProcessorStep(device="cpu", float_dtype="bfloat16")
|
||||
|
||||
observation = {"observation.state": torch.randn(5, dtype=torch.float32)}
|
||||
observation = {OBS_STATE: torch.randn(5, dtype=torch.float32)}
|
||||
action = torch.randn(3, dtype=torch.float64)
|
||||
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
result = processor(transition)
|
||||
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.bfloat16
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
|
||||
assert result[TransitionKey.ACTION].dtype == torch.bfloat16
|
||||
|
||||
|
||||
@@ -505,13 +504,13 @@ def test_float_dtype_float64():
|
||||
"""Test conversion to float64."""
|
||||
processor = DeviceProcessorStep(device="cpu", float_dtype="float64")
|
||||
|
||||
observation = {"observation.state": torch.randn(5, dtype=torch.float16)}
|
||||
observation = {OBS_STATE: torch.randn(5, dtype=torch.float16)}
|
||||
action = torch.randn(3, dtype=torch.float32)
|
||||
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
result = processor(transition)
|
||||
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float64
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float64
|
||||
assert result[TransitionKey.ACTION].dtype == torch.float64
|
||||
|
||||
|
||||
@@ -541,8 +540,8 @@ def test_float_dtype_with_mixed_tensors():
|
||||
processor = DeviceProcessorStep(device="cpu", float_dtype="float32")
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert
|
||||
"observation.state": torch.randn(10, dtype=torch.float64), # Should convert
|
||||
OBS_IMAGE: torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert
|
||||
OBS_STATE: torch.randn(10, dtype=torch.float64), # Should convert
|
||||
"observation.mask": torch.tensor([True, False, True], dtype=torch.bool), # Should not convert
|
||||
"observation.indices": torch.tensor([1, 2, 3], dtype=torch.long), # Should not convert
|
||||
}
|
||||
@@ -552,8 +551,8 @@ def test_float_dtype_with_mixed_tensors():
|
||||
result = processor(transition)
|
||||
|
||||
# Check conversions
|
||||
assert result[TransitionKey.OBSERVATION]["observation.image"].dtype == torch.uint8 # Unchanged
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted
|
||||
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.uint8 # Unchanged
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float32 # Converted
|
||||
assert result[TransitionKey.OBSERVATION]["observation.mask"].dtype == torch.bool # Unchanged
|
||||
assert result[TransitionKey.OBSERVATION]["observation.indices"].dtype == torch.long # Unchanged
|
||||
assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted
|
||||
@@ -612,7 +611,7 @@ def test_complementary_data_index_fields():
|
||||
"episode_id": 123, # Non-tensor field
|
||||
}
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
observation={OBS_STATE: torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
complementary_data=complementary_data,
|
||||
)
|
||||
@@ -736,7 +735,7 @@ def test_complementary_data_full_pipeline_cuda():
|
||||
processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16")
|
||||
|
||||
# Create full transition with mixed CPU tensors
|
||||
observation = {"observation.state": torch.randn(1, 7, dtype=torch.float32)}
|
||||
observation = {OBS_STATE: torch.randn(1, 7, dtype=torch.float32)}
|
||||
action = torch.randn(1, 4, dtype=torch.float32)
|
||||
reward = torch.tensor(1.5, dtype=torch.float32)
|
||||
done = torch.tensor(False)
|
||||
@@ -757,7 +756,7 @@ def test_complementary_data_full_pipeline_cuda():
|
||||
result = processor(transition)
|
||||
|
||||
# Check all components moved to CUDA
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
assert result[TransitionKey.REWARD].device.type == "cuda"
|
||||
assert result[TransitionKey.DONE].device.type == "cuda"
|
||||
@@ -768,7 +767,7 @@ def test_complementary_data_full_pipeline_cuda():
|
||||
assert processed_comp_data["task_index"].device.type == "cuda"
|
||||
|
||||
# Check float conversion happened for float tensors
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float16
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16
|
||||
assert result[TransitionKey.ACTION].dtype == torch.float16
|
||||
assert result[TransitionKey.REWARD].dtype == torch.float16
|
||||
|
||||
@@ -782,7 +781,7 @@ def test_complementary_data_empty():
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
observation={OBS_STATE: torch.randn(1, 7)},
|
||||
complementary_data={},
|
||||
)
|
||||
|
||||
@@ -797,7 +796,7 @@ def test_complementary_data_none():
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
observation={OBS_STATE: torch.randn(1, 7)},
|
||||
complementary_data=None,
|
||||
)
|
||||
|
||||
@@ -814,8 +813,8 @@ def test_preserves_gpu_placement():
|
||||
|
||||
# Create tensors already on GPU
|
||||
observation = {
|
||||
"observation.state": torch.randn(10).cuda(), # Already on GPU
|
||||
"observation.image": torch.randn(3, 224, 224).cuda(), # Already on GPU
|
||||
OBS_STATE: torch.randn(10).cuda(), # Already on GPU
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).cuda(), # Already on GPU
|
||||
}
|
||||
action = torch.randn(5).cuda() # Already on GPU
|
||||
|
||||
@@ -823,14 +822,12 @@ def test_preserves_gpu_placement():
|
||||
result = processor(transition)
|
||||
|
||||
# Check that tensors remain on their original GPU
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Verify no unnecessary copies were made (same data pointer)
|
||||
assert torch.equal(
|
||||
result[TransitionKey.OBSERVATION]["observation.state"], observation["observation.state"]
|
||||
)
|
||||
assert torch.equal(result[TransitionKey.OBSERVATION][OBS_STATE], observation[OBS_STATE])
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
@@ -842,8 +839,8 @@ def test_multi_gpu_preservation():
|
||||
# Create tensors on cuda:1 (simulating Accelerate placement)
|
||||
cuda1_device = torch.device("cuda:1")
|
||||
observation = {
|
||||
"observation.state": torch.randn(10).to(cuda1_device),
|
||||
"observation.image": torch.randn(3, 224, 224).to(cuda1_device),
|
||||
OBS_STATE: torch.randn(10).to(cuda1_device),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(cuda1_device),
|
||||
}
|
||||
action = torch.randn(5).to(cuda1_device)
|
||||
|
||||
@@ -851,20 +848,20 @@ def test_multi_gpu_preservation():
|
||||
result = processor_gpu(transition)
|
||||
|
||||
# Check that tensors remain on cuda:1 (not moved to cuda:0)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device == cuda1_device
|
||||
assert result[TransitionKey.OBSERVATION]["observation.image"].device == cuda1_device
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device == cuda1_device
|
||||
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device == cuda1_device
|
||||
assert result[TransitionKey.ACTION].device == cuda1_device
|
||||
|
||||
# Test 2: GPU-to-CPU should move to CPU (not preserve GPU)
|
||||
processor_cpu = DeviceProcessorStep(device="cpu")
|
||||
|
||||
transition_gpu = create_transition(
|
||||
observation={"observation.state": torch.randn(10).cuda()}, action=torch.randn(5).cuda()
|
||||
observation={OBS_STATE: torch.randn(10).cuda()}, action=torch.randn(5).cuda()
|
||||
)
|
||||
result_cpu = processor_cpu(transition_gpu)
|
||||
|
||||
# Check that tensors are moved to CPU
|
||||
assert result_cpu[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
|
||||
assert result_cpu[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
assert result_cpu[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@@ -933,14 +930,14 @@ def test_simulated_accelerate_scenario():
|
||||
|
||||
# Simulate data already placed by Accelerate
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
observation = {"observation.state": torch.randn(1, 10).to(device)}
|
||||
observation = {OBS_STATE: torch.randn(1, 10).to(device)}
|
||||
action = torch.randn(1, 5).to(device)
|
||||
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
result = processor(transition)
|
||||
|
||||
# Verify data stays on the GPU where Accelerate placed it
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device == device
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert result[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@@ -1081,7 +1078,7 @@ def test_mps_float64_with_complementary_data():
|
||||
}
|
||||
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(5, dtype=torch.float64)},
|
||||
observation={OBS_STATE: torch.randn(5, dtype=torch.float64)},
|
||||
action=torch.randn(3, dtype=torch.float64),
|
||||
complementary_data=complementary_data,
|
||||
)
|
||||
@@ -1089,7 +1086,7 @@ def test_mps_float64_with_complementary_data():
|
||||
result = processor(transition)
|
||||
|
||||
# Check that all tensors are on MPS device
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "mps"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "mps"
|
||||
assert result[TransitionKey.ACTION].device.type == "mps"
|
||||
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
@@ -1099,7 +1096,7 @@ def test_mps_float64_with_complementary_data():
|
||||
assert processed_comp_data["float32_tensor"].device.type == "mps"
|
||||
|
||||
# Check dtype conversions
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float32 # Converted
|
||||
assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted
|
||||
assert processed_comp_data["float64_tensor"].dtype == torch.float32 # Converted
|
||||
assert processed_comp_data["float32_tensor"].dtype == torch.float32 # Unchanged
|
||||
|
||||
@@ -25,6 +25,7 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
|
||||
def test_is_processor_config_valid_configs():
|
||||
@@ -111,7 +112,7 @@ def test_should_suggest_migration_with_model_config_only():
|
||||
# Create a model config (like old LeRobot format)
|
||||
model_config = {
|
||||
"type": "act",
|
||||
"input_features": {"observation.state": {"shape": [7]}},
|
||||
"input_features": {OBS_STATE: {"shape": [7]}},
|
||||
"output_features": {"action": {"shape": [7]}},
|
||||
"hidden_dim": 256,
|
||||
"n_obs_steps": 1,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -39,8 +39,8 @@ def test_process_single_image():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that the image was processed correctly
|
||||
assert "observation.image" in processed_obs
|
||||
processed_img = processed_obs["observation.image"]
|
||||
assert OBS_IMAGE in processed_obs
|
||||
processed_img = processed_obs[OBS_IMAGE]
|
||||
|
||||
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
|
||||
assert processed_img.shape == (1, 3, 64, 64)
|
||||
@@ -66,12 +66,12 @@ def test_process_image_dict():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that both images were processed
|
||||
assert "observation.images.camera1" in processed_obs
|
||||
assert "observation.images.camera2" in processed_obs
|
||||
assert f"{OBS_IMAGES}.camera1" in processed_obs
|
||||
assert f"{OBS_IMAGES}.camera2" in processed_obs
|
||||
|
||||
# Check shapes
|
||||
assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32)
|
||||
assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48)
|
||||
assert processed_obs[f"{OBS_IMAGES}.camera1"].shape == (1, 3, 32, 32)
|
||||
assert processed_obs[f"{OBS_IMAGES}.camera2"].shape == (1, 3, 48, 48)
|
||||
|
||||
|
||||
def test_process_batched_image():
|
||||
@@ -88,7 +88,7 @@ def test_process_batched_image():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
|
||||
assert processed_obs[OBS_IMAGE].shape == (2, 3, 64, 64)
|
||||
|
||||
|
||||
def test_invalid_image_format():
|
||||
@@ -173,10 +173,10 @@ def test_process_environment_state():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that environment_state was renamed and processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert OBS_ENV_STATE in processed_obs
|
||||
assert "environment_state" not in processed_obs
|
||||
|
||||
processed_state = processed_obs["observation.environment_state"]
|
||||
processed_state = processed_obs[OBS_ENV_STATE]
|
||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
||||
assert processed_state.dtype == torch.float32
|
||||
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
@@ -194,10 +194,10 @@ def test_process_agent_pos():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that agent_pos was renamed and processed
|
||||
assert "observation.state" in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
processed_state = processed_obs["observation.state"]
|
||||
processed_state = processed_obs[OBS_STATE]
|
||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
||||
assert processed_state.dtype == torch.float32
|
||||
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
|
||||
@@ -217,8 +217,8 @@ def test_process_batched_states():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that batch dimensions are preserved
|
||||
assert processed_obs["observation.environment_state"].shape == (2, 2)
|
||||
assert processed_obs["observation.state"].shape == (2, 2)
|
||||
assert processed_obs[OBS_ENV_STATE].shape == (2, 2)
|
||||
assert processed_obs[OBS_STATE].shape == (2, 2)
|
||||
|
||||
|
||||
def test_process_both_states():
|
||||
@@ -235,8 +235,8 @@ def test_process_both_states():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that both states were processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert "observation.state" in processed_obs
|
||||
assert OBS_ENV_STATE in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
|
||||
# Check that original keys were removed
|
||||
assert "environment_state" not in processed_obs
|
||||
@@ -281,12 +281,12 @@ def test_complete_observation_processing():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that image was processed
|
||||
assert "observation.image" in processed_obs
|
||||
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert processed_obs[OBS_IMAGE].shape == (1, 3, 32, 32)
|
||||
|
||||
# Check that states were processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert "observation.state" in processed_obs
|
||||
assert OBS_ENV_STATE in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
|
||||
# Check that original keys were removed
|
||||
assert "pixels" not in processed_obs
|
||||
@@ -308,7 +308,7 @@ def test_image_only_processing():
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.image" in processed_obs
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert len(processed_obs) == 1
|
||||
|
||||
|
||||
@@ -323,7 +323,7 @@ def test_state_only_processing():
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
|
||||
@@ -504,7 +504,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory):
|
||||
proc = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
OBS_ENV_STATE: policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||
},
|
||||
}
|
||||
@@ -513,7 +513,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory):
|
||||
assert (
|
||||
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.environment_state"]
|
||||
== features[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
|
||||
)
|
||||
assert (
|
||||
OBS_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
|
||||
@@ -35,6 +35,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -255,7 +256,7 @@ def test_step_through_with_dict():
|
||||
pipeline = DataProcessorPipeline([step1, step2])
|
||||
|
||||
batch = {
|
||||
"observation.image": None,
|
||||
OBS_IMAGE: None,
|
||||
"action": None,
|
||||
"next.reward": 0.0,
|
||||
"next.done": False,
|
||||
@@ -1840,7 +1841,7 @@ def test_save_load_with_custom_converter_functions():
|
||||
|
||||
# Verify it uses default converters by checking with standard batch format
|
||||
batch = {
|
||||
"observation.image": torch.randn(1, 3, 32, 32),
|
||||
OBS_IMAGE: torch.randn(1, 3, 32, 32),
|
||||
"action": torch.randn(1, 7),
|
||||
"next.reward": torch.tensor([1.0]),
|
||||
"next.done": torch.tensor([False]),
|
||||
@@ -1851,7 +1852,7 @@ def test_save_load_with_custom_converter_functions():
|
||||
# Should work with standard format (wouldn't work with custom converter)
|
||||
result = loaded(batch)
|
||||
# With new behavior, default to_output is _default_transition_to_batch, so result is batch dict
|
||||
assert "observation.image" in result
|
||||
assert OBS_IMAGE in result
|
||||
|
||||
|
||||
class NonCompliantStep:
|
||||
@@ -2075,10 +2076,10 @@ class AddObservationStateFeatures(ProcessorStep):
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# State features (mix EE and a joint state)
|
||||
features[PipelineFeatureType.OBSERVATION]["observation.state.ee.x"] = float
|
||||
features[PipelineFeatureType.OBSERVATION]["observation.state.j1.pos"] = float
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.ee.x"] = float
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.j1.pos"] = float
|
||||
if self.add_front_image:
|
||||
features[PipelineFeatureType.OBSERVATION]["observation.images.front"] = self.front_image_shape
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"] = self.front_image_shape
|
||||
return features
|
||||
|
||||
|
||||
@@ -2094,7 +2095,7 @@ def test_aggregate_joint_action_only():
|
||||
)
|
||||
|
||||
# Expect only "action" with joint names
|
||||
assert "action" in out and "observation.state" not in out
|
||||
assert "action" in out and OBS_STATE not in out
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"}
|
||||
assert out["action"]["shape"] == (len(out["action"]["names"]),)
|
||||
@@ -2108,7 +2109,7 @@ def test_aggregate_ee_action_and_observation_with_videos():
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "observation.state"],
|
||||
patterns=["action.ee", OBS_STATE],
|
||||
)
|
||||
|
||||
# Action should pack only EE names
|
||||
@@ -2117,13 +2118,13 @@ def test_aggregate_ee_action_and_observation_with_videos():
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
|
||||
# Observation state should pack both ee.x and j1.pos as a vector
|
||||
assert "observation.state" in out
|
||||
assert set(out["observation.state"]["names"]) == {"ee.x", "j1.pos"}
|
||||
assert out["observation.state"]["dtype"] == "float32"
|
||||
assert OBS_STATE in out
|
||||
assert set(out[OBS_STATE]["names"]) == {"ee.x", "j1.pos"}
|
||||
assert out[OBS_STATE]["dtype"] == "float32"
|
||||
|
||||
# Cameras from initial_features appear as videos
|
||||
for cam in ("front", "side"):
|
||||
key = f"observation.images.{cam}"
|
||||
key = f"{OBS_IMAGES}.{cam}"
|
||||
assert key in out
|
||||
assert out[key]["dtype"] == "video"
|
||||
assert out[key]["shape"] == initial[cam]
|
||||
@@ -2156,8 +2157,8 @@ def test_aggregate_images_when_use_videos_false():
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
key = "observation.images.back"
|
||||
key_front = "observation.images.front"
|
||||
key = f"{OBS_IMAGES}.back"
|
||||
key_front = f"{OBS_IMAGES}.front"
|
||||
assert key not in out
|
||||
assert key_front not in out
|
||||
|
||||
@@ -2173,8 +2174,8 @@ def test_aggregate_images_when_use_videos_true():
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
key = "observation.images.front"
|
||||
key_back = "observation.images.back"
|
||||
key = f"{OBS_IMAGES}.front"
|
||||
key_back = f"{OBS_IMAGES}.back"
|
||||
assert key in out
|
||||
assert key_back in out
|
||||
assert out[key]["dtype"] == "video"
|
||||
@@ -2194,9 +2195,9 @@ def test_initial_camera_not_overridden_by_step_image():
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||
use_videos=True,
|
||||
patterns=["observation.images.front"],
|
||||
patterns=[f"{OBS_IMAGES}.front"],
|
||||
)
|
||||
|
||||
key = "observation.images.front"
|
||||
key = f"{OBS_IMAGES}.front"
|
||||
assert key in out
|
||||
assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial
|
||||
|
||||
@@ -28,6 +28,7 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -121,13 +122,13 @@ def test_overlapping_rename():
|
||||
def test_partial_rename():
|
||||
"""Test renaming only some keys."""
|
||||
rename_map = {
|
||||
"observation.state": "observation.proprio_state",
|
||||
"pixels": "observation.image",
|
||||
OBS_STATE: "observation.proprio_state",
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.randn(10),
|
||||
OBS_STATE: torch.randn(10),
|
||||
"pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8),
|
||||
"reward": 1.0,
|
||||
"info": {"episode": 1},
|
||||
@@ -139,8 +140,8 @@ def test_partial_rename():
|
||||
|
||||
# Check renamed keys
|
||||
assert "observation.proprio_state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
assert "observation.state" not in processed_obs
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert OBS_STATE not in processed_obs
|
||||
assert "pixels" not in processed_obs
|
||||
|
||||
# Check unchanged keys
|
||||
@@ -174,8 +175,8 @@ def test_state_dict():
|
||||
def test_integration_with_robot_processor():
|
||||
"""Test integration with RobotProcessor pipeline."""
|
||||
rename_map = {
|
||||
"agent_pos": "observation.state",
|
||||
"pixels": "observation.image",
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
rename_processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
@@ -196,8 +197,8 @@ def test_integration_with_robot_processor():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renaming worked through pipeline
|
||||
assert "observation.state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
assert "pixels" not in processed_obs
|
||||
assert processed_obs["other_data"] == "preserve_me"
|
||||
@@ -210,8 +211,8 @@ def test_integration_with_robot_processor():
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading processor with RobotProcessor."""
|
||||
rename_map = {
|
||||
"old_state": "observation.state",
|
||||
"old_image": "observation.image",
|
||||
"old_state": OBS_STATE,
|
||||
"old_image": OBS_IMAGE,
|
||||
}
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep")
|
||||
@@ -253,10 +254,10 @@ def test_save_and_load_pretrained():
|
||||
result = loaded_pipeline(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
assert processed_obs["observation.state"] == [1, 2, 3]
|
||||
assert processed_obs["observation.image"] == "image_data"
|
||||
assert OBS_STATE in processed_obs
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert processed_obs[OBS_STATE] == [1, 2, 3]
|
||||
assert processed_obs[OBS_IMAGE] == "image_data"
|
||||
|
||||
|
||||
def test_registry_functionality():
|
||||
@@ -317,8 +318,8 @@ def test_chained_rename_processors():
|
||||
# Second processor: rename to final format
|
||||
processor2 = RenameObservationsProcessorStep(
|
||||
rename_map={
|
||||
"agent_position": "observation.state",
|
||||
"camera_image": "observation.image",
|
||||
"agent_position": OBS_STATE,
|
||||
"camera_image": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -342,8 +343,8 @@ def test_chained_rename_processors():
|
||||
|
||||
# After second processor
|
||||
final_obs = results[2][TransitionKey.OBSERVATION]
|
||||
assert "observation.state" in final_obs
|
||||
assert "observation.image" in final_obs
|
||||
assert OBS_STATE in final_obs
|
||||
assert OBS_IMAGE in final_obs
|
||||
assert final_obs["extra"] == "keep_me"
|
||||
|
||||
# Original keys should be gone
|
||||
@@ -356,15 +357,15 @@ def test_chained_rename_processors():
|
||||
def test_nested_observation_rename():
|
||||
"""Test renaming with nested observation structures."""
|
||||
rename_map = {
|
||||
"observation.images.left": "observation.camera.left_view",
|
||||
"observation.images.right": "observation.camera.right_view",
|
||||
f"{OBS_IMAGES}.left": "observation.camera.left_view",
|
||||
f"{OBS_IMAGES}.right": "observation.camera.right_view",
|
||||
"observation.proprio": "observation.proprioception",
|
||||
}
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.images.left": torch.randn(3, 64, 64),
|
||||
"observation.images.right": torch.randn(3, 64, 64),
|
||||
f"{OBS_IMAGES}.left": torch.randn(3, 64, 64),
|
||||
f"{OBS_IMAGES}.right": torch.randn(3, 64, 64),
|
||||
"observation.proprio": torch.randn(7),
|
||||
"observation.gripper": torch.tensor([0.0]), # Not renamed
|
||||
}
|
||||
@@ -382,8 +383,8 @@ def test_nested_observation_rename():
|
||||
assert "observation.gripper" in processed_obs
|
||||
|
||||
# Check old keys removed
|
||||
assert "observation.images.left" not in processed_obs
|
||||
assert "observation.images.right" not in processed_obs
|
||||
assert f"{OBS_IMAGES}.left" not in processed_obs
|
||||
assert f"{OBS_IMAGES}.right" not in processed_obs
|
||||
assert "observation.proprio" not in processed_obs
|
||||
|
||||
|
||||
@@ -464,7 +465,7 @@ def test_features_chained_processors(policy_feature_factory):
|
||||
# Chain two rename processors at the contract level
|
||||
processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"})
|
||||
processor2 = RenameObservationsProcessorStep(
|
||||
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
|
||||
rename_map={"agent_position": OBS_STATE, "camera_image": OBS_IMAGE}
|
||||
)
|
||||
pipeline = DataProcessorPipeline([processor1, processor2])
|
||||
|
||||
@@ -477,27 +478,21 @@ def test_features_chained_processors(policy_feature_factory):
|
||||
}
|
||||
out = pipeline.transform_features(initial_features=spec)
|
||||
|
||||
assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"}
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["observation.state"]
|
||||
== spec[PipelineFeatureType.OBSERVATION]["pos"]
|
||||
)
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["observation.image"]
|
||||
== spec[PipelineFeatureType.OBSERVATION]["img"]
|
||||
)
|
||||
assert set(out[PipelineFeatureType.OBSERVATION]) == {OBS_STATE, OBS_IMAGE, "extra"}
|
||||
assert out[PipelineFeatureType.OBSERVATION][OBS_STATE] == spec[PipelineFeatureType.OBSERVATION]["pos"]
|
||||
assert out[PipelineFeatureType.OBSERVATION][OBS_IMAGE] == spec[PipelineFeatureType.OBSERVATION]["img"]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_rename_stats_basic():
|
||||
orig = {
|
||||
"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])},
|
||||
OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])},
|
||||
"action": {"mean": np.array([0.0])},
|
||||
}
|
||||
mapping = {"observation.state": "observation.robot_state"}
|
||||
mapping = {OBS_STATE: "observation.robot_state"}
|
||||
renamed = rename_stats(orig, mapping)
|
||||
assert "observation.robot_state" in renamed and "observation.state" not in renamed
|
||||
assert "observation.robot_state" in renamed and OBS_STATE not in renamed
|
||||
# Ensure deep copy: mutate original and verify renamed unaffected
|
||||
orig["observation.state"]["mean"][0] = 42.0
|
||||
orig[OBS_STATE]["mean"][0] = 42.0
|
||||
assert renamed["observation.robot_state"]["mean"][0] != 42.0
|
||||
|
||||
@@ -11,7 +11,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import OBS_LANGUAGE
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_LANGUAGE, OBS_STATE
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -503,16 +503,14 @@ def test_features_basic():
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=128)
|
||||
|
||||
input_features = {
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
||||
},
|
||||
PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
|
||||
}
|
||||
|
||||
output_features = processor.transform_features(input_features)
|
||||
|
||||
# Check that original features are preserved
|
||||
assert "observation.state" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
assert OBS_STATE in output_features[PipelineFeatureType.OBSERVATION]
|
||||
assert "action" in output_features[PipelineFeatureType.ACTION]
|
||||
|
||||
# Check that tokenized features are added
|
||||
@@ -797,7 +795,7 @@ def test_device_detection_cpu():
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CPU tensors
|
||||
observation = {"observation.state": torch.randn(10)} # CPU tensor
|
||||
observation = {OBS_STATE: torch.randn(10)} # CPU tensor
|
||||
action = torch.randn(5) # CPU tensor
|
||||
transition = create_transition(
|
||||
observation=observation, action=action, complementary_data={"task": "test task"}
|
||||
@@ -821,7 +819,7 @@ def test_device_detection_cuda():
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CUDA tensors
|
||||
observation = {"observation.state": torch.randn(10).cuda()} # CUDA tensor
|
||||
observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor
|
||||
action = torch.randn(5).cuda() # CUDA tensor
|
||||
transition = create_transition(
|
||||
observation=observation, action=action, complementary_data={"task": "test task"}
|
||||
@@ -847,7 +845,7 @@ def test_device_detection_multi_gpu():
|
||||
|
||||
# Test with tensors on cuda:1
|
||||
device = torch.device("cuda:1")
|
||||
observation = {"observation.state": torch.randn(10).to(device)}
|
||||
observation = {OBS_STATE: torch.randn(10).to(device)}
|
||||
action = torch.randn(5).to(device)
|
||||
transition = create_transition(
|
||||
observation=observation, action=action, complementary_data={"task": "multi gpu test"}
|
||||
@@ -943,7 +941,7 @@ def test_device_detection_preserves_dtype():
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with float tensor (to test dtype isn't affected)
|
||||
observation = {"observation.state": torch.randn(10, dtype=torch.float16)}
|
||||
observation = {OBS_STATE: torch.randn(10, dtype=torch.float16)}
|
||||
transition = create_transition(observation=observation, complementary_data={"task": "dtype test"})
|
||||
|
||||
result = processor(transition)
|
||||
@@ -977,7 +975,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
|
||||
|
||||
# Start with CPU tensors
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(10)}, # CPU
|
||||
observation={OBS_STATE: torch.randn(10)}, # CPU
|
||||
action=torch.randn(5), # CPU
|
||||
complementary_data={"task": "pipeline test"},
|
||||
)
|
||||
@@ -985,7 +983,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
|
||||
result = robot_processor(transition)
|
||||
|
||||
# All tensors should end up on CUDA (moved by DeviceProcessorStep)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Tokenized tensors should also be on CUDA
|
||||
@@ -1005,8 +1003,8 @@ def test_simulated_accelerate_scenario():
|
||||
# Simulate Accelerate scenario: batch already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
"observation.state": torch.randn(1, 10).to(device), # Batched, on GPU
|
||||
"observation.image": torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU
|
||||
OBS_STATE: torch.randn(1, 10).to(device), # Batched, on GPU
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU
|
||||
}
|
||||
action = torch.randn(1, 5).to(device) # Batched, on GPU
|
||||
|
||||
|
||||
Reference in New Issue
Block a user