chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -30,6 +30,7 @@ from lerobot.async_inference.helpers import (
|
||||
resize_robot_observation_image,
|
||||
)
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# FPSTracker
|
||||
@@ -115,7 +116,7 @@ def test_timed_action_getters():
|
||||
def test_timed_observation_getters():
|
||||
"""TimedObservation stores & returns timestamp, dict and timestep."""
|
||||
ts = time.time()
|
||||
obs_dict = {"observation.state": torch.ones(6)}
|
||||
obs_dict = {OBS_STATE: torch.ones(6)}
|
||||
to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0)
|
||||
|
||||
assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
@@ -151,7 +152,7 @@ def test_timed_data_deserialization_data_getters():
|
||||
# ------------------------------------------------------------------
|
||||
# TimedObservation
|
||||
# ------------------------------------------------------------------
|
||||
obs_dict = {"observation.state": torch.arange(4).float()}
|
||||
obs_dict = {OBS_STATE: torch.arange(4).float()}
|
||||
to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True)
|
||||
|
||||
to_bytes = pickle.dumps(to_in) # nosec
|
||||
@@ -161,7 +162,7 @@ def test_timed_data_deserialization_data_getters():
|
||||
assert to_out.get_timestep() == 7
|
||||
assert to_out.must_go is True
|
||||
assert to_out.get_observation().keys() == obs_dict.keys()
|
||||
torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"])
|
||||
torch.testing.assert_close(to_out.get_observation()[OBS_STATE], obs_dict[OBS_STATE])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
@@ -187,7 +188,7 @@ def test_observations_similar_true():
|
||||
"""Distance below atol → observations considered similar."""
|
||||
# Create mock lerobot features for the similarity check
|
||||
lerobot_features = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
@@ -222,17 +223,17 @@ def _create_mock_robot_observation():
|
||||
def _create_mock_lerobot_features():
|
||||
"""Create mock lerobot features mapping similar to what hw_to_dataset_features returns."""
|
||||
return {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
"observation.images.laptop": {
|
||||
f"{OBS_IMAGES}.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"observation.images.phone": {
|
||||
f"{OBS_IMAGES}.phone": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
@@ -243,11 +244,11 @@ def _create_mock_lerobot_features():
|
||||
def _create_mock_policy_image_features():
|
||||
"""Create mock policy image features with different resolutions."""
|
||||
return {
|
||||
"observation.images.laptop": PolicyFeature(
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Policy expects smaller resolution
|
||||
),
|
||||
"observation.images.phone": PolicyFeature(
|
||||
f"{OBS_IMAGES}.phone": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 160, 160), # Different resolution for second camera
|
||||
),
|
||||
@@ -306,21 +307,21 @@ def test_prepare_raw_observation():
|
||||
prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that state is properly extracted and batched
|
||||
assert "observation.state" in prepared
|
||||
state = prepared["observation.state"]
|
||||
assert OBS_STATE in prepared
|
||||
state = prepared[OBS_STATE]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.shape == (1, 4) # Batched state
|
||||
|
||||
# Check that images are processed and resized
|
||||
assert "observation.images.laptop" in prepared
|
||||
assert "observation.images.phone" in prepared
|
||||
assert f"{OBS_IMAGES}.laptop" in prepared
|
||||
assert f"{OBS_IMAGES}.phone" in prepared
|
||||
|
||||
laptop_img = prepared["observation.images.laptop"]
|
||||
phone_img = prepared["observation.images.phone"]
|
||||
laptop_img = prepared[f"{OBS_IMAGES}.laptop"]
|
||||
phone_img = prepared[f"{OBS_IMAGES}.phone"]
|
||||
|
||||
# Check image shapes match policy requirements
|
||||
assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape
|
||||
assert phone_img.shape == policy_image_features["observation.images.phone"].shape
|
||||
assert laptop_img.shape == policy_image_features[f"{OBS_IMAGES}.laptop"].shape
|
||||
assert phone_img.shape == policy_image_features[f"{OBS_IMAGES}.phone"].shape
|
||||
|
||||
# Check that images are tensors
|
||||
assert isinstance(laptop_img, torch.Tensor)
|
||||
@@ -337,19 +338,19 @@ def test_raw_observation_to_observation_basic():
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Check that all expected keys are present
|
||||
assert "observation.state" in observation
|
||||
assert "observation.images.laptop" in observation
|
||||
assert "observation.images.phone" in observation
|
||||
assert OBS_STATE in observation
|
||||
assert f"{OBS_IMAGES}.laptop" in observation
|
||||
assert f"{OBS_IMAGES}.phone" in observation
|
||||
|
||||
# Check state processing
|
||||
state = observation["observation.state"]
|
||||
state = observation[OBS_STATE]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.device.type == device
|
||||
assert state.shape == (1, 4) # Batched
|
||||
|
||||
# Check image processing
|
||||
laptop_img = observation["observation.images.laptop"]
|
||||
phone_img = observation["observation.images.phone"]
|
||||
laptop_img = observation[f"{OBS_IMAGES}.laptop"]
|
||||
phone_img = observation[f"{OBS_IMAGES}.phone"]
|
||||
|
||||
# Images should have batch dimension: (B, C, H, W)
|
||||
assert laptop_img.shape == (1, 3, 224, 224)
|
||||
@@ -429,19 +430,19 @@ def test_image_processing_pipeline_preserves_content():
|
||||
|
||||
robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img}
|
||||
lerobot_features = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
"observation.images.laptop": {
|
||||
f"{OBS_IMAGES}.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [100, 100, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
}
|
||||
policy_image_features = {
|
||||
"observation.images.laptop": PolicyFeature(
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 50, 50), # Downsamples from 100x100
|
||||
)
|
||||
@@ -449,7 +450,7 @@ def test_image_processing_pipeline_preserves_content():
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
|
||||
|
||||
processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim
|
||||
processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim
|
||||
|
||||
# Check that the center region has higher values than corners
|
||||
# Due to bilinear interpolation, exact values will change but pattern should remain
|
||||
|
||||
@@ -23,6 +23,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from tests.utils import require_package
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -44,7 +45,7 @@ class MockPolicy:
|
||||
|
||||
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return a chunk of 20 dummy actions."""
|
||||
batch_size = len(observation["observation.state"])
|
||||
batch_size = len(observation[OBS_STATE])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
def __init__(self):
|
||||
@@ -77,7 +78,7 @@ def policy_server():
|
||||
|
||||
# Add mock lerobot_features that the observation similarity functions need
|
||||
server.lerobot_features = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [6],
|
||||
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],
|
||||
|
||||
Reference in New Issue
Block a user