forked from tangger/lerobot
chore: replace hard-coded action values with constants throughout all the source code (#2055)
* chore: replace hard-coded 'action' values with constants throughout all the source code * chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
@@ -9,7 +9,7 @@ from lerobot.processor.converters import (
|
||||
to_tensor,
|
||||
transition_to_batch,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
# Tests for the unified to_tensor function
|
||||
@@ -118,16 +118,16 @@ def test_to_tensor_dictionaries():
|
||||
|
||||
# Nested dictionary
|
||||
nested = {
|
||||
"action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
|
||||
ACTION: {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
|
||||
OBS_STR: {"mean": np.array([0.5, 0.6]), "count": 10},
|
||||
}
|
||||
result = to_tensor(nested)
|
||||
assert isinstance(result, dict)
|
||||
assert isinstance(result["action"], dict)
|
||||
assert isinstance(result[ACTION], dict)
|
||||
assert isinstance(result[OBS_STR], dict)
|
||||
assert isinstance(result["action"]["mean"], torch.Tensor)
|
||||
assert isinstance(result[ACTION]["mean"], torch.Tensor)
|
||||
assert isinstance(result[OBS_STR]["mean"], torch.Tensor)
|
||||
assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(result[ACTION]["mean"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(result[OBS_STR]["mean"], torch.tensor([0.5, 0.6]))
|
||||
|
||||
|
||||
@@ -200,7 +200,7 @@ def test_batch_to_transition_with_index_fields():
|
||||
# Create batch with index and task_index fields
|
||||
batch = {
|
||||
OBS_STATE: torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
ACTION: torch.randn(1, 4),
|
||||
"next.reward": 1.5,
|
||||
"next.done": False,
|
||||
"task": ["pick_cube"],
|
||||
@@ -262,7 +262,7 @@ def test_batch_to_transition_without_index_fields():
|
||||
# Batch without index/task_index
|
||||
batch = {
|
||||
OBS_STATE: torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
ACTION: torch.randn(1, 4),
|
||||
"task": ["pick_cube"],
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user