chore: replace hard-coded next values with constants throughout all the source code (#2056)

This commit is contained in:
Steven Palma
2025-09-26 14:30:07 +02:00
committed by GitHub
parent ec40ccde0d
commit c5b5955c5a
13 changed files with 87 additions and 86 deletions

View File

@@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
) )
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
IMAGENET_STATS = { IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
@@ -55,7 +55,7 @@ def resolve_delta_timestamps(
""" """
delta_timestamps = {} delta_timestamps = {}
for key in ds_meta.features: for key in ds_meta.features:
if key == "next.reward" and cfg.reward_delta_indices is not None: if key == REWARD and cfg.reward_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == ACTION and cfg.action_delta_indices is not None: if key == ACTION and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]

View File

@@ -23,7 +23,7 @@ from typing import Any
import numpy as np import numpy as np
import torch import torch
from lerobot.utils.constants import ACTION, OBS_PREFIX from lerobot.utils.constants import ACTION, DONE, OBS_PREFIX, REWARD, TRUNCATED
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
@@ -355,9 +355,9 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
return create_transition( return create_transition(
observation=observation_keys if observation_keys else None, observation=observation_keys if observation_keys else None,
action=batch.get(ACTION), action=batch.get(ACTION),
reward=batch.get("next.reward", 0.0), reward=batch.get(REWARD, 0.0),
done=batch.get("next.done", False), done=batch.get(DONE, False),
truncated=batch.get("next.truncated", False), truncated=batch.get(TRUNCATED, False),
info=batch.get("info", {}), info=batch.get("info", {}),
complementary_data=complementary_data if complementary_data else None, complementary_data=complementary_data if complementary_data else None,
) )
@@ -380,9 +380,9 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
batch = { batch = {
ACTION: transition.get(TransitionKey.ACTION), ACTION: transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0), REWARD: transition.get(TransitionKey.REWARD, 0.0),
"next.done": transition.get(TransitionKey.DONE, False), DONE: transition.get(TransitionKey.DONE, False),
"next.truncated": transition.get(TransitionKey.TRUNCATED, False), TRUNCATED: transition.get(TransitionKey.TRUNCATED, False),
"info": transition.get(TransitionKey.INFO, {}), "info": transition.get(TransitionKey.INFO, {}),
} }

View File

@@ -24,7 +24,7 @@ import torch.nn.functional as F # noqa: N812
from tqdm import tqdm from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, OBS_IMAGE from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, REWARD
from lerobot.utils.transition import Transition from lerobot.utils.transition import Transition
@@ -534,8 +534,8 @@ class ReplayBuffer:
features[ACTION] = act_info features[ACTION] = act_info
# Add "reward" and "done" # Add "reward" and "done"
features["next.reward"] = {"dtype": "float32", "shape": (1,)} features[REWARD] = {"dtype": "float32", "shape": (1,)}
features["next.done"] = {"dtype": "bool", "shape": (1,)} features[DONE] = {"dtype": "bool", "shape": (1,)}
# Add state keys # Add state keys
for key in self.states: for key in self.states:
@@ -578,8 +578,8 @@ class ReplayBuffer:
# Fill action, reward, done # Fill action, reward, done
frame_dict[ACTION] = self.actions[actual_idx].cpu() frame_dict[ACTION] = self.actions[actual_idx].cpu()
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() frame_dict[DONE] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
frame_dict["task"] = task_name frame_dict["task"] = task_name
# Add complementary_info if available # Add complementary_info if available
@@ -648,7 +648,7 @@ class ReplayBuffer:
# Check if the dataset has "next.done" key # Check if the dataset has "next.done" key
sample = dataset[0] sample = dataset[0]
has_done_key = "next.done" in sample has_done_key = DONE in sample
# Check for complementary_info keys # Check for complementary_info keys
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")] complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
@@ -671,11 +671,11 @@ class ReplayBuffer:
action = current_sample[ACTION].unsqueeze(0) # Add batch dimension action = current_sample[ACTION].unsqueeze(0) # Add batch dimension
# ----- 3) Reward and done ----- # ----- 3) Reward and done -----
reward = float(current_sample["next.reward"].item()) # ensure float reward = float(current_sample[REWARD].item()) # ensure float
# Determine done flag - use next.done if available, otherwise infer from episode boundaries # Determine done flag - use next.done if available, otherwise infer from episode boundaries
if has_done_key: if has_done_key:
done = bool(current_sample["next.done"].item()) # ensure bool done = bool(current_sample[DONE].item()) # ensure bool
else: else:
# If this is the last frame or if next frame is in a different episode, mark as done # If this is the last frame or if next frame is in a different episode, mark as done
done = False done = False

