chore: replace hard-coded next values with constants throughout all the source code (#2056)
This commit is contained in:
@@ -22,7 +22,7 @@ import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
@@ -380,9 +380,9 @@ def test_to_lerobot_dataset(tmp_path):
|
||||
for feature, value in ds[i].items():
|
||||
if feature == ACTION:
|
||||
assert torch.equal(value, buffer.actions[i])
|
||||
elif feature == "next.reward":
|
||||
elif feature == REWARD:
|
||||
assert torch.equal(value, buffer.rewards[i])
|
||||
elif feature == "next.done":
|
||||
elif feature == DONE:
|
||||
assert torch.equal(value, buffer.dones[i])
|
||||
elif feature == OBS_IMAGE:
|
||||
# Tensor -> numpy is not precise, so we have some diff there
|
||||
|
||||
Reference in New Issue
Block a user