chore: replace hard-coded obs values with constants throughout all the source code (#2037)

* chore: replace hard-coded OBS values with constants throughout all the source code

* chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-25 15:36:47 +02:00
committed by GitHub
parent ddba994d73
commit 43d878a102
52 changed files with 659 additions and 649 deletions

View File

@@ -28,6 +28,7 @@ from lerobot.processor import (
)
from lerobot.processor.converters import create_transition, identity_transition
from lerobot.processor.rename_processor import rename_stats
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE
from tests.conftest import assert_contract_is_typed
@@ -121,13 +122,13 @@ def test_overlapping_rename():
def test_partial_rename():
"""Test renaming only some keys."""
rename_map = {
"observation.state": "observation.proprio_state",
"pixels": "observation.image",
OBS_STATE: "observation.proprio_state",
"pixels": OBS_IMAGE,
}
processor = RenameObservationsProcessorStep(rename_map=rename_map)
observation = {
"observation.state": torch.randn(10),
OBS_STATE: torch.randn(10),
"pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8),
"reward": 1.0,
"info": {"episode": 1},
@@ -139,8 +140,8 @@ def test_partial_rename():
# Check renamed keys
assert "observation.proprio_state" in processed_obs
assert "observation.image" in processed_obs
assert "observation.state" not in processed_obs
assert OBS_IMAGE in processed_obs
assert OBS_STATE not in processed_obs
assert "pixels" not in processed_obs
# Check unchanged keys
@@ -174,8 +175,8 @@ def test_state_dict():
def test_integration_with_robot_processor():
"""Test integration with RobotProcessor pipeline."""
rename_map = {
"agent_pos": "observation.state",
"pixels": "observation.image",
"agent_pos": OBS_STATE,
"pixels": OBS_IMAGE,
}
rename_processor = RenameObservationsProcessorStep(rename_map=rename_map)
@@ -196,8 +197,8 @@ def test_integration_with_robot_processor():
processed_obs = result[TransitionKey.OBSERVATION]
# Check renaming worked through pipeline
assert "observation.state" in processed_obs
assert "observation.image" in processed_obs
assert OBS_STATE in processed_obs
assert OBS_IMAGE in processed_obs
assert "agent_pos" not in processed_obs
assert "pixels" not in processed_obs
assert processed_obs["other_data"] == "preserve_me"
@@ -210,8 +211,8 @@ def test_integration_with_robot_processor():
def test_save_and_load_pretrained():
"""Test saving and loading processor with RobotProcessor."""
rename_map = {
"old_state": "observation.state",
"old_image": "observation.image",
"old_state": OBS_STATE,
"old_image": OBS_IMAGE,
}
processor = RenameObservationsProcessorStep(rename_map=rename_map)
pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep")
@@ -253,10 +254,10 @@ def test_save_and_load_pretrained():
result = loaded_pipeline(transition)
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs
assert "observation.image" in processed_obs
assert processed_obs["observation.state"] == [1, 2, 3]
assert processed_obs["observation.image"] == "image_data"
assert OBS_STATE in processed_obs
assert OBS_IMAGE in processed_obs
assert processed_obs[OBS_STATE] == [1, 2, 3]
assert processed_obs[OBS_IMAGE] == "image_data"
def test_registry_functionality():
@@ -317,8 +318,8 @@ def test_chained_rename_processors():
# Second processor: rename to final format
processor2 = RenameObservationsProcessorStep(
rename_map={
"agent_position": "observation.state",
"camera_image": "observation.image",
"agent_position": OBS_STATE,
"camera_image": OBS_IMAGE,
}
)
@@ -342,8 +343,8 @@ def test_chained_rename_processors():
# After second processor
final_obs = results[2][TransitionKey.OBSERVATION]
assert "observation.state" in final_obs
assert "observation.image" in final_obs
assert OBS_STATE in final_obs
assert OBS_IMAGE in final_obs
assert final_obs["extra"] == "keep_me"
# Original keys should be gone
@@ -356,15 +357,15 @@ def test_chained_rename_processors():
def test_nested_observation_rename():
"""Test renaming with nested observation structures."""
rename_map = {
"observation.images.left": "observation.camera.left_view",
"observation.images.right": "observation.camera.right_view",
f"{OBS_IMAGES}.left": "observation.camera.left_view",
f"{OBS_IMAGES}.right": "observation.camera.right_view",
"observation.proprio": "observation.proprioception",
}
processor = RenameObservationsProcessorStep(rename_map=rename_map)
observation = {
"observation.images.left": torch.randn(3, 64, 64),
"observation.images.right": torch.randn(3, 64, 64),
f"{OBS_IMAGES}.left": torch.randn(3, 64, 64),
f"{OBS_IMAGES}.right": torch.randn(3, 64, 64),
"observation.proprio": torch.randn(7),
"observation.gripper": torch.tensor([0.0]), # Not renamed
}
@@ -382,8 +383,8 @@ def test_nested_observation_rename():
assert "observation.gripper" in processed_obs
# Check old keys removed
assert "observation.images.left" not in processed_obs
assert "observation.images.right" not in processed_obs
assert f"{OBS_IMAGES}.left" not in processed_obs
assert f"{OBS_IMAGES}.right" not in processed_obs
assert "observation.proprio" not in processed_obs
@@ -464,7 +465,7 @@ def test_features_chained_processors(policy_feature_factory):
# Chain two rename processors at the contract level
processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"})
processor2 = RenameObservationsProcessorStep(
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
rename_map={"agent_position": OBS_STATE, "camera_image": OBS_IMAGE}
)
pipeline = DataProcessorPipeline([processor1, processor2])
@@ -477,27 +478,21 @@ def test_features_chained_processors(policy_feature_factory):
}
out = pipeline.transform_features(initial_features=spec)
assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"}
assert (
out[PipelineFeatureType.OBSERVATION]["observation.state"]
== spec[PipelineFeatureType.OBSERVATION]["pos"]
)
assert (
out[PipelineFeatureType.OBSERVATION]["observation.image"]
== spec[PipelineFeatureType.OBSERVATION]["img"]
)
assert set(out[PipelineFeatureType.OBSERVATION]) == {OBS_STATE, OBS_IMAGE, "extra"}
assert out[PipelineFeatureType.OBSERVATION][OBS_STATE] == spec[PipelineFeatureType.OBSERVATION]["pos"]
assert out[PipelineFeatureType.OBSERVATION][OBS_IMAGE] == spec[PipelineFeatureType.OBSERVATION]["img"]
assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"]
assert_contract_is_typed(out)
def test_rename_stats_basic():
orig = {
"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])},
OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])},
"action": {"mean": np.array([0.0])},
}
mapping = {"observation.state": "observation.robot_state"}
mapping = {OBS_STATE: "observation.robot_state"}
renamed = rename_stats(orig, mapping)
assert "observation.robot_state" in renamed and "observation.state" not in renamed
assert "observation.robot_state" in renamed and OBS_STATE not in renamed
# Ensure deep copy: mutate original and verify renamed unaffected
orig["observation.state"]["mean"][0] = 42.0
orig[OBS_STATE]["mean"][0] = 42.0
assert renamed["observation.robot_state"]["mean"][0] != 42.0