View File

@@ -25,6 +25,7 @@ import torchvision.transforms.functional as F # type: ignore # noqa: N812
from tqdm import tqdm # type: ignore from tqdm import tqdm # type: ignore
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import DONE, REWARD
def select_rect_roi(img): def select_rect_roi(img):
@@ -212,7 +213,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
for key, value in frame.items(): for key, value in frame.items():
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"): if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"):
continue continue
if key in ("next.done", "next.reward"): if key in (DONE, REWARD):
# if not isinstance(value, str) and len(value.shape) == 0: # if not isinstance(value, str) and len(value.shape) == 0:
value = value.unsqueeze(0) value = value.unsqueeze(0)

View File

@@ -73,7 +73,7 @@ from lerobot.teleoperators import (
) )
from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
from lerobot.utils.robot_utils import busy_wait from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
@@ -602,8 +602,8 @@ def control_loop(
action_features = teleop_device.action_features action_features = teleop_device.action_features
features = { features = {
ACTION: action_features, ACTION: action_features,
"next.reward": {"dtype": "float32", "shape": (1,), "names": None}, REWARD: {"dtype": "float32", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None}, DONE: {"dtype": "bool", "shape": (1,), "names": None},
} }
if use_gripper: if use_gripper:
features["complementary_info.discrete_penalty"] = { features["complementary_info.discrete_penalty"] = {
@@ -673,8 +673,8 @@ def control_loop(
frame = { frame = {
**observations, **observations,
ACTION: action_to_record.cpu(), ACTION: action_to_record.cpu(),
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32), REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
"next.done": np.array([terminated or truncated], dtype=bool), DONE: np.array([terminated or truncated], dtype=bool),
} }
if use_gripper: if use_gripper:
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0) discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)

View File

@@ -75,7 +75,7 @@ import torch.utils.data
import tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
class EpisodeSampler(torch.utils.data.Sampler): class EpisodeSampler(torch.utils.data.Sampler):
@@ -166,11 +166,11 @@ def visualize_dataset(
for dim_idx, val in enumerate(batch[OBS_STATE][i]): for dim_idx, val in enumerate(batch[OBS_STATE][i]):
rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
if "next.done" in batch: if DONE in batch:
rr.log("next.done", rr.Scalar(batch["next.done"][i].item())) rr.log(DONE, rr.Scalar(batch[DONE][i].item()))
if "next.reward" in batch: if REWARD in batch:
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item())) rr.log(REWARD, rr.Scalar(batch[REWARD][i].item()))
if "next.success" in batch: if "next.success" in batch:
rr.log("next.success", rr.Scalar(batch["next.success"][i].item())) rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))

View File

@@ -81,7 +81,7 @@ from lerobot.envs.utils import (
from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
from lerobot.utils.io_utils import write_video from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import ( from lerobot.utils.utils import (
@@ -451,9 +451,9 @@ def _compile_episode_data(
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)), "episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"frame_index": torch.arange(0, num_frames - 1, 1), "frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps, "timestamp": torch.arange(0, num_frames - 1, 1) / fps,
"next.done": rollout_data["done"][ep_ix, : num_frames - 1], DONE: rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1], "next.success": rollout_data["success"][ep_ix, : num_frames - 1],
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32), REWARD: rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
} }
# For the last observation frame, all other keys will just be copy padded. # For the last observation frame, all other keys will just be copy padded.

View File

