From c5b5955c5acf15e1b3f0ace6ae72612f98c1fe06 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 26 Sep 2025 14:30:07 +0200 Subject: [PATCH] chore: replace hard-coded next values with constants throughout all the source code (#2056) --- src/lerobot/datasets/factory.py | 4 +- src/lerobot/processor/converters.py | 14 ++-- src/lerobot/rl/buffer.py | 16 ++--- src/lerobot/rl/crop_dataset_roi.py | 3 +- src/lerobot/rl/gym_manipulator.py | 10 +-- src/lerobot/scripts/lerobot_dataset_viz.py | 10 +-- src/lerobot/scripts/lerobot_eval.py | 6 +- tests/datasets/test_datasets.py | 6 +- .../hilserl/test_modeling_classifier.py | 10 +-- tests/processor/test_batch_conversion.py | 68 +++++++++---------- tests/processor/test_converters.py | 6 +- tests/processor/test_pipeline.py | 14 ++-- tests/utils/test_replay_buffer.py | 6 +- 13 files changed, 87 insertions(+), 86 deletions(-) diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index f74b6ac4..f3ceb2b0 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import ( ) from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms -from lerobot.utils.constants import ACTION, OBS_PREFIX +from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD IMAGENET_STATS = { "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) @@ -55,7 +55,7 @@ def resolve_delta_timestamps( """ delta_timestamps = {} for key in ds_meta.features: - if key == "next.reward" and cfg.reward_delta_indices is not None: + if key == REWARD and cfg.reward_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] if key == ACTION and cfg.action_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 68f9dd6f..6b0b6759 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -23,7 +23,7 @@ from typing import Any import numpy as np import torch -from lerobot.utils.constants import ACTION, OBS_PREFIX +from lerobot.utils.constants import ACTION, DONE, OBS_PREFIX, REWARD, TRUNCATED from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey @@ -355,9 +355,9 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: return create_transition( observation=observation_keys if observation_keys else None, action=batch.get(ACTION), - reward=batch.get("next.reward", 0.0), - done=batch.get("next.done", False), - truncated=batch.get("next.truncated", False), + reward=batch.get(REWARD, 0.0), + done=batch.get(DONE, False), + truncated=batch.get(TRUNCATED, False), info=batch.get("info", {}), complementary_data=complementary_data if complementary_data else None, ) @@ -380,9 +380,9 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: batch = { ACTION: transition.get(TransitionKey.ACTION), - "next.reward": transition.get(TransitionKey.REWARD, 0.0), - "next.done": transition.get(TransitionKey.DONE, False), - "next.truncated": transition.get(TransitionKey.TRUNCATED, False), + REWARD: transition.get(TransitionKey.REWARD, 0.0), + DONE: transition.get(TransitionKey.DONE, False), + TRUNCATED: transition.get(TransitionKey.TRUNCATED, False), "info": transition.get(TransitionKey.INFO, {}), } diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index b572bbce..d30b6508 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -24,7 +24,7 @@ import torch.nn.functional as F # noqa: N812 from tqdm import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.utils.constants import ACTION, OBS_IMAGE +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, REWARD from lerobot.utils.transition import Transition @@ -534,8 +534,8 @@ class ReplayBuffer: features[ACTION] = act_info # Add "reward" and "done" - features["next.reward"] = {"dtype": "float32", "shape": (1,)} - features["next.done"] = {"dtype": "bool", "shape": (1,)} + features[REWARD] = {"dtype": "float32", "shape": (1,)} + features[DONE] = {"dtype": "bool", "shape": (1,)} # Add state keys for key in self.states: @@ -578,8 +578,8 @@ class ReplayBuffer: # Fill action, reward, done frame_dict[ACTION] = self.actions[actual_idx].cpu() - frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() - frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() + frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() + frame_dict[DONE] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() frame_dict["task"] = task_name # Add complementary_info if available @@ -648,7 +648,7 @@ class ReplayBuffer: # Check if the dataset has "next.done" key sample = dataset[0] - has_done_key = "next.done" in sample + has_done_key = DONE in sample # Check for complementary_info keys complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")] @@ -671,11 +671,11 @@ class ReplayBuffer: action = current_sample[ACTION].unsqueeze(0) # Add batch dimension # ----- 3) Reward and done ----- - reward = float(current_sample["next.reward"].item()) # ensure float + reward = float(current_sample[REWARD].item()) # ensure float # Determine done flag - use next.done if available, otherwise infer from episode boundaries if has_done_key: - done = bool(current_sample["next.done"].item()) # ensure bool + done = bool(current_sample[DONE].item()) # ensure bool else: # If this is the last frame or if next frame is in a different episode, mark as done done = False diff --git a/src/lerobot/rl/crop_dataset_roi.py b/src/lerobot/rl/crop_dataset_roi.py index c4318c41..281069e1 100644 --- a/src/lerobot/rl/crop_dataset_roi.py +++ b/src/lerobot/rl/crop_dataset_roi.py @@ -25,6 +25,7 @@ import torchvision.transforms.functional as F # type: ignore # noqa: N812 from tqdm import tqdm # type: ignore from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import DONE, REWARD def select_rect_roi(img): @@ -212,7 +213,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset( for key, value in frame.items(): if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"): continue - if key in ("next.done", "next.reward"): + if key in (DONE, REWARD): # if not isinstance(value, str) and len(value.shape) == 0: value = value.unsqueeze(0) diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index fa9f4e3e..ad36f1b3 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -73,7 +73,7 @@ from lerobot.teleoperators import ( ) from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents -from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say @@ -602,8 +602,8 @@ def control_loop( action_features = teleop_device.action_features features = { ACTION: action_features, - "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, - "next.done": {"dtype": "bool", "shape": (1,), "names": None}, + REWARD: {"dtype": "float32", "shape": (1,), "names": None}, + DONE: {"dtype": "bool", "shape": (1,), "names": None}, } if use_gripper: features["complementary_info.discrete_penalty"] = { @@ -673,8 +673,8 @@ def control_loop( frame = { **observations, ACTION: action_to_record.cpu(), - "next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32), - "next.done": np.array([terminated or truncated], dtype=bool), + REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32), + DONE: np.array([terminated or truncated], dtype=bool), } if use_gripper: discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0) diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index adff5c08..55708d9a 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -75,7 +75,7 @@ import torch.utils.data import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.utils.constants import ACTION, OBS_STATE +from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD class EpisodeSampler(torch.utils.data.Sampler): @@ -166,11 +166,11 @@ def visualize_dataset( for dim_idx, val in enumerate(batch[OBS_STATE][i]): rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) - if "next.done" in batch: - rr.log("next.done", rr.Scalar(batch["next.done"][i].item())) + if DONE in batch: + rr.log(DONE, rr.Scalar(batch[DONE][i].item())) - if "next.reward" in batch: - rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item())) + if REWARD in batch: + rr.log(REWARD, rr.Scalar(batch[REWARD][i].item())) if "next.success" in batch: rr.log("next.success", rr.Scalar(batch["next.success"][i].item())) diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 882aeacc..d45be5c4 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -81,7 +81,7 @@ from lerobot.envs.utils import ( from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import PolicyAction, PolicyProcessorPipeline -from lerobot.utils.constants import ACTION, OBS_STR +from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -451,9 +451,9 @@ def _compile_episode_data( "episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)), "frame_index": torch.arange(0, num_frames - 1, 1), "timestamp": torch.arange(0, num_frames - 1, 1) / fps, - "next.done": rollout_data["done"][ep_ix, : num_frames - 1], + DONE: rollout_data["done"][ep_ix, : num_frames - 1], "next.success": rollout_data["success"][ep_ix, : num_frames - 1], - "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32), + REWARD: rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32), } # For the last observation frame, all other keys will just be copy padded. diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index fcfef677..b9e966fe 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -46,7 +46,7 @@ from lerobot.datasets.utils import ( from lerobot.envs.factory import make_env_config from lerobot.policies.factory import make_policy_config from lerobot.robots import make_robot_from_config -from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.mocks.mock_robot import MockRobotConfig from tests.utils import require_x86_64_kernel @@ -399,8 +399,8 @@ def test_factory(env_name, repo_id, policy_name): ("timestamp", 0, True), # TODO(rcadene): should we rename it agent_pos? (OBS_STATE, 1, True), - ("next.reward", 0, False), - ("next.done", 0, False), + (REWARD, 0, False), + (DONE, 0, False), ] # test number of dimensions diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py index 7a878223..a572ea9e 100644 --- a/tests/policies/hilserl/test_modeling_classifier.py +++ b/tests/policies/hilserl/test_modeling_classifier.py @@ -19,7 +19,7 @@ import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput -from lerobot.utils.constants import OBS_IMAGE +from lerobot.utils.constants import OBS_IMAGE, REWARD from tests.utils import require_package @@ -45,7 +45,7 @@ def test_binary_classifier_with_default_params(): OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { - "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)), + REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)), } config.normalization_mapping = { "VISUAL": NormalizationMode.IDENTITY, @@ -58,7 +58,7 @@ def test_binary_classifier_with_default_params(): input = { OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)), - "next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(), + REWARD: torch.randint(low=0, high=2, size=(batch_size,)).float(), } images, labels = classifier.extract_images_and_labels(input) @@ -87,7 +87,7 @@ def test_multiclass_classifier(): OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { - "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)), + REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)), } config.num_cameras = 1 config.num_classes = num_classes @@ -97,7 +97,7 @@ def test_multiclass_classifier(): input = { OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)), - "next.reward": torch.rand((batch_size, num_classes)), + REWARD: torch.rand((batch_size, num_classes)), } images, labels = classifier.extract_images_and_labels(input) diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 0f701897..88b87312 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -2,7 +2,7 @@ import torch from lerobot.processor import DataProcessorPipeline, TransitionKey from lerobot.processor.converters import batch_to_transition, transition_to_batch -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_PREFIX, OBS_STATE +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, REWARD, TRUNCATED def _dummy_batch(): @@ -12,9 +12,9 @@ def _dummy_batch(): f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128), OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]), ACTION: torch.tensor([[0.5]]), - "next.reward": 1.0, - "next.done": False, - "next.truncated": False, + REWARD: 1.0, + DONE: False, + TRUNCATED: False, "info": {"key": "value"}, } @@ -38,9 +38,9 @@ def test_observation_grouping_roundtrip(): # Check other fields assert torch.allclose(batch_out[ACTION], batch_in[ACTION]) - assert batch_out["next.reward"] == batch_in["next.reward"] - assert batch_out["next.done"] == batch_in["next.done"] - assert batch_out["next.truncated"] == batch_in["next.truncated"] + assert batch_out[REWARD] == batch_in[REWARD] + assert batch_out[DONE] == batch_in[DONE] + assert batch_out[TRUNCATED] == batch_in[TRUNCATED] assert batch_out["info"] == batch_in["info"] @@ -51,9 +51,9 @@ def test_batch_to_transition_observation_grouping(): f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), OBS_STATE: [1, 2, 3, 4], ACTION: torch.tensor([0.1, 0.2, 0.3, 0.4]), - "next.reward": 1.5, - "next.done": True, - "next.truncated": False, + REWARD: 1.5, + DONE: True, + TRUNCATED: False, "info": {"episode": 42}, } @@ -115,9 +115,9 @@ def test_transition_to_batch_observation_flattening(): # Check other fields are mapped to next.* format assert batch[ACTION] == "action_data" - assert batch["next.reward"] == 1.5 - assert batch["next.done"] - assert not batch["next.truncated"] + assert batch[REWARD] == 1.5 + assert batch[DONE] + assert not batch[TRUNCATED] assert batch["info"] == {"episode": 42} @@ -125,9 +125,9 @@ def test_no_observation_keys(): """Test behavior when there are no observation.* keys.""" batch = { ACTION: torch.tensor([1.0, 2.0]), - "next.reward": 2.0, - "next.done": False, - "next.truncated": True, + REWARD: 2.0, + DONE: False, + TRUNCATED: True, "info": {"test": "no_obs"}, } @@ -146,9 +146,9 @@ def test_no_observation_keys(): # Round trip should work reconstructed_batch = transition_to_batch(transition) assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([1.0, 2.0])) - assert reconstructed_batch["next.reward"] == 2.0 - assert not reconstructed_batch["next.done"] - assert reconstructed_batch["next.truncated"] + assert reconstructed_batch[REWARD] == 2.0 + assert not reconstructed_batch[DONE] + assert reconstructed_batch[TRUNCATED] assert reconstructed_batch["info"] == {"test": "no_obs"} @@ -173,9 +173,9 @@ def test_minimal_batch(): reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch[OBS_STATE] == "minimal_state" assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([0.5])) - assert reconstructed_batch["next.reward"] == 0.0 - assert not reconstructed_batch["next.done"] - assert not reconstructed_batch["next.truncated"] + assert reconstructed_batch[REWARD] == 0.0 + assert not reconstructed_batch[DONE] + assert not reconstructed_batch[TRUNCATED] assert reconstructed_batch["info"] == {} @@ -197,9 +197,9 @@ def test_empty_batch(): # Round trip reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch[ACTION] is None - assert reconstructed_batch["next.reward"] == 0.0 - assert not reconstructed_batch["next.done"] - assert not reconstructed_batch["next.truncated"] + assert reconstructed_batch[REWARD] == 0.0 + assert not reconstructed_batch[DONE] + assert not reconstructed_batch[TRUNCATED] assert reconstructed_batch["info"] == {} @@ -210,9 +210,9 @@ def test_complex_nested_observation(): f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, OBS_STATE: torch.randn(7), ACTION: torch.randn(8), - "next.reward": 3.14, - "next.done": False, - "next.truncated": True, + REWARD: 3.14, + DONE: False, + TRUNCATED: True, "info": {"episode_length": 200, "success": True}, } @@ -240,9 +240,9 @@ def test_complex_nested_observation(): assert torch.allclose(batch[ACTION], reconstructed_batch[ACTION]) # Check other fields - assert batch["next.reward"] == reconstructed_batch["next.reward"] - assert batch["next.done"] == reconstructed_batch["next.done"] - assert batch["next.truncated"] == reconstructed_batch["next.truncated"] + assert batch[REWARD] == reconstructed_batch[REWARD] + assert batch[DONE] == reconstructed_batch[DONE] + assert batch[TRUNCATED] == reconstructed_batch[TRUNCATED] assert batch["info"] == reconstructed_batch["info"] @@ -267,13 +267,13 @@ def test_custom_converter(): batch = { OBS_STATE: torch.randn(1, 4), ACTION: torch.randn(1, 2), - "next.reward": 1.0, - "next.done": False, + REWARD: 1.0, + DONE: False, } result = processor(batch) # Check the reward was doubled by our custom converter - assert result["next.reward"] == 2.0 + assert result[REWARD] == 2.0 assert torch.allclose(result[OBS_STATE], batch[OBS_STATE]) assert torch.allclose(result[ACTION], batch[ACTION]) diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index d347858d..bc58f7a6 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -9,7 +9,7 @@ from lerobot.processor.converters import ( to_tensor, transition_to_batch, ) -from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, DONE, OBS_STATE, OBS_STR, REWARD # Tests for the unified to_tensor function @@ -201,8 +201,8 @@ def test_batch_to_transition_with_index_fields(): batch = { OBS_STATE: torch.randn(1, 7), ACTION: torch.randn(1, 4), - "next.reward": 1.5, - "next.done": False, + REWARD: 1.5, + DONE: False, "task": ["pick_cube"], "index": torch.tensor([42], dtype=torch.int64), "task_index": torch.tensor([3], dtype=torch.int64), diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 6dbf3745..904fd6fc 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -35,7 +35,7 @@ from lerobot.processor import ( TransitionKey, ) from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED from tests.conftest import assert_contract_is_typed @@ -258,9 +258,9 @@ def test_step_through_with_dict(): batch = { OBS_IMAGE: None, ACTION: None, - "next.reward": 0.0, - "next.done": False, - "next.truncated": False, + REWARD: 0.0, + DONE: False, + TRUNCATED: False, "info": {}, } @@ -1843,9 +1843,9 @@ def test_save_load_with_custom_converter_functions(): batch = { OBS_IMAGE: torch.randn(1, 3, 32, 32), ACTION: torch.randn(1, 7), - "next.reward": torch.tensor([1.0]), - "next.done": torch.tensor([False]), - "next.truncated": torch.tensor([False]), + REWARD: torch.tensor([1.0]), + DONE: torch.tensor([False]), + TRUNCATED: torch.tensor([False]), "info": {}, } diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index 1e6c0df9..ddf0771f 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -22,7 +22,7 @@ import torch from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD from tests.fixtures.constants import DUMMY_REPO_ID @@ -380,9 +380,9 @@ def test_to_lerobot_dataset(tmp_path): for feature, value in ds[i].items(): if feature == ACTION: assert torch.equal(value, buffer.actions[i]) - elif feature == "next.reward": + elif feature == REWARD: assert torch.equal(value, buffer.rewards[i]) - elif feature == "next.done": + elif feature == DONE: assert torch.equal(value, buffer.dones[i]) elif feature == OBS_IMAGE: # Tensor -> numpy is not precise, so we have some diff there