forked from tangger/lerobot
chore: replace hard-coded action values with constants throughout all the source code (#2055)
* chore: replace hard-coded 'action' values with constants throughout all the source code * chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
@@ -21,7 +21,7 @@ from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
@@ -59,14 +59,14 @@ def test_calculate_episode_data_index():
|
||||
|
||||
def test_merge_simple_vectors():
|
||||
g1 = {
|
||||
"action": {
|
||||
ACTION: {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["ee.x", "ee.y"],
|
||||
}
|
||||
}
|
||||
g2 = {
|
||||
"action": {
|
||||
ACTION: {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["ee.y", "ee.z"],
|
||||
@@ -75,23 +75,23 @@ def test_merge_simple_vectors():
|
||||
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
|
||||
assert "action" in out
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
assert ACTION in out
|
||||
assert out[ACTION]["dtype"] == "float32"
|
||||
# Names merged with preserved order and de-dupuplication
|
||||
assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"]
|
||||
assert out[ACTION]["names"] == ["ee.x", "ee.y", "ee.z"]
|
||||
# Shape correctly recomputed from names length
|
||||
assert out["action"]["shape"] == (3,)
|
||||
assert out[ACTION]["shape"] == (3,)
|
||||
|
||||
|
||||
def test_merge_multiple_groups_order_and_dedup():
|
||||
g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}}
|
||||
g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
|
||||
g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
|
||||
g1 = {ACTION: {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}}
|
||||
g2 = {ACTION: {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
|
||||
g3 = {ACTION: {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
|
||||
|
||||
out = combine_feature_dicts(g1, g2, g3)
|
||||
|
||||
assert out["action"]["names"] == ["a", "b", "c", "d"]
|
||||
assert out["action"]["shape"] == (4,)
|
||||
assert out[ACTION]["names"] == ["a", "b", "c", "d"]
|
||||
assert out[ACTION]["shape"] == (4,)
|
||||
|
||||
|
||||
def test_non_vector_last_wins_for_images():
|
||||
@@ -117,8 +117,8 @@ def test_non_vector_last_wins_for_images():
|
||||
|
||||
|
||||
def test_dtype_mismatch_raises():
|
||||
g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}}
|
||||
g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}}
|
||||
g1 = {ACTION: {"dtype": "float32", "shape": (1,), "names": ["a"]}}
|
||||
g2 = {ACTION: {"dtype": "float64", "shape": (1,), "names": ["b"]}}
|
||||
|
||||
with pytest.raises(ValueError, match="dtype mismatch for 'action'"):
|
||||
_ = combine_feature_dicts(g1, g2)
|
||||
|
||||
@@ -46,7 +46,7 @@ from lerobot.datasets.utils import (
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
@@ -75,7 +75,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
"""
|
||||
# Instantiate both ways
|
||||
robot = make_robot_from_config(MockRobotConfig())
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action", True)
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION, True)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR, True)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
root_create = tmp_path / "create"
|
||||
@@ -393,7 +393,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
item = dataset[0]
|
||||
|
||||
keys_ndim_required = [
|
||||
("action", 1, True),
|
||||
(ACTION, 1, True),
|
||||
("episode_index", 0, True),
|
||||
("frame_index", 0, True),
|
||||
("timestamp", 0, True),
|
||||
@@ -668,7 +668,7 @@ def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"],
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"],
|
||||
@@ -775,7 +775,7 @@ def test_update_chunk_settings_video_dataset(tmp_path):
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]},
|
||||
ACTION: {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]},
|
||||
}
|
||||
|
||||
# Create video dataset
|
||||
@@ -842,7 +842,7 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact
|
||||
"""Test episode metadata consistency across multiple episodes."""
|
||||
features = {
|
||||
"state": {"dtype": "float32", "shape": (3,), "names": ["x", "y", "z"]},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["v", "w"]},
|
||||
ACTION: {"dtype": "float32", "shape": (2,), "names": ["v", "w"]},
|
||||
}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
@@ -852,7 +852,7 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact
|
||||
|
||||
for episode_idx in range(num_episodes):
|
||||
for _ in range(frames_per_episode[episode_idx]):
|
||||
dataset.add_frame({"state": torch.randn(3), "action": torch.randn(2), "task": tasks[episode_idx]})
|
||||
dataset.add_frame({"state": torch.randn(3), ACTION: torch.randn(2), "task": tasks[episode_idx]})
|
||||
dataset.save_episode()
|
||||
|
||||
# Load and validate episode metadata
|
||||
@@ -927,7 +927,7 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory)
|
||||
"""Test that statistics are properly computed and stored for all features."""
|
||||
features = {
|
||||
"state": {"dtype": "float32", "shape": (2,), "names": ["pos", "vel"]},
|
||||
"action": {"dtype": "float32", "shape": (1,), "names": ["force"]},
|
||||
ACTION: {"dtype": "float32", "shape": (1,), "names": ["force"]},
|
||||
}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
@@ -941,7 +941,7 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory)
|
||||
for frame_idx in range(frames_per_episode[episode_idx]):
|
||||
state_data = torch.tensor([frame_idx * 0.1, frame_idx * 0.2], dtype=torch.float32)
|
||||
action_data = torch.tensor([frame_idx * 0.05], dtype=torch.float32)
|
||||
dataset.add_frame({"state": state_data, "action": action_data, "task": "stats_test"})
|
||||
dataset.add_frame({"state": state_data, ACTION: action_data, "task": "stats_test"})
|
||||
dataset.save_episode()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch
|
||||
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.utils import safe_shard
|
||||
from lerobot.utils.constants import ACTION
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
@@ -234,7 +235,7 @@ def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_
|
||||
delta_timestamps = {
|
||||
camera_key: state_deltas,
|
||||
"state": state_deltas,
|
||||
"action": action_deltas,
|
||||
ACTION: action_deltas,
|
||||
}
|
||||
|
||||
ds = lerobot_dataset_factory(
|
||||
@@ -319,7 +320,7 @@ def test_frames_with_delta_consistency_with_shards(
|
||||
delta_timestamps = {
|
||||
camera_key: state_deltas,
|
||||
"state": state_deltas,
|
||||
"action": action_deltas,
|
||||
ACTION: action_deltas,
|
||||
}
|
||||
|
||||
ds = lerobot_dataset_factory(
|
||||
|
||||
Reference in New Issue
Block a user