@@ -46,7 +46,7 @@ from lerobot.datasets.utils import (
from lerobot.envs.factory import make_env_config from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy_config from lerobot.policies.factory import make_policy_config
from lerobot.robots import make_robot_from_config from lerobot.robots import make_robot_from_config
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.mocks.mock_robot import MockRobotConfig from tests.mocks.mock_robot import MockRobotConfig
from tests.utils import require_x86_64_kernel from tests.utils import require_x86_64_kernel
@@ -399,8 +399,8 @@ def test_factory(env_name, repo_id, policy_name):
("timestamp", 0, True), ("timestamp", 0, True),
# TODO(rcadene): should we rename it agent_pos? # TODO(rcadene): should we rename it agent_pos?
(OBS_STATE, 1, True), (OBS_STATE, 1, True),
("next.reward", 0, False), (REWARD, 0, False),
("next.done", 0, False), (DONE, 0, False),
] ]
# test number of dimensions # test number of dimensions

View File

@@ -19,7 +19,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE from lerobot.utils.constants import OBS_IMAGE, REWARD
from tests.utils import require_package from tests.utils import require_package
@@ -45,7 +45,7 @@ def test_binary_classifier_with_default_params():
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
} }
config.output_features = { config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)), REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
} }
config.normalization_mapping = { config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY, "VISUAL": NormalizationMode.IDENTITY,
@@ -58,7 +58,7 @@ def test_binary_classifier_with_default_params():
input = { input = {
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)), OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(), REWARD: torch.randint(low=0, high=2, size=(batch_size,)).float(),
} }
images, labels = classifier.extract_images_and_labels(input) images, labels = classifier.extract_images_and_labels(input)
@@ -87,7 +87,7 @@ def test_multiclass_classifier():
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
} }
config.output_features = { config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)), REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
} }
config.num_cameras = 1 config.num_cameras = 1
config.num_classes = num_classes config.num_classes = num_classes
@@ -97,7 +97,7 @@ def test_multiclass_classifier():
input = { input = {
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)), OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.rand((batch_size, num_classes)), REWARD: torch.rand((batch_size, num_classes)),
} }
images, labels = classifier.extract_images_and_labels(input) images, labels = classifier.extract_images_and_labels(input)

View File

