forked from tangger/lerobot
chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -28,6 +28,7 @@ from lerobot.datasets.compute_stats import (
|
||||
sample_images,
|
||||
sample_indices,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
@@ -136,21 +137,21 @@ def test_get_feature_stats_single_value():
|
||||
|
||||
def test_compute_episode_stats():
|
||||
episode_data = {
|
||||
"observation.image": [f"image_{i}.jpg" for i in range(100)],
|
||||
"observation.state": np.random.rand(100, 10),
|
||||
OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)],
|
||||
OBS_STATE: np.random.rand(100, 10),
|
||||
}
|
||||
features = {
|
||||
"observation.image": {"dtype": "image"},
|
||||
"observation.state": {"dtype": "numeric"},
|
||||
OBS_IMAGE: {"dtype": "image"},
|
||||
OBS_STATE: {"dtype": "numeric"},
|
||||
}
|
||||
|
||||
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
assert "observation.image" in stats and "observation.state" in stats
|
||||
assert stats["observation.image"]["count"].item() == 100
|
||||
assert stats["observation.state"]["count"].item() == 100
|
||||
assert stats["observation.image"]["mean"].shape == (3, 1, 1)
|
||||
assert OBS_IMAGE in stats and OBS_STATE in stats
|
||||
assert stats[OBS_IMAGE]["count"].item() == 100
|
||||
assert stats[OBS_STATE]["count"].item() == 100
|
||||
assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_valid():
|
||||
@@ -224,38 +225,38 @@ def test_aggregate_feature_stats():
|
||||
def test_aggregate_stats():
|
||||
all_stats = [
|
||||
{
|
||||
"observation.image": {
|
||||
OBS_IMAGE: {
|
||||
"min": [1, 2, 3],
|
||||
"max": [10, 20, 30],
|
||||
"mean": [5.5, 10.5, 15.5],
|
||||
"std": [2.87, 5.87, 8.87],
|
||||
"count": 10,
|
||||
},
|
||||
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
|
||||
OBS_STATE: {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
|
||||
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
|
||||
},
|
||||
{
|
||||
"observation.image": {
|
||||
OBS_IMAGE: {
|
||||
"min": [2, 1, 0],
|
||||
"max": [15, 10, 5],
|
||||
"mean": [8.5, 5.5, 2.5],
|
||||
"std": [3.42, 2.42, 1.42],
|
||||
"count": 15,
|
||||
},
|
||||
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
|
||||
OBS_STATE: {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
|
||||
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
|
||||
},
|
||||
]
|
||||
|
||||
expected_agg_stats = {
|
||||
"observation.image": {
|
||||
OBS_IMAGE: {
|
||||
"min": [1, 1, 0],
|
||||
"max": [15, 20, 30],
|
||||
"mean": [7.3, 7.5, 7.7],
|
||||
"std": [3.5317, 4.8267, 8.5581],
|
||||
"count": 25,
|
||||
},
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"min": 1,
|
||||
"max": 15,
|
||||
"mean": 7.3,
|
||||
@@ -283,7 +284,7 @@ def test_aggregate_stats():
|
||||
for fkey, stats in ep_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
if fkey == OBS_IMAGE and k != "count":
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
@@ -292,7 +293,7 @@ def test_aggregate_stats():
|
||||
for fkey, stats in expected_agg_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
if fkey == OBS_IMAGE and k != "count":
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
|
||||
@@ -21,6 +21,7 @@ from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
@@ -96,14 +97,14 @@ def test_merge_multiple_groups_order_and_dedup():
|
||||
def test_non_vector_last_wins_for_images():
|
||||
# Non-vector (images) with same name should be overwritten by the last image specified
|
||||
g1 = {
|
||||
"observation.images.front": {
|
||||
f"{OBS_IMAGES}.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 480, 640),
|
||||
"names": ["channels", "height", "width"],
|
||||
}
|
||||
}
|
||||
g2 = {
|
||||
"observation.images.front": {
|
||||
f"{OBS_IMAGES}.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 720, 1280),
|
||||
"names": ["channels", "height", "width"],
|
||||
@@ -111,8 +112,8 @@ def test_non_vector_last_wins_for_images():
|
||||
}
|
||||
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
assert out["observation.images.front"]["shape"] == (3, 720, 1280)
|
||||
assert out["observation.images.front"]["dtype"] == "image"
|
||||
assert out[f"{OBS_IMAGES}.front"]["shape"] == (3, 720, 1280)
|
||||
assert out[f"{OBS_IMAGES}.front"]["dtype"] == "image"
|
||||
|
||||
|
||||
def test_dtype_mismatch_raises():
|
||||
|
||||
@@ -46,6 +46,7 @@ from lerobot.datasets.utils import (
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
@@ -75,7 +76,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
# Instantiate both ways
|
||||
robot = make_robot_from_config(MockRobotConfig())
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action", True)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR, True)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(
|
||||
@@ -397,7 +398,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
("frame_index", 0, True),
|
||||
("timestamp", 0, True),
|
||||
# TODO(rcadene): should we rename it agent_pos?
|
||||
("observation.state", 1, True),
|
||||
(OBS_STATE, 1, True),
|
||||
("next.reward", 0, False),
|
||||
("next.done", 0, False),
|
||||
]
|
||||
@@ -662,7 +663,7 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
|
||||
def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test the update_chunk_settings functionality for both LeRobotDataset and LeRobotDatasetMetadata."""
|
||||
features = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"],
|
||||
@@ -769,7 +770,7 @@ def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
|
||||
def test_update_chunk_settings_video_dataset(tmp_path):
|
||||
"""Test update_chunk_settings with a video dataset to ensure video-specific logic works."""
|
||||
features = {
|
||||
"observation.images.cam": {
|
||||
f"{OBS_IMAGES}.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
|
||||
Reference in New Issue
Block a user