chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -9,6 +9,7 @@ from lerobot.processor.converters import (
|
||||
to_tensor,
|
||||
transition_to_batch,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
# Tests for the unified to_tensor function
|
||||
@@ -118,16 +119,16 @@ def test_to_tensor_dictionaries():
|
||||
# Nested dictionary
|
||||
nested = {
|
||||
"action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
|
||||
"observation": {"mean": np.array([0.5, 0.6]), "count": 10},
|
||||
OBS_STR: {"mean": np.array([0.5, 0.6]), "count": 10},
|
||||
}
|
||||
result = to_tensor(nested)
|
||||
assert isinstance(result, dict)
|
||||
assert isinstance(result["action"], dict)
|
||||
assert isinstance(result["observation"], dict)
|
||||
assert isinstance(result[OBS_STR], dict)
|
||||
assert isinstance(result["action"]["mean"], torch.Tensor)
|
||||
assert isinstance(result["observation"]["mean"], torch.Tensor)
|
||||
assert isinstance(result[OBS_STR]["mean"], torch.Tensor)
|
||||
assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6]))
|
||||
assert torch.allclose(result[OBS_STR]["mean"], torch.tensor([0.5, 0.6]))
|
||||
|
||||
|
||||
def test_to_tensor_none_filtering():
|
||||
@@ -198,7 +199,7 @@ def test_batch_to_transition_with_index_fields():
|
||||
|
||||
# Create batch with index and task_index fields
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 7),
|
||||
OBS_STATE: torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
"next.reward": 1.5,
|
||||
"next.done": False,
|
||||
@@ -231,7 +232,7 @@ def testtransition_to_batch_with_index_fields():
|
||||
|
||||
# Create transition with index and task_index in complementary_data
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
observation={OBS_STATE: torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
reward=1.5,
|
||||
done=False,
|
||||
@@ -260,7 +261,7 @@ def test_batch_to_transition_without_index_fields():
|
||||
|
||||
# Batch without index/task_index
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 7),
|
||||
OBS_STATE: torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
"task": ["pick_cube"],
|
||||
}
|
||||
@@ -279,7 +280,7 @@ def test_transition_to_batch_without_index_fields():
|
||||
|
||||
# Transition without index/task_index
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
observation={OBS_STATE: torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
complementary_data={"task": ["navigate"]},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user