chore: replace hard-coded action values with constants throughout all the source code (#2055)

* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-26 13:33:18 +02:00
committed by GitHub
parent 9627765ce2
commit d2782cf66b
47 changed files with 269 additions and 255 deletions

View File

@@ -21,7 +21,7 @@ from huggingface_hub import DatasetCard
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.constants import ACTION, OBS_IMAGES
def test_default_parameters():
@@ -59,14 +59,14 @@ def test_calculate_episode_data_index():
def test_merge_simple_vectors():
g1 = {
"action": {
ACTION: {
"dtype": "float32",
"shape": (2,),
"names": ["ee.x", "ee.y"],
}
}
g2 = {
"action": {
ACTION: {
"dtype": "float32",
"shape": (2,),
"names": ["ee.y", "ee.z"],
@@ -75,23 +75,23 @@ def test_merge_simple_vectors():
out = combine_feature_dicts(g1, g2)
assert "action" in out
assert out["action"]["dtype"] == "float32"
assert ACTION in out
assert out[ACTION]["dtype"] == "float32"
# Names merged with preserved order and de-dupuplication
assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"]
assert out[ACTION]["names"] == ["ee.x", "ee.y", "ee.z"]
# Shape correctly recomputed from names length
assert out["action"]["shape"] == (3,)
assert out[ACTION]["shape"] == (3,)
def test_merge_multiple_groups_order_and_dedup():
g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}}
g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
g1 = {ACTION: {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}}
g2 = {ACTION: {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
g3 = {ACTION: {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
out = combine_feature_dicts(g1, g2, g3)
assert out["action"]["names"] == ["a", "b", "c", "d"]
assert out["action"]["shape"] == (4,)
assert out[ACTION]["names"] == ["a", "b", "c", "d"]
assert out[ACTION]["shape"] == (4,)
def test_non_vector_last_wins_for_images():
@@ -117,8 +117,8 @@ def test_non_vector_last_wins_for_images():
def test_dtype_mismatch_raises():
g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}}
g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}}
g1 = {ACTION: {"dtype": "float32", "shape": (1,), "names": ["a"]}}
g2 = {ACTION: {"dtype": "float64", "shape": (1,), "names": ["b"]}}
with pytest.raises(ValueError, match="dtype mismatch for 'action'"):
_ = combine_feature_dicts(g1, g2)