@@ -2,7 +2,7 @@ import torch
from lerobot.processor import DataProcessorPipeline, TransitionKey from lerobot.processor import DataProcessorPipeline, TransitionKey
from lerobot.processor.converters import batch_to_transition, transition_to_batch from lerobot.processor.converters import batch_to_transition, transition_to_batch
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_PREFIX, OBS_STATE from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, REWARD, TRUNCATED
def _dummy_batch(): def _dummy_batch():
@@ -12,9 +12,9 @@ def _dummy_batch():
f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128), f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128),
OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]), OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
ACTION: torch.tensor([[0.5]]), ACTION: torch.tensor([[0.5]]),
"next.reward": 1.0, REWARD: 1.0,
"next.done": False, DONE: False,
"next.truncated": False, TRUNCATED: False,
"info": {"key": "value"}, "info": {"key": "value"},
} }
@@ -38,9 +38,9 @@ def test_observation_grouping_roundtrip():
# Check other fields # Check other fields
assert torch.allclose(batch_out[ACTION], batch_in[ACTION]) assert torch.allclose(batch_out[ACTION], batch_in[ACTION])
assert batch_out["next.reward"] == batch_in["next.reward"] assert batch_out[REWARD] == batch_in[REWARD]
assert batch_out["next.done"] == batch_in["next.done"] assert batch_out[DONE] == batch_in[DONE]
assert batch_out["next.truncated"] == batch_in["next.truncated"] assert batch_out[TRUNCATED] == batch_in[TRUNCATED]
assert batch_out["info"] == batch_in["info"] assert batch_out["info"] == batch_in["info"]
@@ -51,9 +51,9 @@ def test_batch_to_transition_observation_grouping():
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
OBS_STATE: [1, 2, 3, 4], OBS_STATE: [1, 2, 3, 4],
ACTION: torch.tensor([0.1, 0.2, 0.3, 0.4]), ACTION: torch.tensor([0.1, 0.2, 0.3, 0.4]),
"next.reward": 1.5, REWARD: 1.5,
"next.done": True, DONE: True,
"next.truncated": False, TRUNCATED: False,
"info": {"episode": 42}, "info": {"episode": 42},
} }
@@ -115,9 +115,9 @@ def test_transition_to_batch_observation_flattening():
# Check other fields are mapped to next.* format # Check other fields are mapped to next.* format
assert batch[ACTION] == "action_data" assert batch[ACTION] == "action_data"
assert batch["next.reward"] == 1.5 assert batch[REWARD] == 1.5
assert batch["next.done"] assert batch[DONE]
assert not batch["next.truncated"] assert not batch[TRUNCATED]
assert batch["info"] == {"episode": 42} assert batch["info"] == {"episode": 42}
@@ -125,9 +125,9 @@ def test_no_observation_keys():
"""Test behavior when there are no observation.* keys.""" """Test behavior when there are no observation.* keys."""
batch = { batch = {
ACTION: torch.tensor([1.0, 2.0]), ACTION: torch.tensor([1.0, 2.0]),
"next.reward": 2.0, REWARD: 2.0,
"next.done": False, DONE: False,
"next.truncated": True, TRUNCATED: True,
"info": {"test": "no_obs"}, "info": {"test": "no_obs"},
} }
@@ -146,9 +146,9 @@ def test_no_observation_keys():
# Round trip should work # Round trip should work
reconstructed_batch = transition_to_batch(transition) reconstructed_batch = transition_to_batch(transition)
assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([1.0, 2.0])) assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([1.0, 2.0]))
assert reconstructed_batch["next.reward"] == 2.0 assert reconstructed_batch[REWARD] == 2.0
assert not reconstructed_batch["next.done"] assert not reconstructed_batch[DONE]
assert reconstructed_batch["next.truncated"] assert reconstructed_batch[TRUNCATED]
assert reconstructed_batch["info"] == {"test": "no_obs"} assert reconstructed_batch["info"] == {"test": "no_obs"}
@@ -173,9 +173,9 @@ def test_minimal_batch():
reconstructed_batch = transition_to_batch(transition) reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch[OBS_STATE] == "minimal_state" assert reconstructed_batch[OBS_STATE] == "minimal_state"
assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([0.5])) assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([0.5]))
assert reconstructed_batch["next.reward"] == 0.0 assert reconstructed_batch[REWARD] == 0.0
assert not reconstructed_batch["next.done"] assert not reconstructed_batch[DONE]
assert not reconstructed_batch["next.truncated"] assert not reconstructed_batch[TRUNCATED]
assert reconstructed_batch["info"] == {} assert reconstructed_batch["info"] == {}
@@ -197,9 +197,9 @@ def test_empty_batch():
# Round trip # Round trip
reconstructed_batch = transition_to_batch(transition) reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch[ACTION] is None assert reconstructed_batch[ACTION] is None
assert reconstructed_batch["next.reward"] == 0.0 assert reconstructed_batch[REWARD] == 0.0
assert not reconstructed_batch["next.done"] assert not reconstructed_batch[DONE]
assert not reconstructed_batch["next.truncated"] assert not reconstructed_batch[TRUNCATED]
assert reconstructed_batch["info"] == {} assert reconstructed_batch["info"] == {}
@@ -210,9 +210,9 @@ def test_complex_nested_observation():
f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
OBS_STATE: torch.randn(7), OBS_STATE: torch.randn(7),
ACTION: torch.randn(8), ACTION: torch.randn(8),
"next.reward": 3.14, REWARD: 3.14,
"next.done": False, DONE: False,
"next.truncated": True, TRUNCATED: True,
"info": {"episode_length": 200, "success": True}, "info": {"episode_length": 200, "success": True},
} }
@@ -240,9 +240,9 @@ def test_complex_nested_observation():
assert torch.allclose(batch[ACTION], reconstructed_batch[ACTION]) assert torch.allclose(batch[ACTION], reconstructed_batch[ACTION])
# Check other fields # Check other fields
assert batch["next.reward"] == reconstructed_batch["next.reward"] assert batch[REWARD] == reconstructed_batch[REWARD]
assert batch["next.done"] == reconstructed_batch["next.done"] assert batch[DONE] == reconstructed_batch[DONE]
assert batch["next.truncated"] == reconstructed_batch["next.truncated"] assert batch[TRUNCATED] == reconstructed_batch[TRUNCATED]
assert batch["info"] == reconstructed_batch["info"] assert batch["info"] == reconstructed_batch["info"]
@@ -267,13 +267,13 @@ def test_custom_converter():
batch = { batch = {
OBS_STATE: torch.randn(1, 4), OBS_STATE: torch.randn(1, 4),
ACTION: torch.randn(1, 2), ACTION: torch.randn(1, 2),
"next.reward": 1.0, REWARD: 1.0,
"next.done": False, DONE: False,
} }
result = processor(batch) result = processor(batch)
# Check the reward was doubled by our custom converter # Check the reward was doubled by our custom converter
assert result["next.reward"] == 2.0 assert result[REWARD] == 2.0
assert torch.allclose(result[OBS_STATE], batch[OBS_STATE]) assert torch.allclose(result[OBS_STATE], batch[OBS_STATE])
assert torch.allclose(result[ACTION], batch[ACTION]) assert torch.allclose(result[ACTION], batch[ACTION])

