chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -28,6 +28,7 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -121,13 +122,13 @@ def test_overlapping_rename():
|
||||
def test_partial_rename():
|
||||
"""Test renaming only some keys."""
|
||||
rename_map = {
|
||||
"observation.state": "observation.proprio_state",
|
||||
"pixels": "observation.image",
|
||||
OBS_STATE: "observation.proprio_state",
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.randn(10),
|
||||
OBS_STATE: torch.randn(10),
|
||||
"pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8),
|
||||
"reward": 1.0,
|
||||
"info": {"episode": 1},
|
||||
@@ -139,8 +140,8 @@ def test_partial_rename():
|
||||
|
||||
# Check renamed keys
|
||||
assert "observation.proprio_state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
assert "observation.state" not in processed_obs
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert OBS_STATE not in processed_obs
|
||||
assert "pixels" not in processed_obs
|
||||
|
||||
# Check unchanged keys
|
||||
@@ -174,8 +175,8 @@ def test_state_dict():
|
||||
def test_integration_with_robot_processor():
|
||||
"""Test integration with RobotProcessor pipeline."""
|
||||
rename_map = {
|
||||
"agent_pos": "observation.state",
|
||||
"pixels": "observation.image",
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
rename_processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
@@ -196,8 +197,8 @@ def test_integration_with_robot_processor():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renaming worked through pipeline
|
||||
assert "observation.state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
assert "pixels" not in processed_obs
|
||||
assert processed_obs["other_data"] == "preserve_me"
|
||||
@@ -210,8 +211,8 @@ def test_integration_with_robot_processor():
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading processor with RobotProcessor."""
|
||||
rename_map = {
|
||||
"old_state": "observation.state",
|
||||
"old_image": "observation.image",
|
||||
"old_state": OBS_STATE,
|
||||
"old_image": OBS_IMAGE,
|
||||
}
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep")
|
||||
@@ -253,10 +254,10 @@ def test_save_and_load_pretrained():
|
||||
result = loaded_pipeline(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
assert processed_obs["observation.state"] == [1, 2, 3]
|
||||
assert processed_obs["observation.image"] == "image_data"
|
||||
assert OBS_STATE in processed_obs
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert processed_obs[OBS_STATE] == [1, 2, 3]
|
||||
assert processed_obs[OBS_IMAGE] == "image_data"
|
||||
|
||||
|
||||
def test_registry_functionality():
|
||||
@@ -317,8 +318,8 @@ def test_chained_rename_processors():
|
||||
# Second processor: rename to final format
|
||||
processor2 = RenameObservationsProcessorStep(
|
||||
rename_map={
|
||||
"agent_position": "observation.state",
|
||||
"camera_image": "observation.image",
|
||||
"agent_position": OBS_STATE,
|
||||
"camera_image": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -342,8 +343,8 @@ def test_chained_rename_processors():
|
||||
|
||||
# After second processor
|
||||
final_obs = results[2][TransitionKey.OBSERVATION]
|
||||
assert "observation.state" in final_obs
|
||||
assert "observation.image" in final_obs
|
||||
assert OBS_STATE in final_obs
|
||||
assert OBS_IMAGE in final_obs
|
||||
assert final_obs["extra"] == "keep_me"
|
||||
|
||||
# Original keys should be gone
|
||||
@@ -356,15 +357,15 @@ def test_chained_rename_processors():
|
||||
def test_nested_observation_rename():
|
||||
"""Test renaming with nested observation structures."""
|
||||
rename_map = {
|
||||
"observation.images.left": "observation.camera.left_view",
|
||||
"observation.images.right": "observation.camera.right_view",
|
||||
f"{OBS_IMAGES}.left": "observation.camera.left_view",
|
||||
f"{OBS_IMAGES}.right": "observation.camera.right_view",
|
||||
"observation.proprio": "observation.proprioception",
|
||||
}
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.images.left": torch.randn(3, 64, 64),
|
||||
"observation.images.right": torch.randn(3, 64, 64),
|
||||
f"{OBS_IMAGES}.left": torch.randn(3, 64, 64),
|
||||
f"{OBS_IMAGES}.right": torch.randn(3, 64, 64),
|
||||
"observation.proprio": torch.randn(7),
|
||||
"observation.gripper": torch.tensor([0.0]), # Not renamed
|
||||
}
|
||||
@@ -382,8 +383,8 @@ def test_nested_observation_rename():
|
||||
assert "observation.gripper" in processed_obs
|
||||
|
||||
# Check old keys removed
|
||||
assert "observation.images.left" not in processed_obs
|
||||
assert "observation.images.right" not in processed_obs
|
||||
assert f"{OBS_IMAGES}.left" not in processed_obs
|
||||
assert f"{OBS_IMAGES}.right" not in processed_obs
|
||||
assert "observation.proprio" not in processed_obs
|
||||
|
||||
|
||||
@@ -464,7 +465,7 @@ def test_features_chained_processors(policy_feature_factory):
|
||||
# Chain two rename processors at the contract level
|
||||
processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"})
|
||||
processor2 = RenameObservationsProcessorStep(
|
||||
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
|
||||
rename_map={"agent_position": OBS_STATE, "camera_image": OBS_IMAGE}
|
||||
)
|
||||
pipeline = DataProcessorPipeline([processor1, processor2])
|
||||
|
||||
@@ -477,27 +478,21 @@ def test_features_chained_processors(policy_feature_factory):
|
||||
}
|
||||
out = pipeline.transform_features(initial_features=spec)
|
||||
|
||||
assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"}
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["observation.state"]
|
||||
== spec[PipelineFeatureType.OBSERVATION]["pos"]
|
||||
)
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["observation.image"]
|
||||
== spec[PipelineFeatureType.OBSERVATION]["img"]
|
||||
)
|
||||
assert set(out[PipelineFeatureType.OBSERVATION]) == {OBS_STATE, OBS_IMAGE, "extra"}
|
||||
assert out[PipelineFeatureType.OBSERVATION][OBS_STATE] == spec[PipelineFeatureType.OBSERVATION]["pos"]
|
||||
assert out[PipelineFeatureType.OBSERVATION][OBS_IMAGE] == spec[PipelineFeatureType.OBSERVATION]["img"]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_rename_stats_basic():
|
||||
orig = {
|
||||
"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])},
|
||||
OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])},
|
||||
"action": {"mean": np.array([0.0])},
|
||||
}
|
||||
mapping = {"observation.state": "observation.robot_state"}
|
||||
mapping = {OBS_STATE: "observation.robot_state"}
|
||||
renamed = rename_stats(orig, mapping)
|
||||
assert "observation.robot_state" in renamed and "observation.state" not in renamed
|
||||
assert "observation.robot_state" in renamed and OBS_STATE not in renamed
|
||||
# Ensure deep copy: mutate original and verify renamed unaffected
|
||||
orig["observation.state"]["mean"][0] = 42.0
|
||||
orig[OBS_STATE]["mean"][0] = 42.0
|
||||
assert renamed["observation.robot_state"]["mean"][0] != 42.0
|
||||
|
||||
Reference in New Issue
Block a user