forked from tangger/lerobot
chore: replace hard-coded next values with constants throughout all the source code (#2056)
This commit is contained in:
@@ -35,7 +35,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -258,9 +258,9 @@ def test_step_through_with_dict():
|
||||
batch = {
|
||||
OBS_IMAGE: None,
|
||||
ACTION: None,
|
||||
"next.reward": 0.0,
|
||||
"next.done": False,
|
||||
"next.truncated": False,
|
||||
REWARD: 0.0,
|
||||
DONE: False,
|
||||
TRUNCATED: False,
|
||||
"info": {},
|
||||
}
|
||||
|
||||
@@ -1843,9 +1843,9 @@ def test_save_load_with_custom_converter_functions():
|
||||
batch = {
|
||||
OBS_IMAGE: torch.randn(1, 3, 32, 32),
|
||||
ACTION: torch.randn(1, 7),
|
||||
"next.reward": torch.tensor([1.0]),
|
||||
"next.done": torch.tensor([False]),
|
||||
"next.truncated": torch.tensor([False]),
|
||||
REWARD: torch.tensor([1.0]),
|
||||
DONE: torch.tensor([False]),
|
||||
TRUNCATED: torch.tensor([False]),
|
||||
"info": {},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user