forked from tangger/lerobot
chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -35,6 +35,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -255,7 +256,7 @@ def test_step_through_with_dict():
|
||||
pipeline = DataProcessorPipeline([step1, step2])
|
||||
|
||||
batch = {
|
||||
"observation.image": None,
|
||||
OBS_IMAGE: None,
|
||||
"action": None,
|
||||
"next.reward": 0.0,
|
||||
"next.done": False,
|
||||
@@ -1840,7 +1841,7 @@ def test_save_load_with_custom_converter_functions():
|
||||
|
||||
# Verify it uses default converters by checking with standard batch format
|
||||
batch = {
|
||||
"observation.image": torch.randn(1, 3, 32, 32),
|
||||
OBS_IMAGE: torch.randn(1, 3, 32, 32),
|
||||
"action": torch.randn(1, 7),
|
||||
"next.reward": torch.tensor([1.0]),
|
||||
"next.done": torch.tensor([False]),
|
||||
@@ -1851,7 +1852,7 @@ def test_save_load_with_custom_converter_functions():
|
||||
# Should work with standard format (wouldn't work with custom converter)
|
||||
result = loaded(batch)
|
||||
# With new behavior, default to_output is _default_transition_to_batch, so result is batch dict
|
||||
assert "observation.image" in result
|
||||
assert OBS_IMAGE in result
|
||||
|
||||
|
||||
class NonCompliantStep:
|
||||
@@ -2075,10 +2076,10 @@ class AddObservationStateFeatures(ProcessorStep):
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# State features (mix EE and a joint state)
|
||||
features[PipelineFeatureType.OBSERVATION]["observation.state.ee.x"] = float
|
||||
features[PipelineFeatureType.OBSERVATION]["observation.state.j1.pos"] = float
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.ee.x"] = float
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.j1.pos"] = float
|
||||
if self.add_front_image:
|
||||
features[PipelineFeatureType.OBSERVATION]["observation.images.front"] = self.front_image_shape
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"] = self.front_image_shape
|
||||
return features
|
||||
|
||||
|
||||
@@ -2094,7 +2095,7 @@ def test_aggregate_joint_action_only():
|
||||
)
|
||||
|
||||
# Expect only "action" with joint names
|
||||
assert "action" in out and "observation.state" not in out
|
||||
assert "action" in out and OBS_STATE not in out
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"}
|
||||
assert out["action"]["shape"] == (len(out["action"]["names"]),)
|
||||
@@ -2108,7 +2109,7 @@ def test_aggregate_ee_action_and_observation_with_videos():
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "observation.state"],
|
||||
patterns=["action.ee", OBS_STATE],
|
||||
)
|
||||
|
||||
# Action should pack only EE names
|
||||
@@ -2117,13 +2118,13 @@ def test_aggregate_ee_action_and_observation_with_videos():
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
|
||||
# Observation state should pack both ee.x and j1.pos as a vector
|
||||
assert "observation.state" in out
|
||||
assert set(out["observation.state"]["names"]) == {"ee.x", "j1.pos"}
|
||||
assert out["observation.state"]["dtype"] == "float32"
|
||||
assert OBS_STATE in out
|
||||
assert set(out[OBS_STATE]["names"]) == {"ee.x", "j1.pos"}
|
||||
assert out[OBS_STATE]["dtype"] == "float32"
|
||||
|
||||
# Cameras from initial_features appear as videos
|
||||
for cam in ("front", "side"):
|
||||
key = f"observation.images.{cam}"
|
||||
key = f"{OBS_IMAGES}.{cam}"
|
||||
assert key in out
|
||||
assert out[key]["dtype"] == "video"
|
||||
assert out[key]["shape"] == initial[cam]
|
||||
@@ -2156,8 +2157,8 @@ def test_aggregate_images_when_use_videos_false():
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
key = "observation.images.back"
|
||||
key_front = "observation.images.front"
|
||||
key = f"{OBS_IMAGES}.back"
|
||||
key_front = f"{OBS_IMAGES}.front"
|
||||
assert key not in out
|
||||
assert key_front not in out
|
||||
|
||||
@@ -2173,8 +2174,8 @@ def test_aggregate_images_when_use_videos_true():
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
key = "observation.images.front"
|
||||
key_back = "observation.images.back"
|
||||
key = f"{OBS_IMAGES}.front"
|
||||
key_back = f"{OBS_IMAGES}.back"
|
||||
assert key in out
|
||||
assert key_back in out
|
||||
assert out[key]["dtype"] == "video"
|
||||
@@ -2194,9 +2195,9 @@ def test_initial_camera_not_overridden_by_step_image():
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||
use_videos=True,
|
||||
patterns=["observation.images.front"],
|
||||
patterns=[f"{OBS_IMAGES}.front"],
|
||||
)
|
||||
|
||||
key = "observation.images.front"
|
||||
key = f"{OBS_IMAGES}.front"
|
||||
assert key in out
|
||||
assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial
|
||||
|
||||
Reference in New Issue
Block a user