chore: replace hard-coded obs values with constants throughout all the source code (#2037)

* chore: replace hard-coded OBS values with constants throughout all the source code

* chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-25 15:36:47 +02:00
committed by GitHub
parent ddba994d73
commit 43d878a102
52 changed files with 659 additions and 649 deletions

View File

@@ -9,6 +9,7 @@ from lerobot.processor.converters import (
to_tensor,
transition_to_batch,
)
from lerobot.utils.constants import OBS_STATE, OBS_STR
# Tests for the unified to_tensor function
@@ -118,16 +119,16 @@ def test_to_tensor_dictionaries():
# Nested dictionary
nested = {
"action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
"observation": {"mean": np.array([0.5, 0.6]), "count": 10},
OBS_STR: {"mean": np.array([0.5, 0.6]), "count": 10},
}
result = to_tensor(nested)
assert isinstance(result, dict)
assert isinstance(result["action"], dict)
assert isinstance(result["observation"], dict)
assert isinstance(result[OBS_STR], dict)
assert isinstance(result["action"]["mean"], torch.Tensor)
assert isinstance(result["observation"]["mean"], torch.Tensor)
assert isinstance(result[OBS_STR]["mean"], torch.Tensor)
assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2]))
assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6]))
assert torch.allclose(result[OBS_STR]["mean"], torch.tensor([0.5, 0.6]))
def test_to_tensor_none_filtering():
@@ -198,7 +199,7 @@ def test_batch_to_transition_with_index_fields():
# Create batch with index and task_index fields
batch = {
"observation.state": torch.randn(1, 7),
OBS_STATE: torch.randn(1, 7),
"action": torch.randn(1, 4),
"next.reward": 1.5,
"next.done": False,
@@ -231,7 +232,7 @@ def testtransition_to_batch_with_index_fields():
# Create transition with index and task_index in complementary_data
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
observation={OBS_STATE: torch.randn(1, 7)},
action=torch.randn(1, 4),
reward=1.5,
done=False,
@@ -260,7 +261,7 @@ def test_batch_to_transition_without_index_fields():
# Batch without index/task_index
batch = {
"observation.state": torch.randn(1, 7),
OBS_STATE: torch.randn(1, 7),
"action": torch.randn(1, 4),
"task": ["pick_cube"],
}
@@ -279,7 +280,7 @@ def test_transition_to_batch_without_index_fields():
# Transition without index/task_index
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
observation={OBS_STATE: torch.randn(1, 7)},
action=torch.randn(1, 4),
complementary_data={"task": ["navigate"]},
)