chore: replace hard-coded obs values with constants throughout all the source code (#2037)

* chore: replace hard-coded OBS values with constants throughout all the source code

* chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-25 15:36:47 +02:00
committed by GitHub
parent ddba994d73
commit 43d878a102
52 changed files with 659 additions and 649 deletions

View File

@@ -30,6 +30,7 @@ from lerobot.async_inference.helpers import (
resize_robot_observation_image,
)
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
# ---------------------------------------------------------------------
# FPSTracker
@@ -115,7 +116,7 @@ def test_timed_action_getters():
def test_timed_observation_getters():
"""TimedObservation stores & returns timestamp, dict and timestep."""
ts = time.time()
obs_dict = {"observation.state": torch.ones(6)}
obs_dict = {OBS_STATE: torch.ones(6)}
to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0)
assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
@@ -151,7 +152,7 @@ def test_timed_data_deserialization_data_getters():
# ------------------------------------------------------------------
# TimedObservation
# ------------------------------------------------------------------
obs_dict = {"observation.state": torch.arange(4).float()}
obs_dict = {OBS_STATE: torch.arange(4).float()}
to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True)
to_bytes = pickle.dumps(to_in) # nosec
@@ -161,7 +162,7 @@ def test_timed_data_deserialization_data_getters():
assert to_out.get_timestep() == 7
assert to_out.must_go is True
assert to_out.get_observation().keys() == obs_dict.keys()
torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"])
torch.testing.assert_close(to_out.get_observation()[OBS_STATE], obs_dict[OBS_STATE])
# ---------------------------------------------------------------------
@@ -187,7 +188,7 @@ def test_observations_similar_true():
"""Distance below atol → observations considered similar."""
# Create mock lerobot features for the similarity check
lerobot_features = {
"observation.state": {
OBS_STATE: {
"dtype": "float32",
"shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"],
@@ -222,17 +223,17 @@ def _create_mock_robot_observation():
def _create_mock_lerobot_features():
"""Create mock lerobot features mapping similar to what hw_to_dataset_features returns."""
return {
"observation.state": {
OBS_STATE: {
"dtype": "float32",
"shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"],
},
"observation.images.laptop": {
f"{OBS_IMAGES}.laptop": {
"dtype": "image",
"shape": [480, 640, 3],
"names": ["height", "width", "channels"],
},
"observation.images.phone": {
f"{OBS_IMAGES}.phone": {
"dtype": "image",
"shape": [480, 640, 3],
"names": ["height", "width", "channels"],
@@ -243,11 +244,11 @@ def _create_mock_lerobot_features():
def _create_mock_policy_image_features():
"""Create mock policy image features with different resolutions."""
return {
"observation.images.laptop": PolicyFeature(
f"{OBS_IMAGES}.laptop": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224), # Policy expects smaller resolution
),
"observation.images.phone": PolicyFeature(
f"{OBS_IMAGES}.phone": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 160, 160), # Different resolution for second camera
),
@@ -306,21 +307,21 @@ def test_prepare_raw_observation():
prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features)
# Check that state is properly extracted and batched
assert "observation.state" in prepared
state = prepared["observation.state"]
assert OBS_STATE in prepared
state = prepared[OBS_STATE]
assert isinstance(state, torch.Tensor)
assert state.shape == (1, 4) # Batched state
# Check that images are processed and resized
assert "observation.images.laptop" in prepared
assert "observation.images.phone" in prepared
assert f"{OBS_IMAGES}.laptop" in prepared
assert f"{OBS_IMAGES}.phone" in prepared
laptop_img = prepared["observation.images.laptop"]
phone_img = prepared["observation.images.phone"]
laptop_img = prepared[f"{OBS_IMAGES}.laptop"]
phone_img = prepared[f"{OBS_IMAGES}.phone"]
# Check image shapes match policy requirements
assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape
assert phone_img.shape == policy_image_features["observation.images.phone"].shape
assert laptop_img.shape == policy_image_features[f"{OBS_IMAGES}.laptop"].shape
assert phone_img.shape == policy_image_features[f"{OBS_IMAGES}.phone"].shape
# Check that images are tensors
assert isinstance(laptop_img, torch.Tensor)
@@ -337,19 +338,19 @@ def test_raw_observation_to_observation_basic():
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Check that all expected keys are present
assert "observation.state" in observation
assert "observation.images.laptop" in observation
assert "observation.images.phone" in observation
assert OBS_STATE in observation
assert f"{OBS_IMAGES}.laptop" in observation
assert f"{OBS_IMAGES}.phone" in observation
# Check state processing
state = observation["observation.state"]
state = observation[OBS_STATE]
assert isinstance(state, torch.Tensor)
assert state.device.type == device
assert state.shape == (1, 4) # Batched
# Check image processing
laptop_img = observation["observation.images.laptop"]
phone_img = observation["observation.images.phone"]
laptop_img = observation[f"{OBS_IMAGES}.laptop"]
phone_img = observation[f"{OBS_IMAGES}.phone"]
# Images should have batch dimension: (B, C, H, W)
assert laptop_img.shape == (1, 3, 224, 224)
@@ -429,19 +430,19 @@ def test_image_processing_pipeline_preserves_content():
robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img}
lerobot_features = {
"observation.state": {
OBS_STATE: {
"dtype": "float32",
"shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"],
},
"observation.images.laptop": {
f"{OBS_IMAGES}.laptop": {
"dtype": "image",
"shape": [100, 100, 3],
"names": ["height", "width", "channels"],
},
}
policy_image_features = {
"observation.images.laptop": PolicyFeature(
f"{OBS_IMAGES}.laptop": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 50, 50), # Downsamples from 100x100
)
@@ -449,7 +450,7 @@ def test_image_processing_pipeline_preserves_content():
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim
processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim
# Check that the center region has higher values than corners
# Due to bilinear interpolation, exact values will change but pattern should remain

View File

@@ -23,6 +23,7 @@ import pytest
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.utils.constants import OBS_STATE
from tests.utils import require_package
# -----------------------------------------------------------------------------
@@ -44,7 +45,7 @@ class MockPolicy:
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return a chunk of 20 dummy actions."""
batch_size = len(observation["observation.state"])
batch_size = len(observation[OBS_STATE])
return torch.zeros(batch_size, 20, 6)
def __init__(self):
@@ -77,7 +78,7 @@ def policy_server():
# Add mock lerobot_features that the observation similarity functions need
server.lerobot_features = {
"observation.state": {
OBS_STATE: {
"dtype": "float32",
"shape": [6],
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],