chore: replace hard-coded action values with constants throughout all the source code (#2055)

* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-26 13:33:18 +02:00
committed by GitHub
parent 9627765ce2
commit d2782cf66b
47 changed files with 269 additions and 255 deletions

View File

@@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
)
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import OBS_PREFIX
from lerobot.utils.constants import ACTION, OBS_PREFIX
IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
@@ -57,7 +57,7 @@ def resolve_delta_timestamps(
for key in ds_meta.features:
if key == "next.reward" and cfg.reward_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == "action" and cfg.action_delta_indices is not None:
if key == ACTION and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]

View File

@@ -132,7 +132,7 @@ def aggregate_pipeline_dataset_features(
# Convert the processed features into the final dataset format.
dataset_features = {}
if processed_features[ACTION]:
dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos))
dataset_features.update(hw_to_dataset_features(processed_features[ACTION], ACTION, use_videos))
if processed_features[OBS_STR]:
dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos))

View File

@@ -43,7 +43,7 @@ from lerobot.datasets.backward_compatibility import (
BackwardCompatibilityError,
ForwardCompatibilityError,
)
from lerobot.utils.constants import OBS_ENV_STATE, OBS_STR
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
from lerobot.utils.utils import is_valid_numpy_dtype_string
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
@@ -646,7 +646,7 @@ def hw_to_dataset_features(
}
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts and prefix == "action":
if joint_fts and prefix == ACTION:
features[prefix] = {
"dtype": "float32",
"shape": (len(joint_fts),),
@@ -733,7 +733,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
type = FeatureType.ENV
elif key.startswith(OBS_STR):
type = FeatureType.STATE
elif key.startswith("action"):
elif key.startswith(ACTION):
type = FeatureType.ACTION
else:
continue

View File

@@ -53,12 +53,12 @@ class AlohaEnv(EnvConfig):
render_mode: str = "rgb_array"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
ACTION: ACTION,
"agent_pos": OBS_STATE,
"top": f"{OBS_IMAGE}.top",
"pixels/top": f"{OBS_IMAGES}.top",
@@ -93,13 +93,13 @@ class PushtEnv(EnvConfig):
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
ACTION: ACTION,
"agent_pos": OBS_STATE,
"environment_state": OBS_ENV_STATE,
"pixels": OBS_IMAGE,
@@ -135,13 +135,13 @@ class XarmEnv(EnvConfig):
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
ACTION: ACTION,
"agent_pos": OBS_STATE,
"pixels": OBS_IMAGE,
}
@@ -259,12 +259,12 @@ class LiberoEnv(EnvConfig):
camera_name_mapping: dict[str, str] | None = (None,)
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
ACTION: ACTION,
"agent_pos": OBS_STATE,
"pixels/agentview_image": f"{OBS_IMAGES}.image",
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",

View File

@@ -394,7 +394,7 @@ class ACT(nn.Module):
latent dimension.
"""
if self.config.use_vae and self.training:
assert "action" in batch, (
assert ACTION in batch, (
"actions must be provided when using the variational objective in training mode."
)
@@ -404,7 +404,7 @@ class ACT(nn.Module):
batch_size = batch[OBS_ENV_STATE].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch and self.training:
if self.config.use_vae and ACTION in batch and self.training:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
@@ -412,7 +412,7 @@ class ACT(nn.Module):
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
action_embed = self.vae_encoder_action_input_proj(batch[ACTION]) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)

View File

@@ -82,7 +82,7 @@ class DiffusionPolicy(PreTrainedPolicy):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
ACTION: deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
@@ -306,10 +306,10 @@ class DiffusionModel(nn.Module):
}
"""
# Input validation.
assert set(batch).issuperset({OBS_STATE, "action", "action_is_pad"})
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
n_obs_steps = batch[OBS_STATE].shape[1]
horizon = batch["action"].shape[1]
horizon = batch[ACTION].shape[1]
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
@@ -317,7 +317,7 @@ class DiffusionModel(nn.Module):
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# Forward diffusion.
trajectory = batch["action"]
trajectory = batch[ACTION]
# Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random noising timestep for each item in the batch.
@@ -338,7 +338,7 @@ class DiffusionModel(nn.Module):
if self.config.prediction_type == "epsilon":
target = eps
elif self.config.prediction_type == "sample":
target = batch["action"]
target = batch[ACTION]
else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")