View File

@@ -9,7 +9,7 @@ from lerobot.processor.converters import (
to_tensor, to_tensor,
transition_to_batch, transition_to_batch,
) )
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR from lerobot.utils.constants import ACTION, DONE, OBS_STATE, OBS_STR, REWARD
# Tests for the unified to_tensor function # Tests for the unified to_tensor function
@@ -201,8 +201,8 @@ def test_batch_to_transition_with_index_fields():
batch = { batch = {
OBS_STATE: torch.randn(1, 7), OBS_STATE: torch.randn(1, 7),
ACTION: torch.randn(1, 4), ACTION: torch.randn(1, 4),
"next.reward": 1.5, REWARD: 1.5,
"next.done": False, DONE: False,
"task": ["pick_cube"], "task": ["pick_cube"],
"index": torch.tensor([42], dtype=torch.int64), "index": torch.tensor([42], dtype=torch.int64),
"task_index": torch.tensor([3], dtype=torch.int64), "task_index": torch.tensor([3], dtype=torch.int64),

View File

@@ -35,7 +35,7 @@ from lerobot.processor import (
TransitionKey, TransitionKey,
) )
from lerobot.processor.converters import create_transition, identity_transition from lerobot.processor.converters import create_transition, identity_transition
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
from tests.conftest import assert_contract_is_typed from tests.conftest import assert_contract_is_typed
@@ -258,9 +258,9 @@ def test_step_through_with_dict():
batch = { batch = {
OBS_IMAGE: None, OBS_IMAGE: None,
ACTION: None, ACTION: None,
"next.reward": 0.0, REWARD: 0.0,
"next.done": False, DONE: False,
"next.truncated": False, TRUNCATED: False,
"info": {}, "info": {},
} }
@@ -1843,9 +1843,9 @@ def test_save_load_with_custom_converter_functions():
batch = { batch = {
OBS_IMAGE: torch.randn(1, 3, 32, 32), OBS_IMAGE: torch.randn(1, 3, 32, 32),
ACTION: torch.randn(1, 7), ACTION: torch.randn(1, 7),
"next.reward": torch.tensor([1.0]), REWARD: torch.tensor([1.0]),
"next.done": torch.tensor([False]), DONE: torch.tensor([False]),
"next.truncated": torch.tensor([False]), TRUNCATED: torch.tensor([False]),
"info": {}, "info": {},
} }

View File

@@ -22,7 +22,7 @@ import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD
from tests.fixtures.constants import DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_REPO_ID
@@ -380,9 +380,9 @@ def test_to_lerobot_dataset(tmp_path):
for feature, value in ds[i].items(): for feature, value in ds[i].items():
if feature == ACTION: if feature == ACTION:
assert torch.equal(value, buffer.actions[i]) assert torch.equal(value, buffer.actions[i])
elif feature == "next.reward": elif feature == REWARD:
assert torch.equal(value, buffer.rewards[i]) assert torch.equal(value, buffer.rewards[i])
elif feature == "next.done": elif feature == DONE:
assert torch.equal(value, buffer.dones[i]) assert torch.equal(value, buffer.dones[i])
elif feature == OBS_IMAGE: elif feature == OBS_IMAGE:
# Tensor -> numpy is not precise, so we have some diff there # Tensor -> numpy is not precise, so we have some diff there