chore: replace hard-coded next values with constants throughout all the source code (#2056)

This commit is contained in:
Steven Palma
2025-09-26 14:30:07 +02:00
committed by GitHub
parent ec40ccde0d
commit c5b5955c5a
13 changed files with 87 additions and 86 deletions

View File

@@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
)
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
@@ -55,7 +55,7 @@ def resolve_delta_timestamps(
"""
delta_timestamps = {}
for key in ds_meta.features:
if key == "next.reward" and cfg.reward_delta_indices is not None:
if key == REWARD and cfg.reward_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == ACTION and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]

View File

@@ -23,7 +23,7 @@ from typing import Any
import numpy as np
import torch
from lerobot.utils.constants import ACTION, OBS_PREFIX
from lerobot.utils.constants import ACTION, DONE, OBS_PREFIX, REWARD, TRUNCATED
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
@@ -355,9 +355,9 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
return create_transition(
observation=observation_keys if observation_keys else None,
action=batch.get(ACTION),
reward=batch.get("next.reward", 0.0),
done=batch.get("next.done", False),
truncated=batch.get("next.truncated", False),
reward=batch.get(REWARD, 0.0),
done=batch.get(DONE, False),
truncated=batch.get(TRUNCATED, False),
info=batch.get("info", {}),
complementary_data=complementary_data if complementary_data else None,
)
@@ -380,9 +380,9 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
batch = {
ACTION: transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
"next.done": transition.get(TransitionKey.DONE, False),
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
REWARD: transition.get(TransitionKey.REWARD, 0.0),
DONE: transition.get(TransitionKey.DONE, False),
TRUNCATED: transition.get(TransitionKey.TRUNCATED, False),
"info": transition.get(TransitionKey.INFO, {}),
}

View File

@@ -24,7 +24,7 @@ import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, OBS_IMAGE
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, REWARD
from lerobot.utils.transition import Transition
@@ -534,8 +534,8 @@ class ReplayBuffer:
features[ACTION] = act_info
# Add "reward" and "done"
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
features["next.done"] = {"dtype": "bool", "shape": (1,)}
features[REWARD] = {"dtype": "float32", "shape": (1,)}
features[DONE] = {"dtype": "bool", "shape": (1,)}
# Add state keys
for key in self.states:
@@ -578,8 +578,8 @@ class ReplayBuffer:
# Fill action, reward, done
frame_dict[ACTION] = self.actions[actual_idx].cpu()
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict[DONE] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
frame_dict["task"] = task_name
# Add complementary_info if available
@@ -648,7 +648,7 @@ class ReplayBuffer:
# Check if the dataset has "next.done" key
sample = dataset[0]
has_done_key = "next.done" in sample
has_done_key = DONE in sample
# Check for complementary_info keys
complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")]
@@ -671,11 +671,11 @@ class ReplayBuffer:
action = current_sample[ACTION].unsqueeze(0) # Add batch dimension
# ----- 3) Reward and done -----
reward = float(current_sample["next.reward"].item()) # ensure float
reward = float(current_sample[REWARD].item()) # ensure float
# Determine done flag - use next.done if available, otherwise infer from episode boundaries
if has_done_key:
done = bool(current_sample["next.done"].item()) # ensure bool
done = bool(current_sample[DONE].item()) # ensure bool
else:
# If this is the last frame or if next frame is in a different episode, mark as done
done = False

View File

@@ -25,6 +25,7 @@ import torchvision.transforms.functional as F # type: ignore # noqa: N812
from tqdm import tqdm # type: ignore
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import DONE, REWARD
def select_rect_roi(img):
@@ -212,7 +213,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
for key, value in frame.items():
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"):
continue
if key in ("next.done", "next.reward"):
if key in (DONE, REWARD):
# if not isinstance(value, str) and len(value.shape) == 0:
value = value.unsqueeze(0)

View File

@@ -73,7 +73,7 @@ from lerobot.teleoperators import (
)
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
@@ -602,8 +602,8 @@ def control_loop(
action_features = teleop_device.action_features
features = {
ACTION: action_features,
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
REWARD: {"dtype": "float32", "shape": (1,), "names": None},
DONE: {"dtype": "bool", "shape": (1,), "names": None},
}
if use_gripper:
features["complementary_info.discrete_penalty"] = {
@@ -673,8 +673,8 @@ def control_loop(
frame = {
**observations,
ACTION: action_to_record.cpu(),
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
"next.done": np.array([terminated or truncated], dtype=bool),
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
DONE: np.array([terminated or truncated], dtype=bool),
}
if use_gripper:
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)

View File

@@ -75,7 +75,7 @@ import torch.utils.data
import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
class EpisodeSampler(torch.utils.data.Sampler):
@@ -166,11 +166,11 @@ def visualize_dataset(
for dim_idx, val in enumerate(batch[OBS_STATE][i]):
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
if "next.done" in batch:
rr.log("next.done", rr.Scalar(batch["next.done"][i].item()))
if DONE in batch:
rr.log(DONE, rr.Scalar(batch[DONE][i].item()))
if "next.reward" in batch:
rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item()))
if REWARD in batch:
rr.log(REWARD, rr.Scalar(batch[REWARD][i].item()))
if "next.success" in batch:
rr.log("next.success", rr.Scalar(batch["next.success"][i].item()))

View File

@@ -81,7 +81,7 @@ from lerobot.envs.utils import (
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
@@ -451,9 +451,9 @@ def _compile_episode_data(
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
DONE: rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
REWARD: rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
}
# For the last observation frame, all other keys will just be copy padded.