View File

@@ -21,7 +21,7 @@ import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.factory import make_policy
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
def display(tensor: torch.Tensor):
@@ -73,7 +73,7 @@ def main():
for cam_key, uint_chw_array in example["images"].items():
batch[f"{OBS_IMAGES}.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
batch[OBS_STATE] = torch.from_numpy(example["state"])
batch["action"] = torch.from_numpy(outputs["actions"])
batch[ACTION] = torch.from_numpy(outputs["actions"])
batch["task"] = example["prompt"]
if model_name == "pi0_aloha_towel":
@@ -117,7 +117,7 @@ def main():
actions.append(action)
actions = torch.stack(actions, dim=1)
pi_actions = batch["action"]
pi_actions = batch[ACTION]
print("actions")
display(actions)
print()

View File

@@ -225,7 +225,7 @@ class SACConfig(PreTrainedConfig):
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
)
if "action" not in self.output_features:
if ACTION not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property

View File

@@ -31,7 +31,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
from lerobot.policies.utils import get_device_from_parameters
from lerobot.utils.constants import OBS_ENV_STATE, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
@@ -51,7 +51,7 @@ class SACPolicy(
self.config = config
# Determine action dimension and initialize all components
continuous_action_dim = config.output_features["action"].shape[0]
continuous_action_dim = config.output_features[ACTION].shape[0]
self._init_encoders()
self._init_critics(continuous_action_dim)
self._init_actor(continuous_action_dim)
@@ -158,7 +158,7 @@ class SACPolicy(
The computed loss tensor
"""
# Extract common components from batch
actions: Tensor = batch["action"]
actions: Tensor = batch[ACTION]
observations: dict[str, Tensor] = batch["state"]
observation_features: Tensor = batch.get("observation_feature")

View File

@@ -92,7 +92,7 @@ class TDMPCPolicy(PreTrainedPolicy):
"""
self._queues = {
OBS_STATE: deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
ACTION: deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
}
if self.config.image_features:
self._queues[OBS_IMAGE] = deque(maxlen=1)

View File

@@ -23,7 +23,7 @@ from typing import Any
import numpy as np
import torch
from lerobot.utils.constants import OBS_PREFIX
from lerobot.utils.constants import ACTION, OBS_PREFIX
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
@@ -344,7 +344,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
if not isinstance(batch, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
action = batch.get("action")
action = batch.get(ACTION)
if action is not None and not isinstance(action, PolicyAction):
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
@@ -354,7 +354,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
return create_transition(
observation=observation_keys if observation_keys else None,
action=batch.get("action"),
action=batch.get(ACTION),
reward=batch.get("next.reward", 0.0),
done=batch.get("next.done", False),
truncated=batch.get("next.truncated", False),
@@ -379,7 +379,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
batch = {
"action": transition.get(TransitionKey.ACTION),
ACTION: transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
"next.done": transition.get(TransitionKey.DONE, False),
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),

View File

@@ -59,6 +59,7 @@ from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors
from lerobot.utils.constants import ACTION
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
@@ -196,7 +197,7 @@ def detect_features_and_norm_modes(
feature_type = FeatureType.VISUAL
elif "state" in key:
feature_type = FeatureType.STATE
elif "action" in key:
elif ACTION in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE # Default
@@ -215,7 +216,7 @@ def detect_features_and_norm_modes(
feature_type = FeatureType.VISUAL
elif "state" in key or "joint" in key or "position" in key:
feature_type = FeatureType.STATE
elif "action" in key:
elif ACTION in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE
@@ -321,7 +322,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
feature_type = FeatureType.VISUAL
elif "state" in key:
feature_type = FeatureType.STATE
elif "action" in key:
elif ACTION in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE

View File

@@ -26,6 +26,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION
from .converters import from_tensor_to_numpy, to_tensor
from .core import EnvTransition, PolicyAction, TransitionKey
@@ -272,7 +273,7 @@ class _NormalizationMixin:
Returns:
The transformed action tensor.
"""
processed_action = self._apply_transform(action, "action", FeatureType.ACTION, inverse=inverse)
processed_action = self._apply_transform(action, ACTION, FeatureType.ACTION, inverse=inverse)
return processed_action
def _apply_transform(

View File

@@ -5,6 +5,7 @@ import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction
from lerobot.utils.constants import ACTION
@dataclass
@@ -23,7 +24,7 @@ class RobotActionToPolicyActionProcessorStep(ActionProcessorStep):
return asdict(self)
def transform_features(self, features):
features[PipelineFeatureType.ACTION]["action"] = PolicyFeature(
features[PipelineFeatureType.ACTION][ACTION] = PolicyFeature(
type=FeatureType.ACTION, shape=(len(self.motor_names),)
)
return features

View File

@@ -24,7 +24,7 @@ import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import OBS_IMAGE
from lerobot.utils.constants import ACTION, OBS_IMAGE
from lerobot.utils.transition import Transition
@@ -467,7 +467,7 @@ class ReplayBuffer:
if list_transition:
first_transition = list_transition[0]
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
first_action = first_transition["action"].to(device)
first_action = first_transition[ACTION].to(device)
# Get complementary info if available
first_complementary_info = None
@@ -492,7 +492,7 @@ class ReplayBuffer:
elif isinstance(v, torch.Tensor):
data[k] = v.to(storage_device)
action = data["action"]
action = data[ACTION]
replay_buffer.add(
state=data["state"],
@@ -530,8 +530,8 @@ class ReplayBuffer:
# Add "action"
sample_action = self.actions[0]
act_info = guess_feature_info(t=sample_action, name="action")
features["action"] = act_info
act_info = guess_feature_info(t=sample_action, name=ACTION)
features[ACTION] = act_info
# Add "reward" and "done"
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
@@ -577,7 +577,7 @@ class ReplayBuffer:
frame_dict[key] = self.states[key][actual_idx].cpu()
# Fill action, reward, done
frame_dict["action"] = self.actions[actual_idx].cpu()
frame_dict[ACTION] = self.actions[actual_idx].cpu()
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
frame_dict["task"] = task_name
@@ -668,7 +668,7 @@ class ReplayBuffer:
current_state[key] = val.unsqueeze(0) # Add batch dimension
# ----- 2) Action -----
action = current_sample["action"].unsqueeze(0) # Add batch dimension
action = current_sample[ACTION].unsqueeze(0) # Add batch dimension
# ----- 3) Reward and done -----
reward = float(current_sample["next.reward"].item()) # ensure float
@@ -788,8 +788,8 @@ def concatenate_batch_transitions(
}
# Concatenate basic fields
left_batch_transitions["action"] = torch.cat(
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
left_batch_transitions[ACTION] = torch.cat(
[left_batch_transitions[ACTION], right_batch_transition[ACTION]], dim=0
)
left_batch_transitions["reward"] = torch.cat(
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0

View File

@@ -73,7 +73,7 @@ from lerobot.teleoperators import (
)
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
@@ -601,7 +601,7 @@ def control_loop(
if cfg.mode == "record":
action_features = teleop_device.action_features
features = {
"action": action_features,
ACTION: action_features,
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
}
@@ -672,7 +672,7 @@ def control_loop(
)
frame = {
**observations,
"action": action_to_record.cpu(),
ACTION: action_to_record.cpu(),
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
"next.done": np.array([terminated or truncated], dtype=bool),
}
@@ -733,7 +733,7 @@ def replay_trajectory(
download_videos=False,
)
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode)
actions = episode_frames.select_columns("action")
actions = episode_frames.select_columns(ACTION)
_, info = env.reset()
@@ -741,7 +741,7 @@ def replay_trajectory(
start_time = time.perf_counter()
transition = create_transition(
observation=env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {},
action=action_data["action"],
action=action_data[ACTION],
)
transition = action_processor(transition)
env.step(transition[TransitionKey.ACTION])

View File

@@ -80,6 +80,7 @@ from lerobot.transport.utils import (
state_to_bytes,
)
from lerobot.utils.constants import (
ACTION,
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
@@ -402,7 +403,7 @@ def add_actor_information_and_train(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
@@ -415,7 +416,7 @@ def add_actor_information_and_train(
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
"action": actions,
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
@@ -460,7 +461,7 @@ def add_actor_information_and_train(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
@@ -474,7 +475,7 @@ def add_actor_information_and_train(
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
"action": actions,
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
@@ -1155,7 +1156,7 @@ def process_transitions(
# Skip transitions with NaN values
if check_nan_in_transition(
observations=transition["state"],
actions=transition["action"],
actions=transition[ACTION],
next_state=transition["next_state"],
):
logging.warning("[LEARNER] NaN detected in transition, skipping")

View File

@@ -23,7 +23,7 @@ from typing import Any
import cv2
import numpy as np
from lerobot.utils.constants import OBS_STATE
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
@@ -330,7 +330,7 @@ class LeKiwiClient(Robot):
actions = np.array([action.get(k, 0.0) for k in self._state_order], dtype=np.float32)
action_sent = {key: actions[i] for i, key in enumerate(self._state_order)}
action_sent["action"] = actions
action_sent[ACTION] = actions
return action_sent
def disconnect(self):

View File

@@ -75,7 +75,7 @@ import torch.utils.data
import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import OBS_STATE
from lerobot.utils.constants import ACTION, OBS_STATE
class EpisodeSampler(torch.utils.data.Sampler):
@@ -157,9 +157,9 @@ def visualize_dataset(
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
# display each dimension of action space (e.g. actuators command)
if "action" in batch:
for dim_idx, val in enumerate(batch["action"][i]):
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
if ACTION in batch:
for dim_idx, val in enumerate(batch[ACTION][i]):
rr.log(f"{ACTION}/{dim_idx}", rr.Scalar(val.item()))
# display each dimension of observed state space (e.g. agent position in joint space)
if OBS_STATE in batch:

View File

@@ -81,7 +81,7 @@ from lerobot.envs.utils import (
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.utils.constants import OBS_STR
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
@@ -213,7 +213,7 @@ def rollout(
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
ret = {
"action": torch.stack(all_actions, dim=1),
ACTION: torch.stack(all_actions, dim=1),
"reward": torch.stack(all_rewards, dim=1),
"success": torch.stack(all_successes, dim=1),
"done": torch.stack(all_dones, dim=1),
@@ -440,14 +440,14 @@ def _compile_episode_data(
"""
ep_dicts = []
total_frames = 0
for ep_ix in range(rollout_data["action"].shape[0]):
for ep_ix in range(rollout_data[ACTION].shape[0]):
# + 2 to include the first done frame and the last observation frame.
num_frames = done_indices[ep_ix].item() + 2
total_frames += num_frames
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = {
"action": rollout_data["action"][ep_ix, : num_frames - 1],
ACTION: rollout_data[ACTION][ep_ix, : num_frames - 1],
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,

View File

@@ -109,7 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401
so101_leader,
)
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
from lerobot.utils.constants import OBS_STR
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import (
init_keyboard_listener,
is_headless,
@@ -319,7 +319,7 @@ def record_loop(
robot_type=robot.robot_type,
)
action_names = dataset.features["action"]["names"]
action_names = dataset.features[ACTION]["names"]
act_processed_policy: RobotAction = {
f"{name}": float(action_values[i]) for i, name in enumerate(action_names)
}
@@ -361,7 +361,7 @@ def record_loop(
# Write to dataset
if dataset is not None:
action_frame = build_dataset_frame(dataset.features, action_values, prefix="action")
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": single_task}
dataset.add_frame(frame)

View File

@@ -60,6 +60,7 @@ from lerobot.robots import ( # noqa: F401
so100_follower,
so101_follower,
)
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
init_logging,
@@ -99,7 +100,7 @@ def replay(cfg: ReplayConfig):
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode)
actions = episode_frames.select_columns("action")
actions = episode_frames.select_columns(ACTION)
robot.connect()
@@ -107,9 +108,9 @@ def replay(cfg: ReplayConfig):
for idx in range(len(episode_frames)):
start_episode_t = time.perf_counter()
action_array = actions[idx]["action"]
action_array = actions[idx][ACTION]
action = {}
for i, name in enumerate(dataset.features["action"]["names"]):
for i, name in enumerate(dataset.features[ACTION]["names"]):
action[name] = action_array[i]
robot_obs = robot.get_observation()

View File

@@ -18,6 +18,8 @@ from typing import TypedDict
import torch
from lerobot.utils.constants import ACTION
class Transition(TypedDict):
state: dict[str, torch.Tensor]
@@ -39,7 +41,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
}
# Move action to device
transition["action"] = transition["action"].to(device, non_blocking=non_blocking)
transition[ACTION] = transition[ACTION].to(device, non_blocking=non_blocking)
# Move reward and done if they are tensors
if isinstance(transition["reward"], torch.Tensor):