chore: replace hard-coded obs values with constants throughout all the source code (#2037)

* chore: replace hard-coded OBS values with constants throughout all the source code

* chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-25 15:36:47 +02:00
committed by GitHub
parent ddba994d73
commit 43d878a102
52 changed files with 659 additions and 649 deletions

View File

@@ -41,6 +41,7 @@ from lerobot.datasets.video_utils import (
decode_video_frames_torchvision, decode_video_frames_torchvision,
encode_video_frames, encode_video_frames,
) )
from lerobot.utils.constants import OBS_IMAGE
BASE_ENCODING = OrderedDict( BASE_ENCODING = OrderedDict(
[ [
@@ -117,7 +118,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
hf_dataset = dataset.hf_dataset.with_format(None) hf_dataset = dataset.hf_dataset.with_format(None)
# We only save images from the first camera # We only save images from the first camera
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")] img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
imgs_dataset = hf_dataset.select_columns(img_keys[0]) imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate( for i, item in enumerate(

View File

@@ -21,6 +21,7 @@ from lerobot.policies.factory import make_pre_post_processors
from lerobot.processor import make_default_processors from lerobot.processor import make_default_processors
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.scripts.lerobot_record import record_loop from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.constants import OBS_STR
from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun from lerobot.utils.visualization_utils import init_rerun
@@ -42,7 +43,7 @@ policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Configure the dataset features # Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action") action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation") obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
dataset_features = {**action_features, **obs_features} dataset_features = {**action_features, **obs_features}
# Create the dataset # Create the dataset

View File

@@ -22,6 +22,7 @@ from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
from lerobot.scripts.lerobot_record import record_loop from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
from lerobot.utils.constants import OBS_STR
from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun from lerobot.utils.visualization_utils import init_rerun
@@ -48,7 +49,7 @@ teleop_action_processor, robot_action_processor, robot_observation_processor = m
# Configure the dataset features # Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action") action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation") obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
dataset_features = {**action_features, **obs_features} dataset_features = {**action_features, **obs_features}
# Create the dataset # Create the dataset

View File

@@ -27,7 +27,7 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config # NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401 from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
from lerobot.robots.robot import Robot from lerobot.robots.robot import Robot
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.utils import init_logging from lerobot.utils.utils import init_logging
Action = torch.Tensor Action = torch.Tensor
@@ -66,7 +66,7 @@ def validate_robot_cameras_for_policy(
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]: def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
return hw_to_dataset_features(robot.observation_features, "observation", use_video=False) return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
def is_image_key(k: str) -> bool: def is_image_key(k: str) -> bool:
@@ -141,7 +141,7 @@ def make_lerobot_observation(
lerobot_features: dict[str, dict], lerobot_features: dict[str, dict],
) -> LeRobotObservation: ) -> LeRobotObservation:
"""Make a lerobot observation from a raw observation.""" """Make a lerobot observation from a raw observation."""
return build_dataset_frame(lerobot_features, robot_obs, prefix="observation") return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR)
def prepare_raw_observation( def prepare_raw_observation(

View File

@@ -27,6 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
) )
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import OBS_PREFIX
IMAGENET_STATS = { IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
@@ -58,7 +59,7 @@ def resolve_delta_timestamps(
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == "action" and cfg.action_delta_indices is not None: if key == "action" and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith("observation.") and cfg.observation_delta_indices is not None: if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
if len(delta_timestamps) == 0: if len(delta_timestamps) == 0:

View File

@@ -19,7 +19,7 @@ from typing import Any
from lerobot.configs.types import PipelineFeatureType from lerobot.configs.types import PipelineFeatureType
from lerobot.datasets.utils import hw_to_dataset_features from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor import DataProcessorPipeline from lerobot.processor import DataProcessorPipeline
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
def create_initial_features( def create_initial_features(
@@ -92,8 +92,8 @@ def aggregate_pipeline_dataset_features(
# Intermediate storage for categorized and filtered features. # Intermediate storage for categorized and filtered features.
processed_features: dict[str, dict[str, Any]] = { processed_features: dict[str, dict[str, Any]] = {
"action": {}, ACTION: {},
"observation": {}, OBS_STR: {},
} }
images_token = OBS_IMAGES.split(".")[-1] images_token = OBS_IMAGES.split(".")[-1]
@@ -125,17 +125,15 @@ def aggregate_pipeline_dataset_features(
# 3. Add the feature to the appropriate group with a clean name. # 3. Add the feature to the appropriate group with a clean name.
name = strip_prefix(key, PREFIXES_TO_STRIP) name = strip_prefix(key, PREFIXES_TO_STRIP)
if is_action: if is_action:
processed_features["action"][name] = value processed_features[ACTION][name] = value
else: else:
processed_features["observation"][name] = value processed_features[OBS_STR][name] = value
# Convert the processed features into the final dataset format. # Convert the processed features into the final dataset format.
dataset_features = {} dataset_features = {}
if processed_features["action"]: if processed_features[ACTION]:
dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos)) dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos))
if processed_features["observation"]: if processed_features[OBS_STR]:
dataset_features.update( dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos))
hw_to_dataset_features(processed_features["observation"], "observation", use_videos)
)
return dataset_features return dataset_features

View File

@@ -43,6 +43,7 @@ from lerobot.datasets.backward_compatibility import (
BackwardCompatibilityError, BackwardCompatibilityError,
ForwardCompatibilityError, ForwardCompatibilityError,
) )
from lerobot.utils.constants import OBS_ENV_STATE, OBS_STR
from lerobot.utils.utils import is_valid_numpy_dtype_string from lerobot.utils.utils import is_valid_numpy_dtype_string
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
@@ -652,7 +653,7 @@ def hw_to_dataset_features(
"names": list(joint_fts), "names": list(joint_fts),
} }
if joint_fts and prefix == "observation": if joint_fts and prefix == OBS_STR:
features[f"{prefix}.state"] = { features[f"{prefix}.state"] = {
"dtype": "float32", "dtype": "float32",
"shape": (len(joint_fts),), "shape": (len(joint_fts),),
@@ -728,9 +729,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1]) shape = (shape[2], shape[0], shape[1])
elif key == "observation.environment_state": elif key == OBS_ENV_STATE:
type = FeatureType.ENV type = FeatureType.ENV
elif key.startswith("observation"): elif key.startswith(OBS_STR):
type = FeatureType.STATE type = FeatureType.STATE
elif key.startswith("action"): elif key.startswith("action"):
type = FeatureType.ACTION type = FeatureType.ACTION

View File

@@ -26,6 +26,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig from lerobot.envs.configs import EnvConfig
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.utils.utils import get_channel_first_image_shape from lerobot.utils.utils import get_channel_first_image_shape
@@ -41,9 +42,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations = {} return_observations = {}
if "pixels" in observations: if "pixels" in observations:
if isinstance(observations["pixels"], dict): if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} imgs = {f"{OBS_IMAGES}.{key}": img for key, img in observations["pixels"].items()}
else: else:
imgs = {"observation.image": observations["pixels"]} imgs = {OBS_IMAGE: observations["pixels"]}
for imgkey, img in imgs.items(): for imgkey, img in imgs.items():
# TODO(aliberts, rcadene): use transforms.ToTensor()? # TODO(aliberts, rcadene): use transforms.ToTensor()?
@@ -72,13 +73,13 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
if env_state.dim() == 1: if env_state.dim() == 1:
env_state = env_state.unsqueeze(0) env_state = env_state.unsqueeze(0)
return_observations["observation.environment_state"] = env_state return_observations[OBS_ENV_STATE] = env_state
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float() agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1: if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0) agent_pos = agent_pos.unsqueeze(0)
return_observations["observation.state"] = agent_pos return_observations[OBS_STATE] = agent_pos
return return_observations return return_observations

View File

@@ -35,7 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGES from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class ACTPolicy(PreTrainedPolicy): class ACTPolicy(PreTrainedPolicy):
@@ -398,10 +398,10 @@ class ACT(nn.Module):
"actions must be provided when using the variational objective in training mode." "actions must be provided when using the variational objective in training mode."
) )
if "observation.images" in batch: if OBS_IMAGES in batch:
batch_size = batch["observation.images"][0].shape[0] batch_size = batch[OBS_IMAGES][0].shape[0]
else: else:
batch_size = batch["observation.environment_state"].shape[0] batch_size = batch[OBS_ENV_STATE].shape[0]
# Prepare the latent for input to the transformer encoder. # Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch and self.training: if self.config.use_vae and "action" in batch and self.training:
@@ -410,7 +410,7 @@ class ACT(nn.Module):
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D) ) # (B, 1, D)
if self.config.robot_state_feature: if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
@@ -430,7 +430,7 @@ class ACT(nn.Module):
cls_joint_is_pad = torch.full( cls_joint_is_pad = torch.full(
(batch_size, 2 if self.config.robot_state_feature else 1), (batch_size, 2 if self.config.robot_state_feature else 1),
False, False,
device=batch["observation.state"].device, device=batch[OBS_STATE].device,
) )
key_padding_mask = torch.cat( key_padding_mask = torch.cat(
[cls_joint_is_pad, batch["action_is_pad"]], axis=1 [cls_joint_is_pad, batch["action_is_pad"]], axis=1
@@ -454,7 +454,7 @@ class ACT(nn.Module):
mu = log_sigma_x2 = None mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device batch[OBS_STATE].device
) )
# Prepare transformer encoder inputs. # Prepare transformer encoder inputs.
@@ -462,18 +462,16 @@ class ACT(nn.Module):
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)) encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token. # Robot state token.
if self.config.robot_state_feature: if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch[OBS_STATE]))
# Environment state token. # Environment state token.
if self.config.env_state_feature: if self.config.env_state_feature:
encoder_in_tokens.append( encoder_in_tokens.append(self.encoder_env_state_input_proj(batch[OBS_ENV_STATE]))
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
if self.config.image_features: if self.config.image_features:
# For a list of images, the H and W may vary but H*W is constant. # For a list of images, the H and W may vary but H*W is constant.
# NOTE: If modifying this section, verify on MPS devices that # NOTE: If modifying this section, verify on MPS devices that
# gradients remain stable (no explosions or NaNs). # gradients remain stable (no explosions or NaNs).
for img in batch["observation.images"]: for img in batch[OBS_IMAGES]:
cam_features = self.backbone(img)["feature_map"] cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) cam_features = self.encoder_img_feat_input_proj(cam_features)

View File

@@ -81,13 +81,13 @@ class DiffusionPolicy(PreTrainedPolicy):
def reset(self): def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`""" """Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = { self._queues = {
"observation.state": deque(maxlen=self.config.n_obs_steps), OBS_STATE: deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps), "action": deque(maxlen=self.config.n_action_steps),
} }
if self.config.image_features: if self.config.image_features:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
if self.config.env_state_feature: if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad() @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
@@ -234,7 +234,7 @@ class DiffusionModel(nn.Module):
if self.config.image_features: if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera: if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first. # Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...") images_per_camera = einops.rearrange(batch[OBS_IMAGES], "b s n ... -> n (b s) ...")
img_features_list = torch.cat( img_features_list = torch.cat(
[ [
encoder(images) encoder(images)
@@ -249,7 +249,7 @@ class DiffusionModel(nn.Module):
else: else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder. # Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder( img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...")
) )
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features). # feature dim (effectively concatenating the camera features).
@@ -275,7 +275,7 @@ class DiffusionModel(nn.Module):
"observation.environment_state": (B, n_obs_steps, environment_dim) "observation.environment_state": (B, n_obs_steps, environment_dim)
} }
""" """
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
assert n_obs_steps == self.config.n_obs_steps assert n_obs_steps == self.config.n_obs_steps
# Encode image features and concatenate them all together along with the state vector. # Encode image features and concatenate them all together along with the state vector.
@@ -306,9 +306,9 @@ class DiffusionModel(nn.Module):
} }
""" """
# Input validation. # Input validation.
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"}) assert set(batch).issuperset({OBS_STATE, "action", "action_is_pad"})
assert "observation.images" in batch or "observation.environment_state" in batch assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
n_obs_steps = batch["observation.state"].shape[1] n_obs_steps = batch[OBS_STATE].shape[1]
horizon = batch["action"].shape[1] horizon = batch["action"].shape[1]
assert horizon == self.config.horizon assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps assert n_obs_steps == self.config.n_obs_steps

View File

@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import ( from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig, CosineDecayWithWarmupSchedulerConfig,
) )
from lerobot.utils.constants import OBS_IMAGES
@PreTrainedConfig.register_subclass("pi0") @PreTrainedConfig.register_subclass("pi0")
@@ -113,7 +114,7 @@ class PI0Config(PreTrainedConfig):
# raise ValueError("You must provide at least one image or the environment state among the inputs.") # raise ValueError("You must provide at least one image or the environment state among the inputs.")
for i in range(self.empty_cameras): for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}" key = f"{OBS_IMAGES}.empty_camera_{i}"
empty_camera = PolicyFeature( empty_camera = PolicyFeature(
type=FeatureType.VISUAL, type=FeatureType.VISUAL,
shape=(3, 480, 640), shape=(3, 480, 640),

View File

@@ -21,6 +21,7 @@ import torch
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.factory import make_policy from lerobot.policies.factory import make_policy
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
def display(tensor: torch.Tensor): def display(tensor: torch.Tensor):
@@ -60,26 +61,26 @@ def main():
# Override stats # Override stats
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id) dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
dataset_meta.stats["observation.state"]["mean"] = torch.tensor( dataset_meta.stats[OBS_STATE]["mean"] = torch.tensor(
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32 norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
) )
dataset_meta.stats["observation.state"]["std"] = torch.tensor( dataset_meta.stats[OBS_STATE]["std"] = torch.tensor(
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32 norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
) )
# Create LeRobot batch from Jax # Create LeRobot batch from Jax
batch = {} batch = {}
for cam_key, uint_chw_array in example["images"].items(): for cam_key, uint_chw_array in example["images"].items():
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 batch[f"{OBS_IMAGES}.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
batch["observation.state"] = torch.from_numpy(example["state"]) batch[OBS_STATE] = torch.from_numpy(example["state"])
batch["action"] = torch.from_numpy(outputs["actions"]) batch["action"] = torch.from_numpy(outputs["actions"])
batch["task"] = example["prompt"] batch["task"] = example["prompt"]
if model_name == "pi0_aloha_towel": if model_name == "pi0_aloha_towel":
del batch["observation.images.cam_low"] del batch[f"{OBS_IMAGES}.cam_low"]
elif model_name == "pi0_aloha_sim": elif model_name == "pi0_aloha_sim":
batch["observation.images.top"] = batch["observation.images.cam_high"] batch[f"{OBS_IMAGES}.top"] = batch[f"{OBS_IMAGES}.cam_high"]
del batch["observation.images.cam_high"] del batch[f"{OBS_IMAGES}.cam_high"]
# Batchify # Batchify
for key in batch: for key in batch:

View File

@@ -6,6 +6,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import ( from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig, CosineDecayWithWarmupSchedulerConfig,
) )
from lerobot.utils.constants import OBS_IMAGES
@PreTrainedConfig.register_subclass("pi0fast") @PreTrainedConfig.register_subclass("pi0fast")
@@ -99,7 +100,7 @@ class PI0FASTConfig(PreTrainedConfig):
def validate_features(self) -> None: def validate_features(self) -> None:
for i in range(self.empty_cameras): for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}" key = f"{OBS_IMAGES}.empty_camera_{i}"
empty_camera = PolicyFeature( empty_camera = PolicyFeature(
type=FeatureType.VISUAL, type=FeatureType.VISUAL,
shape=(3, 480, 640), shape=(3, 480, 640),

View File

@@ -31,6 +31,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
from lerobot.policies.utils import get_device_from_parameters from lerobot.policies.utils import get_device_from_parameters
from lerobot.utils.constants import OBS_ENV_STATE, OBS_STATE
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
@@ -513,17 +514,17 @@ class SACObservationEncoder(nn.Module):
) )
def _init_state_layers(self) -> None: def _init_state_layers(self) -> None:
self.has_env = "observation.environment_state" in self.config.input_features self.has_env = OBS_ENV_STATE in self.config.input_features
self.has_state = "observation.state" in self.config.input_features self.has_state = OBS_STATE in self.config.input_features
if self.has_env: if self.has_env:
dim = self.config.input_features["observation.environment_state"].shape[0] dim = self.config.input_features[OBS_ENV_STATE].shape[0]
self.env_encoder = nn.Sequential( self.env_encoder = nn.Sequential(
nn.Linear(dim, self.config.latent_dim), nn.Linear(dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim), nn.LayerNorm(self.config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
if self.has_state: if self.has_state:
dim = self.config.input_features["observation.state"].shape[0] dim = self.config.input_features[OBS_STATE].shape[0]
self.state_encoder = nn.Sequential( self.state_encoder = nn.Sequential(
nn.Linear(dim, self.config.latent_dim), nn.Linear(dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim), nn.LayerNorm(self.config.latent_dim),
@@ -549,9 +550,9 @@ class SACObservationEncoder(nn.Module):
cache = self.get_cached_image_features(obs) cache = self.get_cached_image_features(obs)
parts.append(self._encode_images(cache, detach)) parts.append(self._encode_images(cache, detach))
if self.has_env: if self.has_env:
parts.append(self.env_encoder(obs["observation.environment_state"])) parts.append(self.env_encoder(obs[OBS_ENV_STATE]))
if self.has_state: if self.has_state:
parts.append(self.state_encoder(obs["observation.state"])) parts.append(self.state_encoder(obs[OBS_STATE]))
if parts: if parts:
return torch.cat(parts, dim=-1) return torch.cat(parts, dim=-1)

View File

@@ -19,6 +19,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig, OptimizerConfig from lerobot.optim.optimizers import AdamWConfig, OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.constants import OBS_IMAGE
@PreTrainedConfig.register_subclass(name="reward_classifier") @PreTrainedConfig.register_subclass(name="reward_classifier")
@@ -69,7 +70,7 @@ class RewardClassifierConfig(PreTrainedConfig):
def validate_features(self) -> None: def validate_features(self) -> None:
"""Validate feature configurations.""" """Validate feature configurations."""
has_image = any(key.startswith("observation.image") for key in self.input_features) has_image = any(key.startswith(OBS_IMAGE) for key in self.input_features)
if not has_image: if not has_image:
raise ValueError( raise ValueError(
"You must provide an image observation (key starting with 'observation.image') in the input features" "You must provide an image observation (key starting with 'observation.image') in the input features"

View File

@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import ( from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig, CosineDecayWithWarmupSchedulerConfig,
) )
from lerobot.utils.constants import OBS_IMAGES
@PreTrainedConfig.register_subclass("smolvla") @PreTrainedConfig.register_subclass("smolvla")
@@ -117,7 +118,7 @@ class SmolVLAConfig(PreTrainedConfig):
def validate_features(self) -> None: def validate_features(self) -> None:
for i in range(self.empty_cameras): for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}" key = f"{OBS_IMAGES}.empty_camera_{i}"
empty_camera = PolicyFeature( empty_camera = PolicyFeature(
type=FeatureType.VISUAL, type=FeatureType.VISUAL,
shape=(3, 480, 640), shape=(3, 480, 640),

View File

@@ -38,7 +38,7 @@ from torch import Tensor
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD
class TDMPCPolicy(PreTrainedPolicy): class TDMPCPolicy(PreTrainedPolicy):
@@ -91,13 +91,13 @@ class TDMPCPolicy(PreTrainedPolicy):
called on `env.reset()` called on `env.reset()`
""" """
self._queues = { self._queues = {
"observation.state": deque(maxlen=1), OBS_STATE: deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
} }
if self.config.image_features: if self.config.image_features:
self._queues["observation.image"] = deque(maxlen=1) self._queues[OBS_IMAGE] = deque(maxlen=1)
if self.config.env_state_feature: if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=1) self._queues[OBS_ENV_STATE] = deque(maxlen=1)
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start # Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
# CEM for the next step. # CEM for the next step.
self._prev_mean: torch.Tensor | None = None self._prev_mean: torch.Tensor | None = None
@@ -325,7 +325,7 @@ class TDMPCPolicy(PreTrainedPolicy):
action = batch[ACTION] # (t, b, action_dim) action = batch[ACTION] # (t, b, action_dim)
reward = batch[REWARD] # (t, b) reward = batch[REWARD] # (t, b)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")} observations = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)}
# Apply random image augmentations. # Apply random image augmentations.
if self.config.image_features and self.config.max_random_shift_ratio > 0: if self.config.image_features and self.config.max_random_shift_ratio > 0:
@@ -387,10 +387,10 @@ class TDMPCPolicy(PreTrainedPolicy):
temporal_loss_coeffs temporal_loss_coeffs
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1) * F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
# `z_preds` depends on the current observation and the actions. # `z_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0] * ~batch[f"{OBS_STR}.state_is_pad"][0]
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
# `z_targets` depends on the next observation. # `z_targets` depends on the next observation.
* ~batch["observation.state_is_pad"][1:] * ~batch[f"{OBS_STR}.state_is_pad"][1:]
) )
.sum(0) .sum(0)
.mean() .mean()
@@ -403,7 +403,7 @@ class TDMPCPolicy(PreTrainedPolicy):
* F.mse_loss(reward_preds, reward, reduction="none") * F.mse_loss(reward_preds, reward, reduction="none")
* ~batch["next.reward_is_pad"] * ~batch["next.reward_is_pad"]
# `reward_preds` depends on the current observation and the actions. # `reward_preds` depends on the current observation and the actions.
* ~batch["observation.state_is_pad"][0] * ~batch[f"{OBS_STR}.state_is_pad"][0]
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
) )
.sum(0) .sum(0)
@@ -419,11 +419,11 @@ class TDMPCPolicy(PreTrainedPolicy):
reduction="none", reduction="none",
).sum(0) # sum over ensemble ).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions. # `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0] * ~batch[f"{OBS_STR}.state_is_pad"][0]
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations. # q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"] * ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:] * ~batch[f"{OBS_STR}.state_is_pad"][1:]
) )
.sum(0) .sum(0)
.mean() .mean()
@@ -441,7 +441,7 @@ class TDMPCPolicy(PreTrainedPolicy):
temporal_loss_coeffs temporal_loss_coeffs
* raw_v_value_loss * raw_v_value_loss
# `v_targets` depends on the first observation and the actions, as does `v_preds`. # `v_targets` depends on the first observation and the actions, as does `v_preds`.
* ~batch["observation.state_is_pad"][0] * ~batch[f"{OBS_STR}.state_is_pad"][0]
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
) )
.sum(0) .sum(0)
@@ -477,7 +477,7 @@ class TDMPCPolicy(PreTrainedPolicy):
* mse * mse
* temporal_loss_coeffs * temporal_loss_coeffs
# `action_preds` depends on the first observation and the actions. # `action_preds` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0] * ~batch[f"{OBS_STR}.state_is_pad"][0]
* ~batch["action_is_pad"] * ~batch["action_is_pad"]
).mean() ).mean()

View File

@@ -133,7 +133,7 @@ class VQBeTPolicy(PreTrainedPolicy):
batch.pop(ACTION) batch.pop(ACTION)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
# NOTE: It's important that this happens after stacking the images into a single key. # NOTE: It's important that this happens after stacking the images into a single key.
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch: if ACTION in batch:
batch.pop(ACTION) batch.pop(ACTION)
@@ -340,14 +340,12 @@ class VQBeTModel(nn.Module):
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]: def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
# Input validation. # Input validation.
assert set(batch).issuperset({"observation.state", "observation.images"}) assert set(batch).issuperset({OBS_STATE, OBS_IMAGES})
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
assert n_obs_steps == self.config.n_obs_steps assert n_obs_steps == self.config.n_obs_steps
# Extract image feature (first combine batch and sequence dims). # Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder( img_features = self.rgb_encoder(einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ..."))
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
)
# Separate batch and sequence dims. # Separate batch and sequence dims.
img_features = einops.rearrange( img_features = einops.rearrange(
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
@@ -359,9 +357,7 @@ class VQBeTModel(nn.Module):
img_features img_features
) # (batch, obs_step, number of different cameras, projection dims) ) # (batch, obs_step, number of different cameras, projection dims)
input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))] input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))]
input_tokens.append( input_tokens.append(self.state_projector(batch[OBS_STATE])) # (batch, obs_step, projection dims)
self.state_projector(batch["observation.state"])
) # (batch, obs_step, projection dims)
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps)) input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
# Interleave tokens by stacking and rearranging. # Interleave tokens by stacking and rearranging.
input_tokens = torch.stack(input_tokens, dim=2) input_tokens = torch.stack(input_tokens, dim=2)

View File

@@ -23,6 +23,8 @@ from typing import Any
import numpy as np import numpy as np
import torch import torch
from lerobot.utils.constants import OBS_PREFIX
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
@@ -347,7 +349,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
raise ValueError(f"Action should be a PolicyAction type got {type(action)}") raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
# Extract observation and complementary data keys. # Extract observation and complementary data keys.
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} observation_keys = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)}
complementary_data = _extract_complementary_data(batch) complementary_data = _extract_complementary_data(batch)
return create_transition( return create_transition(

View File

@@ -21,7 +21,7 @@ import torch
from torch import Tensor from torch import Tensor
from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@@ -171,7 +171,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
# Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1) # Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1)
for old_prefix, new_prefix in prefix_pairs.items(): for old_prefix, new_prefix in prefix_pairs.items():
prefixed_old = f"observation.{old_prefix}" prefixed_old = f"{OBS_STR}.{old_prefix}"
if key.startswith(prefixed_old): if key.startswith(prefixed_old):
suffix = key[len(prefixed_old) :] suffix = key[len(prefixed_old) :]
new_key = f"{new_prefix}{suffix}" new_key = f"{new_prefix}{suffix}"
@@ -191,7 +191,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
# Exact-name rules (pixels, environment_state, agent_pos) # Exact-name rules (pixels, environment_state, agent_pos)
for old, new in exact_pairs.items(): for old, new in exact_pairs.items():
if key == old or key == f"observation.{old}": if key == old or key == f"{OBS_STR}.{old}":
new_key = new new_key = new
new_features[src_ft][new_key] = feat new_features[src_ft][new_key] = feat
handled = True handled = True

View File

@@ -24,6 +24,7 @@ import torch.nn.functional as F # noqa: N812
from tqdm import tqdm from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import OBS_IMAGE
from lerobot.utils.transition import Transition from lerobot.utils.transition import Transition
@@ -240,7 +241,7 @@ class ReplayBuffer:
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
# Identify image keys that need augmentation # Identify image keys that need augmentation
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else [] image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
# Create batched state and next_state # Create batched state and next_state
batch_state = {} batch_state = {}

View File

@@ -73,6 +73,7 @@ from lerobot.teleoperators import (
) )
from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from lerobot.utils.robot_utils import busy_wait from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
@@ -180,7 +181,7 @@ class RobotEnv(gym.Env):
# Define observation spaces for images and other states. # Define observation spaces for images and other states.
if current_observation is not None and "pixels" in current_observation: if current_observation is not None and "pixels" in current_observation:
prefix = "observation.images" prefix = OBS_IMAGES
observation_spaces = { observation_spaces = {
f"{prefix}.{key}": gym.spaces.Box( f"{prefix}.{key}": gym.spaces.Box(
low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8 low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8
@@ -190,7 +191,7 @@ class RobotEnv(gym.Env):
if current_observation is not None: if current_observation is not None:
agent_pos = current_observation["agent_pos"] agent_pos = current_observation["agent_pos"]
observation_spaces["observation.state"] = gym.spaces.Box( observation_spaces[OBS_STATE] = gym.spaces.Box(
low=0, low=0,
high=10, high=10,
shape=agent_pos.shape, shape=agent_pos.shape,
@@ -612,7 +613,7 @@ def control_loop(
} }
for key, value in transition[TransitionKey.OBSERVATION].items(): for key, value in transition[TransitionKey.OBSERVATION].items():
if key == "observation.state": if key == OBS_STATE:
features[key] = { features[key] = {
"dtype": "float32", "dtype": "float32",
"shape": value.squeeze(0).shape, "shape": value.squeeze(0).shape,

View File

@@ -23,6 +23,7 @@ from typing import Any
import cv2 import cv2
import numpy as np import numpy as np
from lerobot.utils.constants import OBS_STATE
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot from ..robot import Robot
@@ -203,7 +204,7 @@ class LeKiwiClient(Robot):
state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32) state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32)
obs_dict: dict[str, Any] = {**flat_state, "observation.state": state_vec} obs_dict: dict[str, Any] = {**flat_state, OBS_STATE: state_vec}
# Decode images # Decode images
current_frames: dict[str, np.ndarray] = {} current_frames: dict[str, np.ndarray] = {}

View File

@@ -75,6 +75,7 @@ import torch.utils.data
import tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import OBS_STATE
class EpisodeSampler(torch.utils.data.Sampler): class EpisodeSampler(torch.utils.data.Sampler):
@@ -161,8 +162,8 @@ def visualize_dataset(
rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
# display each dimension of observed state space (e.g. agent position in joint space) # display each dimension of observed state space (e.g. agent position in joint space)
if "observation.state" in batch: if OBS_STATE in batch:
for dim_idx, val in enumerate(batch["observation.state"][i]): for dim_idx, val in enumerate(batch[OBS_STATE][i]):
rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
if "next.done" in batch: if "next.done" in batch:

View File

@@ -81,6 +81,7 @@ from lerobot.envs.utils import (
from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.utils.constants import OBS_STR
from lerobot.utils.io_utils import write_video from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import ( from lerobot.utils.utils import (
@@ -221,7 +222,7 @@ def rollout(
stacked_observations = {} stacked_observations = {}
for key in all_observations[0]: for key in all_observations[0]:
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
ret["observation"] = stacked_observations ret[OBS_STR] = stacked_observations
if hasattr(policy, "use_original_modules"): if hasattr(policy, "use_original_modules"):
policy.use_original_modules() policy.use_original_modules()
@@ -459,8 +460,8 @@ def _compile_episode_data(
for k in ep_dict: for k in ep_dict:
ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]]) ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]])
for key in rollout_data["observation"]: for key in rollout_data[OBS_STR]:
ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames] ep_dict[key] = rollout_data[OBS_STR][key][ep_ix, :num_frames]
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)

View File

@@ -109,6 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401
so101_leader, so101_leader,
) )
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
from lerobot.utils.constants import OBS_STR
from lerobot.utils.control_utils import ( from lerobot.utils.control_utils import (
init_keyboard_listener, init_keyboard_listener,
is_headless, is_headless,
@@ -303,7 +304,7 @@ def record_loop(
obs_processed = robot_observation_processor(obs) obs_processed = robot_observation_processor(obs)
if policy is not None or dataset is not None: if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix="observation") observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
# Get action from either policy or teleop # Get action from either policy or teleop
if policy is not None and preprocessor is not None and postprocessor is not None: if policy is not None and preprocessor is not None and postprocessor is not None:

View File

@@ -17,19 +17,21 @@ from pathlib import Path
from huggingface_hub.constants import HF_HOME from huggingface_hub.constants import HF_HOME
OBS_ENV_STATE = "observation.environment_state" OBS_STR = "observation"
OBS_STATE = "observation.state" OBS_PREFIX = OBS_STR + "."
OBS_IMAGE = "observation.image" OBS_ENV_STATE = OBS_STR + ".environment_state"
OBS_IMAGES = "observation.images" OBS_STATE = OBS_STR + ".state"
OBS_LANGUAGE = "observation.language" OBS_IMAGE = OBS_STR + ".image"
OBS_IMAGES = OBS_IMAGE + "s"
OBS_LANGUAGE = OBS_STR + ".language"
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
ACTION = "action" ACTION = "action"
REWARD = "next.reward" REWARD = "next.reward"
TRUNCATED = "next.truncated" TRUNCATED = "next.truncated"
DONE = "next.done" DONE = "next.done"
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
ROBOTS = "robots" ROBOTS = "robots"
ROBOT_TYPE = "robot_type" ROBOT_TYPE = "robot_type"
TELEOPERATORS = "teleoperators" TELEOPERATORS = "teleoperators"

View File

@@ -19,6 +19,8 @@ from typing import Any
import numpy as np import numpy as np
import rerun as rr import rerun as rr
from .constants import OBS_PREFIX, OBS_STR
def init_rerun(session_name: str = "lerobot_control_loop") -> None: def init_rerun(session_name: str = "lerobot_control_loop") -> None:
"""Initializes the Rerun SDK for visualizing the control loop.""" """Initializes the Rerun SDK for visualizing the control loop."""
@@ -63,7 +65,7 @@ def log_rerun_data(
for k, v in observation.items(): for k, v in observation.items():
if v is None: if v is None:
continue continue
key = k if str(k).startswith("observation.") else f"observation.{k}" key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
if _is_scalar(v): if _is_scalar(v):
rr.log(key, rr.Scalar(float(v))) rr.log(key, rr.Scalar(float(v)))

View File

@@ -24,6 +24,7 @@ from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset from lerobot.datasets.factory import make_dataset
from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors
from lerobot.utils.constants import OBS_STR
from lerobot.utils.random_utils import set_seed from lerobot.utils.random_utils import set_seed
@@ -92,7 +93,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
# for backward compatibility # for backward compatibility
if k == "task": if k == "task":
continue continue
if k.startswith("observation"): if k.startswith(OBS_STR):
obs[k] = batch[k] obs[k] = batch[k]
if hasattr(train_cfg.policy, "n_action_steps"): if hasattr(train_cfg.policy, "n_action_steps"):

View File

@@ -30,6 +30,7 @@ from lerobot.async_inference.helpers import (
resize_robot_observation_image, resize_robot_observation_image,
) )
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
# FPSTracker # FPSTracker
@@ -115,7 +116,7 @@ def test_timed_action_getters():
def test_timed_observation_getters(): def test_timed_observation_getters():
"""TimedObservation stores & returns timestamp, dict and timestep.""" """TimedObservation stores & returns timestamp, dict and timestep."""
ts = time.time() ts = time.time()
obs_dict = {"observation.state": torch.ones(6)} obs_dict = {OBS_STATE: torch.ones(6)}
to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0) to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0)
assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6) assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
@@ -151,7 +152,7 @@ def test_timed_data_deserialization_data_getters():
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# TimedObservation # TimedObservation
# ------------------------------------------------------------------ # ------------------------------------------------------------------
obs_dict = {"observation.state": torch.arange(4).float()} obs_dict = {OBS_STATE: torch.arange(4).float()}
to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True) to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True)
to_bytes = pickle.dumps(to_in) # nosec to_bytes = pickle.dumps(to_in) # nosec
@@ -161,7 +162,7 @@ def test_timed_data_deserialization_data_getters():
assert to_out.get_timestep() == 7 assert to_out.get_timestep() == 7
assert to_out.must_go is True assert to_out.must_go is True
assert to_out.get_observation().keys() == obs_dict.keys() assert to_out.get_observation().keys() == obs_dict.keys()
torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"]) torch.testing.assert_close(to_out.get_observation()[OBS_STATE], obs_dict[OBS_STATE])
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
@@ -187,7 +188,7 @@ def test_observations_similar_true():
"""Distance below atol → observations considered similar.""" """Distance below atol → observations considered similar."""
# Create mock lerobot features for the similarity check # Create mock lerobot features for the similarity check
lerobot_features = { lerobot_features = {
"observation.state": { OBS_STATE: {
"dtype": "float32", "dtype": "float32",
"shape": [4], "shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"], "names": ["shoulder", "elbow", "wrist", "gripper"],
@@ -222,17 +223,17 @@ def _create_mock_robot_observation():
def _create_mock_lerobot_features(): def _create_mock_lerobot_features():
"""Create mock lerobot features mapping similar to what hw_to_dataset_features returns.""" """Create mock lerobot features mapping similar to what hw_to_dataset_features returns."""
return { return {
"observation.state": { OBS_STATE: {
"dtype": "float32", "dtype": "float32",
"shape": [4], "shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"], "names": ["shoulder", "elbow", "wrist", "gripper"],
}, },
"observation.images.laptop": { f"{OBS_IMAGES}.laptop": {
"dtype": "image", "dtype": "image",
"shape": [480, 640, 3], "shape": [480, 640, 3],
"names": ["height", "width", "channels"], "names": ["height", "width", "channels"],
}, },
"observation.images.phone": { f"{OBS_IMAGES}.phone": {
"dtype": "image", "dtype": "image",
"shape": [480, 640, 3], "shape": [480, 640, 3],
"names": ["height", "width", "channels"], "names": ["height", "width", "channels"],
@@ -243,11 +244,11 @@ def _create_mock_lerobot_features():
def _create_mock_policy_image_features(): def _create_mock_policy_image_features():
"""Create mock policy image features with different resolutions.""" """Create mock policy image features with different resolutions."""
return { return {
"observation.images.laptop": PolicyFeature( f"{OBS_IMAGES}.laptop": PolicyFeature(
type=FeatureType.VISUAL, type=FeatureType.VISUAL,
shape=(3, 224, 224), # Policy expects smaller resolution shape=(3, 224, 224), # Policy expects smaller resolution
), ),
"observation.images.phone": PolicyFeature( f"{OBS_IMAGES}.phone": PolicyFeature(
type=FeatureType.VISUAL, type=FeatureType.VISUAL,
shape=(3, 160, 160), # Different resolution for second camera shape=(3, 160, 160), # Different resolution for second camera
), ),
@@ -306,21 +307,21 @@ def test_prepare_raw_observation():
prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features) prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features)
# Check that state is properly extracted and batched # Check that state is properly extracted and batched
assert "observation.state" in prepared assert OBS_STATE in prepared
state = prepared["observation.state"] state = prepared[OBS_STATE]
assert isinstance(state, torch.Tensor) assert isinstance(state, torch.Tensor)
assert state.shape == (1, 4) # Batched state assert state.shape == (1, 4) # Batched state
# Check that images are processed and resized # Check that images are processed and resized
assert "observation.images.laptop" in prepared assert f"{OBS_IMAGES}.laptop" in prepared
assert "observation.images.phone" in prepared assert f"{OBS_IMAGES}.phone" in prepared
laptop_img = prepared["observation.images.laptop"] laptop_img = prepared[f"{OBS_IMAGES}.laptop"]
phone_img = prepared["observation.images.phone"] phone_img = prepared[f"{OBS_IMAGES}.phone"]
# Check image shapes match policy requirements # Check image shapes match policy requirements
assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape assert laptop_img.shape == policy_image_features[f"{OBS_IMAGES}.laptop"].shape
assert phone_img.shape == policy_image_features["observation.images.phone"].shape assert phone_img.shape == policy_image_features[f"{OBS_IMAGES}.phone"].shape
# Check that images are tensors # Check that images are tensors
assert isinstance(laptop_img, torch.Tensor) assert isinstance(laptop_img, torch.Tensor)
@@ -337,19 +338,19 @@ def test_raw_observation_to_observation_basic():
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
# Check that all expected keys are present # Check that all expected keys are present
assert "observation.state" in observation assert OBS_STATE in observation
assert "observation.images.laptop" in observation assert f"{OBS_IMAGES}.laptop" in observation
assert "observation.images.phone" in observation assert f"{OBS_IMAGES}.phone" in observation
# Check state processing # Check state processing
state = observation["observation.state"] state = observation[OBS_STATE]
assert isinstance(state, torch.Tensor) assert isinstance(state, torch.Tensor)
assert state.device.type == device assert state.device.type == device
assert state.shape == (1, 4) # Batched assert state.shape == (1, 4) # Batched
# Check image processing # Check image processing
laptop_img = observation["observation.images.laptop"] laptop_img = observation[f"{OBS_IMAGES}.laptop"]
phone_img = observation["observation.images.phone"] phone_img = observation[f"{OBS_IMAGES}.phone"]
# Images should have batch dimension: (B, C, H, W) # Images should have batch dimension: (B, C, H, W)
assert laptop_img.shape == (1, 3, 224, 224) assert laptop_img.shape == (1, 3, 224, 224)
@@ -429,19 +430,19 @@ def test_image_processing_pipeline_preserves_content():
robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img} robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img}
lerobot_features = { lerobot_features = {
"observation.state": { OBS_STATE: {
"dtype": "float32", "dtype": "float32",
"shape": [4], "shape": [4],
"names": ["shoulder", "elbow", "wrist", "gripper"], "names": ["shoulder", "elbow", "wrist", "gripper"],
}, },
"observation.images.laptop": { f"{OBS_IMAGES}.laptop": {
"dtype": "image", "dtype": "image",
"shape": [100, 100, 3], "shape": [100, 100, 3],
"names": ["height", "width", "channels"], "names": ["height", "width", "channels"],
}, },
} }
policy_image_features = { policy_image_features = {
"observation.images.laptop": PolicyFeature( f"{OBS_IMAGES}.laptop": PolicyFeature(
type=FeatureType.VISUAL, type=FeatureType.VISUAL,
shape=(3, 50, 50), # Downsamples from 100x100 shape=(3, 50, 50), # Downsamples from 100x100
) )
@@ -449,7 +450,7 @@ def test_image_processing_pipeline_preserves_content():
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu") observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim
# Check that the center region has higher values than corners # Check that the center region has higher values than corners
# Due to bilinear interpolation, exact values will change but pattern should remain # Due to bilinear interpolation, exact values will change but pattern should remain

View File

@@ -23,6 +23,7 @@ import pytest
import torch import torch
from lerobot.configs.types import PolicyFeature from lerobot.configs.types import PolicyFeature
from lerobot.utils.constants import OBS_STATE
from tests.utils import require_package from tests.utils import require_package
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -44,7 +45,7 @@ class MockPolicy:
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return a chunk of 20 dummy actions.""" """Return a chunk of 20 dummy actions."""
batch_size = len(observation["observation.state"]) batch_size = len(observation[OBS_STATE])
return torch.zeros(batch_size, 20, 6) return torch.zeros(batch_size, 20, 6)
def __init__(self): def __init__(self):
@@ -77,7 +78,7 @@ def policy_server():
# Add mock lerobot_features that the observation similarity functions need # Add mock lerobot_features that the observation similarity functions need
server.lerobot_features = { server.lerobot_features = {
"observation.state": { OBS_STATE: {
"dtype": "float32", "dtype": "float32",
"shape": [6], "shape": [6],
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"], "names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],

View File

@@ -28,6 +28,7 @@ from lerobot.datasets.compute_stats import (
sample_images, sample_images,
sample_indices, sample_indices,
) )
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
def mock_load_image_as_numpy(path, dtype, channel_first): def mock_load_image_as_numpy(path, dtype, channel_first):
@@ -136,21 +137,21 @@ def test_get_feature_stats_single_value():
def test_compute_episode_stats(): def test_compute_episode_stats():
episode_data = { episode_data = {
"observation.image": [f"image_{i}.jpg" for i in range(100)], OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)],
"observation.state": np.random.rand(100, 10), OBS_STATE: np.random.rand(100, 10),
} }
features = { features = {
"observation.image": {"dtype": "image"}, OBS_IMAGE: {"dtype": "image"},
"observation.state": {"dtype": "numeric"}, OBS_STATE: {"dtype": "numeric"},
} }
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy): with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
stats = compute_episode_stats(episode_data, features) stats = compute_episode_stats(episode_data, features)
assert "observation.image" in stats and "observation.state" in stats assert OBS_IMAGE in stats and OBS_STATE in stats
assert stats["observation.image"]["count"].item() == 100 assert stats[OBS_IMAGE]["count"].item() == 100
assert stats["observation.state"]["count"].item() == 100 assert stats[OBS_STATE]["count"].item() == 100
assert stats["observation.image"]["mean"].shape == (3, 1, 1) assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
def test_assert_type_and_shape_valid(): def test_assert_type_and_shape_valid():
@@ -224,38 +225,38 @@ def test_aggregate_feature_stats():
def test_aggregate_stats(): def test_aggregate_stats():
all_stats = [ all_stats = [
{ {
"observation.image": { OBS_IMAGE: {
"min": [1, 2, 3], "min": [1, 2, 3],
"max": [10, 20, 30], "max": [10, 20, 30],
"mean": [5.5, 10.5, 15.5], "mean": [5.5, 10.5, 15.5],
"std": [2.87, 5.87, 8.87], "std": [2.87, 5.87, 8.87],
"count": 10, "count": 10,
}, },
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10}, OBS_STATE: {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6}, "extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
}, },
{ {
"observation.image": { OBS_IMAGE: {
"min": [2, 1, 0], "min": [2, 1, 0],
"max": [15, 10, 5], "max": [15, 10, 5],
"mean": [8.5, 5.5, 2.5], "mean": [8.5, 5.5, 2.5],
"std": [3.42, 2.42, 1.42], "std": [3.42, 2.42, 1.42],
"count": 15, "count": 15,
}, },
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15}, OBS_STATE: {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5}, "extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
}, },
] ]
expected_agg_stats = { expected_agg_stats = {
"observation.image": { OBS_IMAGE: {
"min": [1, 1, 0], "min": [1, 1, 0],
"max": [15, 20, 30], "max": [15, 20, 30],
"mean": [7.3, 7.5, 7.7], "mean": [7.3, 7.5, 7.7],
"std": [3.5317, 4.8267, 8.5581], "std": [3.5317, 4.8267, 8.5581],
"count": 25, "count": 25,
}, },
"observation.state": { OBS_STATE: {
"min": 1, "min": 1,
"max": 15, "max": 15,
"mean": 7.3, "mean": 7.3,
@@ -283,7 +284,7 @@ def test_aggregate_stats():
for fkey, stats in ep_stats.items(): for fkey, stats in ep_stats.items():
for k in stats: for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
if fkey == "observation.image" and k != "count": if fkey == OBS_IMAGE and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
else: else:
stats[k] = stats[k].reshape(1) stats[k] = stats[k].reshape(1)
@@ -292,7 +293,7 @@ def test_aggregate_stats():
for fkey, stats in expected_agg_stats.items(): for fkey, stats in expected_agg_stats.items():
for k in stats: for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
if fkey == "observation.image" and k != "count": if fkey == OBS_IMAGE and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
else: else:
stats[k] = stats[k].reshape(1) stats[k] = stats[k].reshape(1)

View File

@@ -21,6 +21,7 @@ from huggingface_hub import DatasetCard
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
from lerobot.utils.constants import OBS_IMAGES
def test_default_parameters(): def test_default_parameters():
@@ -96,14 +97,14 @@ def test_merge_multiple_groups_order_and_dedup():
def test_non_vector_last_wins_for_images(): def test_non_vector_last_wins_for_images():
# Non-vector (images) with same name should be overwritten by the last image specified # Non-vector (images) with same name should be overwritten by the last image specified
g1 = { g1 = {
"observation.images.front": { f"{OBS_IMAGES}.front": {
"dtype": "image", "dtype": "image",
"shape": (3, 480, 640), "shape": (3, 480, 640),
"names": ["channels", "height", "width"], "names": ["channels", "height", "width"],
} }
} }
g2 = { g2 = {
"observation.images.front": { f"{OBS_IMAGES}.front": {
"dtype": "image", "dtype": "image",
"shape": (3, 720, 1280), "shape": (3, 720, 1280),
"names": ["channels", "height", "width"], "names": ["channels", "height", "width"],
@@ -111,8 +112,8 @@ def test_non_vector_last_wins_for_images():
} }
out = combine_feature_dicts(g1, g2) out = combine_feature_dicts(g1, g2)
assert out["observation.images.front"]["shape"] == (3, 720, 1280) assert out[f"{OBS_IMAGES}.front"]["shape"] == (3, 720, 1280)
assert out["observation.images.front"]["dtype"] == "image" assert out[f"{OBS_IMAGES}.front"]["dtype"] == "image"
def test_dtype_mismatch_raises(): def test_dtype_mismatch_raises():

View File

@@ -46,6 +46,7 @@ from lerobot.datasets.utils import (
from lerobot.envs.factory import make_env_config from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy_config from lerobot.policies.factory import make_policy_config
from lerobot.robots import make_robot_from_config from lerobot.robots import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.mocks.mock_robot import MockRobotConfig from tests.mocks.mock_robot import MockRobotConfig
from tests.utils import require_x86_64_kernel from tests.utils import require_x86_64_kernel
@@ -75,7 +76,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
# Instantiate both ways # Instantiate both ways
robot = make_robot_from_config(MockRobotConfig()) robot = make_robot_from_config(MockRobotConfig())
action_features = hw_to_dataset_features(robot.action_features, "action", True) action_features = hw_to_dataset_features(robot.action_features, "action", True)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True) obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR, True)
dataset_features = {**action_features, **obs_features} dataset_features = {**action_features, **obs_features}
root_create = tmp_path / "create" root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create( dataset_create = LeRobotDataset.create(
@@ -397,7 +398,7 @@ def test_factory(env_name, repo_id, policy_name):
("frame_index", 0, True), ("frame_index", 0, True),
("timestamp", 0, True), ("timestamp", 0, True),
# TODO(rcadene): should we rename it agent_pos? # TODO(rcadene): should we rename it agent_pos?
("observation.state", 1, True), (OBS_STATE, 1, True),
("next.reward", 0, False), ("next.reward", 0, False),
("next.done", 0, False), ("next.done", 0, False),
] ]
@@ -662,7 +663,7 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory): def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
"""Test the update_chunk_settings functionality for both LeRobotDataset and LeRobotDatasetMetadata.""" """Test the update_chunk_settings functionality for both LeRobotDataset and LeRobotDatasetMetadata."""
features = { features = {
"observation.state": { OBS_STATE: {
"dtype": "float32", "dtype": "float32",
"shape": (6,), "shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"], "names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"],
@@ -769,7 +770,7 @@ def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
def test_update_chunk_settings_video_dataset(tmp_path): def test_update_chunk_settings_video_dataset(tmp_path):
"""Test update_chunk_settings with a video dataset to ensure video-specific logic works.""" """Test update_chunk_settings with a video dataset to ensure video-specific logic works."""
features = { features = {
"observation.images.cam": { f"{OBS_IMAGES}.cam": {
"dtype": "video", "dtype": "video",
"shape": (480, 640, 3), "shape": (480, 640, 3),
"names": ["height", "width", "channels"], "names": ["height", "width", "channels"],

View File

@@ -19,6 +19,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE
from tests.utils import require_package from tests.utils import require_package
@@ -41,7 +42,7 @@ def test_binary_classifier_with_default_params():
config = RewardClassifierConfig() config = RewardClassifierConfig()
config.input_features = { config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
} }
config.output_features = { config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)), "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
@@ -56,7 +57,7 @@ def test_binary_classifier_with_default_params():
batch_size = 10 batch_size = 10
input = { input = {
"observation.image": torch.rand((batch_size, 3, 128, 128)), OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(), "next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
} }
@@ -83,7 +84,7 @@ def test_multiclass_classifier():
num_classes = 5 num_classes = 5
config = RewardClassifierConfig() config = RewardClassifierConfig()
config.input_features = { config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
} }
config.output_features = { config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)), "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
@@ -95,7 +96,7 @@ def test_multiclass_classifier():
batch_size = 10 batch_size = 10
input = { input = {
"observation.image": torch.rand((batch_size, 3, 128, 128)), OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.rand((batch_size, num_classes)), "next.reward": torch.rand((batch_size, num_classes)),
} }

View File

@@ -41,7 +41,7 @@ from lerobot.policies.factory import (
make_pre_post_processors, make_pre_post_processors,
) )
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.random_utils import seeded_context from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
@@ -52,7 +52,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
# Create only one camera input which is squared to fit all current policy constraints # Create only one camera input which is squared to fit all current policy constraints
# e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared # e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared
camera_features = { camera_features = {
"observation.images.laptop": { f"{OBS_IMAGES}.laptop": {
"shape": (84, 84, 3), "shape": (84, 84, 3),
"names": ["height", "width", "channels"], "names": ["height", "width", "channels"],
"info": None, "info": None,
@@ -64,7 +64,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
"shape": (6,), "shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
}, },
"observation.state": { OBS_STATE: {
"dtype": "float32", "dtype": "float32",
"shape": (6,), "shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
@@ -281,7 +281,7 @@ def test_multikey_construction(multikey: bool):
preventing erroneous creation of the policy object. preventing erroneous creation of the policy object.
""" """
input_features = { input_features = {
"observation.state": PolicyFeature( OBS_STATE: PolicyFeature(
type=FeatureType.STATE, type=FeatureType.STATE,
shape=(10,), shape=(10,),
), ),
@@ -297,9 +297,9 @@ def test_multikey_construction(multikey: bool):
"""Simulates the complete state/action is constructed from more granular multiple """Simulates the complete state/action is constructed from more granular multiple
keys, of the same type as the overall state/action""" keys, of the same type as the overall state/action"""
input_features = {} input_features = {}
input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) input_features[f"{OBS_STATE}.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) input_features[f"{OBS_STATE}.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,)) input_features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
output_features = {} output_features = {}
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,)) output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))

View File

@@ -25,6 +25,7 @@ from lerobot.policies.sac.configuration_sac import (
PolicyConfig, PolicyConfig,
SACConfig, SACConfig,
) )
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
def test_sac_config_default_initialization(): def test_sac_config_default_initialization():
@@ -37,11 +38,11 @@ def test_sac_config_default_initialization():
"ACTION": NormalizationMode.MIN_MAX, "ACTION": NormalizationMode.MIN_MAX,
} }
assert config.dataset_stats == { assert config.dataset_stats == {
"observation.image": { OBS_IMAGE: {
"mean": [0.485, 0.456, 0.406], "mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225], "std": [0.229, 0.224, 0.225],
}, },
"observation.state": { OBS_STATE: {
"min": [0.0, 0.0], "min": [0.0, 0.0],
"max": [1.0, 1.0], "max": [1.0, 1.0],
}, },
@@ -90,11 +91,11 @@ def test_sac_config_default_initialization():
# Dataset stats defaults # Dataset stats defaults
expected_dataset_stats = { expected_dataset_stats = {
"observation.image": { OBS_IMAGE: {
"mean": [0.485, 0.456, 0.406], "mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225], "std": [0.229, 0.224, 0.225],
}, },
"observation.state": { OBS_STATE: {
"min": [0.0, 0.0], "min": [0.0, 0.0],
"max": [1.0, 1.0], "max": [1.0, 1.0],
}, },
@@ -191,7 +192,7 @@ def test_sac_config_custom_initialization():
def test_validate_features(): def test_validate_features():
config = SACConfig( config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
) )
config.validate_features() config.validate_features()
@@ -210,7 +211,7 @@ def test_validate_features_missing_observation():
def test_validate_features_missing_action(): def test_validate_features_missing_action():
config = SACConfig( config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
) )
with pytest.raises(ValueError, match="You must provide 'action' in the output features"): with pytest.raises(ValueError, match="You must provide 'action' in the output features"):

View File

@@ -23,6 +23,7 @@ from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed from lerobot.utils.random_utils import seeded_context, set_seed
try: try:
@@ -85,14 +86,14 @@ def test_sac_policy_with_default_args():
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor: def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
return { return {
"observation.state": torch.randn(batch_size, state_dim), OBS_STATE: torch.randn(batch_size, state_dim),
} }
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor: def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
return { return {
"observation.image": torch.randn(batch_size, 3, 84, 84), OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
"observation.state": torch.randn(batch_size, state_dim), OBS_STATE: torch.randn(batch_size, state_dim),
} }
@@ -126,14 +127,14 @@ def create_train_batch_with_visual_input(
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return { return {
"observation.state": torch.randn(batch_size, state_dim), OBS_STATE: torch.randn(batch_size, state_dim),
} }
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return { return {
"observation.state": torch.randn(batch_size, state_dim), OBS_STATE: torch.randn(batch_size, state_dim),
"observation.image": torch.randn(batch_size, 3, 84, 84), OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
} }
@@ -180,10 +181,10 @@ def create_default_config(
action_dim += 1 action_dim += 1
config = SACConfig( config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={ dataset_stats={
"observation.state": { OBS_STATE: {
"min": [0.0] * state_dim, "min": [0.0] * state_dim,
"max": [1.0] * state_dim, "max": [1.0] * state_dim,
}, },
@@ -205,8 +206,8 @@ def create_config_with_visual_input(
continuous_action_dim=continuous_action_dim, continuous_action_dim=continuous_action_dim,
has_discrete_action=has_discrete_action, has_discrete_action=has_discrete_action,
) )
config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats["observation.image"] = { config.dataset_stats[OBS_IMAGE] = {
"mean": torch.randn(3, 1, 1), "mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1), "std": torch.randn(3, 1, 1),
} }

View File

@@ -342,7 +342,7 @@ def test_act_processor_batch_consistency():
batch = transition_to_batch(transition) batch = transition_to_batch(transition)
processed = preprocessor(batch) processed = preprocessor(batch)
assert processed["observation.state"].shape[0] == 1 # Batched assert processed[OBS_STATE].shape[0] == 1 # Batched
# Test already batched data # Test already batched data
observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8 observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8

View File

@@ -2,14 +2,15 @@ import torch
from lerobot.processor import DataProcessorPipeline, TransitionKey from lerobot.processor import DataProcessorPipeline, TransitionKey
from lerobot.processor.converters import batch_to_transition, transition_to_batch from lerobot.processor.converters import batch_to_transition, transition_to_batch
from lerobot.utils.constants import OBS_IMAGE, OBS_PREFIX, OBS_STATE
def _dummy_batch(): def _dummy_batch():
"""Create a dummy batch using the new format with observation.* and next.* keys.""" """Create a dummy batch using the new format with observation.* and next.* keys."""
return { return {
"observation.image.left": torch.randn(1, 3, 128, 128), f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
"observation.image.right": torch.randn(1, 3, 128, 128), f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128),
"observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]), OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
"action": torch.tensor([[0.5]]), "action": torch.tensor([[0.5]]),
"next.reward": 1.0, "next.reward": 1.0,
"next.done": False, "next.done": False,
@@ -25,15 +26,15 @@ def test_observation_grouping_roundtrip():
batch_out = proc(batch_in) batch_out = proc(batch_in)
# Check that all observation.* keys are preserved # Check that all observation.* keys are preserved
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")} original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith(OBS_PREFIX)}
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")} reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith(OBS_PREFIX)}
assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys()) assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys())
# Check tensor values # Check tensor values
assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"]) assert torch.allclose(batch_out[f"{OBS_IMAGE}.left"], batch_in[f"{OBS_IMAGE}.left"])
assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"]) assert torch.allclose(batch_out[f"{OBS_IMAGE}.right"], batch_in[f"{OBS_IMAGE}.right"])
assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"]) assert torch.allclose(batch_out[OBS_STATE], batch_in[OBS_STATE])
# Check other fields # Check other fields
assert torch.allclose(batch_out["action"], batch_in["action"]) assert torch.allclose(batch_out["action"], batch_in["action"])
@@ -46,9 +47,9 @@ def test_observation_grouping_roundtrip():
def test_batch_to_transition_observation_grouping(): def test_batch_to_transition_observation_grouping():
"""Test that batch_to_transition correctly groups observation.* keys.""" """Test that batch_to_transition correctly groups observation.* keys."""
batch = { batch = {
"observation.image.top": torch.randn(1, 3, 128, 128), f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128), f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
"observation.state": [1, 2, 3, 4], OBS_STATE: [1, 2, 3, 4],
"action": torch.tensor([0.1, 0.2, 0.3, 0.4]), "action": torch.tensor([0.1, 0.2, 0.3, 0.4]),
"next.reward": 1.5, "next.reward": 1.5,
"next.done": True, "next.done": True,
@@ -60,18 +61,18 @@ def test_batch_to_transition_observation_grouping():
# Check observation is a dict with all observation.* keys # Check observation is a dict with all observation.* keys
assert isinstance(transition[TransitionKey.OBSERVATION], dict) assert isinstance(transition[TransitionKey.OBSERVATION], dict)
assert "observation.image.top" in transition[TransitionKey.OBSERVATION] assert f"{OBS_IMAGE}.top" in transition[TransitionKey.OBSERVATION]
assert "observation.image.left" in transition[TransitionKey.OBSERVATION] assert f"{OBS_IMAGE}.left" in transition[TransitionKey.OBSERVATION]
assert "observation.state" in transition[TransitionKey.OBSERVATION] assert OBS_STATE in transition[TransitionKey.OBSERVATION]
# Check values are preserved # Check values are preserved
assert torch.allclose( assert torch.allclose(
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"] transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.top"], batch[f"{OBS_IMAGE}.top"]
) )
assert torch.allclose( assert torch.allclose(
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"] transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.left"], batch[f"{OBS_IMAGE}.left"]
) )
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] assert transition[TransitionKey.OBSERVATION][OBS_STATE] == [1, 2, 3, 4]
# Check other fields # Check other fields
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4])) assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4]))
@@ -85,9 +86,9 @@ def test_batch_to_transition_observation_grouping():
def test_transition_to_batch_observation_flattening(): def test_transition_to_batch_observation_flattening():
"""Test that transition_to_batch correctly flattens observation dict.""" """Test that transition_to_batch correctly flattens observation dict."""
observation_dict = { observation_dict = {
"observation.image.top": torch.randn(1, 3, 128, 128), f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128), f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
"observation.state": [1, 2, 3, 4], OBS_STATE: [1, 2, 3, 4],
} }
transition = { transition = {
@@ -103,14 +104,14 @@ def test_transition_to_batch_observation_flattening():
batch = transition_to_batch(transition) batch = transition_to_batch(transition)
# Check that observation.* keys are flattened back to batch # Check that observation.* keys are flattened back to batch
assert "observation.image.top" in batch assert f"{OBS_IMAGE}.top" in batch
assert "observation.image.left" in batch assert f"{OBS_IMAGE}.left" in batch
assert "observation.state" in batch assert OBS_STATE in batch
# Check values are preserved # Check values are preserved
assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"]) assert torch.allclose(batch[f"{OBS_IMAGE}.top"], observation_dict[f"{OBS_IMAGE}.top"])
assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"]) assert torch.allclose(batch[f"{OBS_IMAGE}.left"], observation_dict[f"{OBS_IMAGE}.left"])
assert batch["observation.state"] == [1, 2, 3, 4] assert batch[OBS_STATE] == [1, 2, 3, 4]
# Check other fields are mapped to next.* format # Check other fields are mapped to next.* format
assert batch["action"] == "action_data" assert batch["action"] == "action_data"
@@ -153,12 +154,12 @@ def test_no_observation_keys():
def test_minimal_batch(): def test_minimal_batch():
"""Test with minimal batch containing only observation.* and action.""" """Test with minimal batch containing only observation.* and action."""
batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])} batch = {OBS_STATE: "minimal_state", "action": torch.tensor([0.5])}
transition = batch_to_transition(batch) transition = batch_to_transition(batch)
# Check observation # Check observation
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} assert transition[TransitionKey.OBSERVATION] == {OBS_STATE: "minimal_state"}
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5])) assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5]))
# Check defaults # Check defaults
@@ -170,7 +171,7 @@ def test_minimal_batch():
# Round trip # Round trip
reconstructed_batch = transition_to_batch(transition) reconstructed_batch = transition_to_batch(transition)
assert reconstructed_batch["observation.state"] == "minimal_state" assert reconstructed_batch[OBS_STATE] == "minimal_state"
assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5])) assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5]))
assert reconstructed_batch["next.reward"] == 0.0 assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"] assert not reconstructed_batch["next.done"]
@@ -205,9 +206,9 @@ def test_empty_batch():
def test_complex_nested_observation(): def test_complex_nested_observation():
"""Test with complex nested observation data.""" """Test with complex nested observation data."""
batch = { batch = {
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, f"{OBS_IMAGE}.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
"observation.state": torch.randn(7), OBS_STATE: torch.randn(7),
"action": torch.randn(8), "action": torch.randn(8),
"next.reward": 3.14, "next.reward": 3.14,
"next.done": False, "next.done": False,
@@ -219,20 +220,20 @@ def test_complex_nested_observation():
reconstructed_batch = transition_to_batch(transition) reconstructed_batch = transition_to_batch(transition)
# Check that all observation keys are preserved # Check that all observation keys are preserved
original_obs_keys = {k for k in batch if k.startswith("observation.")} original_obs_keys = {k for k in batch if k.startswith(OBS_PREFIX)}
reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")} reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith(OBS_PREFIX)}
assert original_obs_keys == reconstructed_obs_keys assert original_obs_keys == reconstructed_obs_keys
# Check tensor values # Check tensor values
assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"]) assert torch.allclose(batch[OBS_STATE], reconstructed_batch[OBS_STATE])
# Check nested dict with tensors # Check nested dict with tensors
assert torch.allclose( assert torch.allclose(
batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"] batch[f"{OBS_IMAGE}.top"]["image"], reconstructed_batch[f"{OBS_IMAGE}.top"]["image"]
) )
assert torch.allclose( assert torch.allclose(
batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"] batch[f"{OBS_IMAGE}.left"]["image"], reconstructed_batch[f"{OBS_IMAGE}.left"]["image"]
) )
# Check action tensor # Check action tensor
@@ -264,7 +265,7 @@ def test_custom_converter():
processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch) processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch)
batch = { batch = {
"observation.state": torch.randn(1, 4), OBS_STATE: torch.randn(1, 4),
"action": torch.randn(1, 2), "action": torch.randn(1, 2),
"next.reward": 1.0, "next.reward": 1.0,
"next.done": False, "next.done": False,
@@ -274,5 +275,5 @@ def test_custom_converter():
# Check the reward was doubled by our custom converter # Check the reward was doubled by our custom converter
assert result["next.reward"] == 2.0 assert result["next.reward"] == 2.0
assert torch.allclose(result["observation.state"], batch["observation.state"]) assert torch.allclose(result[OBS_STATE], batch[OBS_STATE])
assert torch.allclose(result["action"], batch["action"]) assert torch.allclose(result["action"], batch["action"])

View File

@@ -9,6 +9,7 @@ from lerobot.processor.converters import (
to_tensor, to_tensor,
transition_to_batch, transition_to_batch,
) )
from lerobot.utils.constants import OBS_STATE, OBS_STR
# Tests for the unified to_tensor function # Tests for the unified to_tensor function
@@ -118,16 +119,16 @@ def test_to_tensor_dictionaries():
# Nested dictionary # Nested dictionary
nested = { nested = {
"action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]}, "action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
"observation": {"mean": np.array([0.5, 0.6]), "count": 10}, OBS_STR: {"mean": np.array([0.5, 0.6]), "count": 10},
} }
result = to_tensor(nested) result = to_tensor(nested)
assert isinstance(result, dict) assert isinstance(result, dict)
assert isinstance(result["action"], dict) assert isinstance(result["action"], dict)
assert isinstance(result["observation"], dict) assert isinstance(result[OBS_STR], dict)
assert isinstance(result["action"]["mean"], torch.Tensor) assert isinstance(result["action"]["mean"], torch.Tensor)
assert isinstance(result["observation"]["mean"], torch.Tensor) assert isinstance(result[OBS_STR]["mean"], torch.Tensor)
assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2])) assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2]))
assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6])) assert torch.allclose(result[OBS_STR]["mean"], torch.tensor([0.5, 0.6]))
def test_to_tensor_none_filtering(): def test_to_tensor_none_filtering():
@@ -198,7 +199,7 @@ def test_batch_to_transition_with_index_fields():
# Create batch with index and task_index fields # Create batch with index and task_index fields
batch = { batch = {
"observation.state": torch.randn(1, 7), OBS_STATE: torch.randn(1, 7),
"action": torch.randn(1, 4), "action": torch.randn(1, 4),
"next.reward": 1.5, "next.reward": 1.5,
"next.done": False, "next.done": False,
@@ -231,7 +232,7 @@ def testtransition_to_batch_with_index_fields():
# Create transition with index and task_index in complementary_data # Create transition with index and task_index in complementary_data
transition = create_transition( transition = create_transition(
observation={"observation.state": torch.randn(1, 7)}, observation={OBS_STATE: torch.randn(1, 7)},
action=torch.randn(1, 4), action=torch.randn(1, 4),
reward=1.5, reward=1.5,
done=False, done=False,
@@ -260,7 +261,7 @@ def test_batch_to_transition_without_index_fields():
# Batch without index/task_index # Batch without index/task_index
batch = { batch = {
"observation.state": torch.randn(1, 7), OBS_STATE: torch.randn(1, 7),
"action": torch.randn(1, 4), "action": torch.randn(1, 4),
"task": ["pick_cube"], "task": ["pick_cube"],
} }
@@ -279,7 +280,7 @@ def test_transition_to_batch_without_index_fields():
# Transition without index/task_index # Transition without index/task_index
transition = create_transition( transition = create_transition(
observation={"observation.state": torch.randn(1, 7)}, observation={OBS_STATE: torch.randn(1, 7)},
action=torch.randn(1, 4), action=torch.randn(1, 4),
complementary_data={"task": ["navigate"]}, complementary_data={"task": ["navigate"]},
) )

View File

@@ -21,6 +21,7 @@ import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey
from lerobot.processor.converters import create_transition, identity_transition from lerobot.processor.converters import create_transition, identity_transition
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
def test_basic_functionality(): def test_basic_functionality():
@@ -28,7 +29,7 @@ def test_basic_functionality():
processor = DeviceProcessorStep(device="cpu") processor = DeviceProcessorStep(device="cpu")
# Create a transition with CPU tensors # Create a transition with CPU tensors
observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} observation = {OBS_STATE: torch.randn(10), OBS_IMAGE: torch.randn(3, 224, 224)}
action = torch.randn(5) action = torch.randn(5)
reward = torch.tensor(1.0) reward = torch.tensor(1.0)
done = torch.tensor(False) done = torch.tensor(False)
@@ -41,8 +42,8 @@ def test_basic_functionality():
result = processor(transition) result = processor(transition)
# Check that all tensors are on CPU # Check that all tensors are on CPU
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cpu" assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cpu"
assert result[TransitionKey.ACTION].device.type == "cpu" assert result[TransitionKey.ACTION].device.type == "cpu"
assert result[TransitionKey.REWARD].device.type == "cpu" assert result[TransitionKey.REWARD].device.type == "cpu"
assert result[TransitionKey.DONE].device.type == "cpu" assert result[TransitionKey.DONE].device.type == "cpu"
@@ -55,7 +56,7 @@ def test_cuda_functionality():
processor = DeviceProcessorStep(device="cuda") processor = DeviceProcessorStep(device="cuda")
# Create a transition with CPU tensors # Create a transition with CPU tensors
observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} observation = {OBS_STATE: torch.randn(10), OBS_IMAGE: torch.randn(3, 224, 224)}
action = torch.randn(5) action = torch.randn(5)
reward = torch.tensor(1.0) reward = torch.tensor(1.0)
done = torch.tensor(False) done = torch.tensor(False)
@@ -68,8 +69,8 @@ def test_cuda_functionality():
result = processor(transition) result = processor(transition)
# Check that all tensors are on CUDA # Check that all tensors are on CUDA
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda" assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda"
assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda"
assert result[TransitionKey.REWARD].device.type == "cuda" assert result[TransitionKey.REWARD].device.type == "cuda"
assert result[TransitionKey.DONE].device.type == "cuda" assert result[TransitionKey.DONE].device.type == "cuda"
@@ -81,14 +82,14 @@ def test_specific_cuda_device():
"""Test device processor with specific CUDA device.""" """Test device processor with specific CUDA device."""
processor = DeviceProcessorStep(device="cuda:0") processor = DeviceProcessorStep(device="cuda:0")
observation = {"observation.state": torch.randn(10)} observation = {OBS_STATE: torch.randn(10)}
action = torch.randn(5) action = torch.randn(5)
transition = create_transition(observation=observation, action=action) transition = create_transition(observation=observation, action=action)
result = processor(transition) result = processor(transition)
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
assert result[TransitionKey.OBSERVATION]["observation.state"].device.index == 0 assert result[TransitionKey.OBSERVATION][OBS_STATE].device.index == 0
assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda"
assert result[TransitionKey.ACTION].device.index == 0 assert result[TransitionKey.ACTION].device.index == 0
@@ -98,7 +99,7 @@ def test_non_tensor_values():
processor = DeviceProcessorStep(device="cpu") processor = DeviceProcessorStep(device="cpu")
observation = { observation = {
"observation.state": torch.randn(10), OBS_STATE: torch.randn(10),
"observation.metadata": {"key": "value"}, # Non-tensor data "observation.metadata": {"key": "value"}, # Non-tensor data
"observation.list": [1, 2, 3], # Non-tensor data "observation.list": [1, 2, 3], # Non-tensor data
} }
@@ -110,7 +111,7 @@ def test_non_tensor_values():
result = processor(transition) result = processor(transition)
# Check tensors are processed # Check tensors are processed
assert isinstance(result[TransitionKey.OBSERVATION]["observation.state"], torch.Tensor) assert isinstance(result[TransitionKey.OBSERVATION][OBS_STATE], torch.Tensor)
assert isinstance(result[TransitionKey.ACTION], torch.Tensor) assert isinstance(result[TransitionKey.ACTION], torch.Tensor)
# Check non-tensor values are preserved # Check non-tensor values are preserved
@@ -130,9 +131,9 @@ def test_none_values():
assert result[TransitionKey.ACTION].device.type == "cpu" assert result[TransitionKey.ACTION].device.type == "cpu"
# Test with None action # Test with None action
transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None) transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=None)
result = processor(transition) result = processor(transition)
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
assert result[TransitionKey.ACTION] is None assert result[TransitionKey.ACTION] is None
@@ -271,9 +272,7 @@ def test_features():
processor = DeviceProcessorStep(device="cpu") processor = DeviceProcessorStep(device="cpu")
features = { features = {
PipelineFeatureType.OBSERVATION: { PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))
},
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
} }
@@ -376,7 +375,7 @@ def test_reward_done_truncated_types():
# Test with scalar values (not tensors) # Test with scalar values (not tensors)
transition = create_transition( transition = create_transition(
observation={"observation.state": torch.randn(5)}, observation={OBS_STATE: torch.randn(5)},
action=torch.randn(3), action=torch.randn(3),
reward=1.0, # float reward=1.0, # float
done=False, # bool done=False, # bool
@@ -392,7 +391,7 @@ def test_reward_done_truncated_types():
# Test with tensor values # Test with tensor values
transition = create_transition( transition = create_transition(
observation={"observation.state": torch.randn(5)}, observation={OBS_STATE: torch.randn(5)},
action=torch.randn(3), action=torch.randn(3),
reward=torch.tensor(1.0), reward=torch.tensor(1.0),
done=torch.tensor(False), done=torch.tensor(False),
@@ -422,7 +421,7 @@ def test_complementary_data_preserved():
} }
transition = create_transition( transition = create_transition(
observation={"observation.state": torch.randn(5)}, complementary_data=complementary_data observation={OBS_STATE: torch.randn(5)}, complementary_data=complementary_data
) )
result = processor(transition) result = processor(transition)
@@ -491,13 +490,13 @@ def test_float_dtype_bfloat16():
"""Test conversion to bfloat16.""" """Test conversion to bfloat16."""
processor = DeviceProcessorStep(device="cpu", float_dtype="bfloat16") processor = DeviceProcessorStep(device="cpu", float_dtype="bfloat16")
observation = {"observation.state": torch.randn(5, dtype=torch.float32)} observation = {OBS_STATE: torch.randn(5, dtype=torch.float32)}
action = torch.randn(3, dtype=torch.float64) action = torch.randn(3, dtype=torch.float64)
transition = create_transition(observation=observation, action=action) transition = create_transition(observation=observation, action=action)
result = processor(transition) result = processor(transition)
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.bfloat16 assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert result[TransitionKey.ACTION].dtype == torch.bfloat16 assert result[TransitionKey.ACTION].dtype == torch.bfloat16
@@ -505,13 +504,13 @@ def test_float_dtype_float64():
"""Test conversion to float64.""" """Test conversion to float64."""
processor = DeviceProcessorStep(device="cpu", float_dtype="float64") processor = DeviceProcessorStep(device="cpu", float_dtype="float64")
observation = {"observation.state": torch.randn(5, dtype=torch.float16)} observation = {OBS_STATE: torch.randn(5, dtype=torch.float16)}
action = torch.randn(3, dtype=torch.float32) action = torch.randn(3, dtype=torch.float32)
transition = create_transition(observation=observation, action=action) transition = create_transition(observation=observation, action=action)
result = processor(transition) result = processor(transition)
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float64 assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float64
assert result[TransitionKey.ACTION].dtype == torch.float64 assert result[TransitionKey.ACTION].dtype == torch.float64
@@ -541,8 +540,8 @@ def test_float_dtype_with_mixed_tensors():
processor = DeviceProcessorStep(device="cpu", float_dtype="float32") processor = DeviceProcessorStep(device="cpu", float_dtype="float32")
observation = { observation = {
"observation.image": torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert OBS_IMAGE: torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert
"observation.state": torch.randn(10, dtype=torch.float64), # Should convert OBS_STATE: torch.randn(10, dtype=torch.float64), # Should convert
"observation.mask": torch.tensor([True, False, True], dtype=torch.bool), # Should not convert "observation.mask": torch.tensor([True, False, True], dtype=torch.bool), # Should not convert
"observation.indices": torch.tensor([1, 2, 3], dtype=torch.long), # Should not convert "observation.indices": torch.tensor([1, 2, 3], dtype=torch.long), # Should not convert
} }
@@ -552,8 +551,8 @@ def test_float_dtype_with_mixed_tensors():
result = processor(transition) result = processor(transition)
# Check conversions # Check conversions
assert result[TransitionKey.OBSERVATION]["observation.image"].dtype == torch.uint8 # Unchanged assert result[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.uint8 # Unchanged
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float32 # Converted
assert result[TransitionKey.OBSERVATION]["observation.mask"].dtype == torch.bool # Unchanged assert result[TransitionKey.OBSERVATION]["observation.mask"].dtype == torch.bool # Unchanged
assert result[TransitionKey.OBSERVATION]["observation.indices"].dtype == torch.long # Unchanged assert result[TransitionKey.OBSERVATION]["observation.indices"].dtype == torch.long # Unchanged
assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted
@@ -612,7 +611,7 @@ def test_complementary_data_index_fields():
"episode_id": 123, # Non-tensor field "episode_id": 123, # Non-tensor field
} }
transition = create_transition( transition = create_transition(
observation={"observation.state": torch.randn(1, 7)}, observation={OBS_STATE: torch.randn(1, 7)},
action=torch.randn(1, 4), action=torch.randn(1, 4),
complementary_data=complementary_data, complementary_data=complementary_data,
) )
@@ -736,7 +735,7 @@ def test_complementary_data_full_pipeline_cuda():
processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16") processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16")
# Create full transition with mixed CPU tensors # Create full transition with mixed CPU tensors
observation = {"observation.state": torch.randn(1, 7, dtype=torch.float32)} observation = {OBS_STATE: torch.randn(1, 7, dtype=torch.float32)}
action = torch.randn(1, 4, dtype=torch.float32) action = torch.randn(1, 4, dtype=torch.float32)
reward = torch.tensor(1.5, dtype=torch.float32) reward = torch.tensor(1.5, dtype=torch.float32)
done = torch.tensor(False) done = torch.tensor(False)
@@ -757,7 +756,7 @@ def test_complementary_data_full_pipeline_cuda():
result = processor(transition) result = processor(transition)
# Check all components moved to CUDA # Check all components moved to CUDA
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda"
assert result[TransitionKey.REWARD].device.type == "cuda" assert result[TransitionKey.REWARD].device.type == "cuda"
assert result[TransitionKey.DONE].device.type == "cuda" assert result[TransitionKey.DONE].device.type == "cuda"
@@ -768,7 +767,7 @@ def test_complementary_data_full_pipeline_cuda():
assert processed_comp_data["task_index"].device.type == "cuda" assert processed_comp_data["task_index"].device.type == "cuda"
# Check float conversion happened for float tensors # Check float conversion happened for float tensors
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float16 assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16
assert result[TransitionKey.ACTION].dtype == torch.float16 assert result[TransitionKey.ACTION].dtype == torch.float16
assert result[TransitionKey.REWARD].dtype == torch.float16 assert result[TransitionKey.REWARD].dtype == torch.float16
@@ -782,7 +781,7 @@ def test_complementary_data_empty():
processor = DeviceProcessorStep(device="cpu") processor = DeviceProcessorStep(device="cpu")
transition = create_transition( transition = create_transition(
observation={"observation.state": torch.randn(1, 7)}, observation={OBS_STATE: torch.randn(1, 7)},
complementary_data={}, complementary_data={},
) )
@@ -797,7 +796,7 @@ def test_complementary_data_none():
processor = DeviceProcessorStep(device="cpu") processor = DeviceProcessorStep(device="cpu")
transition = create_transition( transition = create_transition(
observation={"observation.state": torch.randn(1, 7)}, observation={OBS_STATE: torch.randn(1, 7)},
complementary_data=None, complementary_data=None,
) )
@@ -814,8 +813,8 @@ def test_preserves_gpu_placement():
# Create tensors already on GPU # Create tensors already on GPU
observation = { observation = {
"observation.state": torch.randn(10).cuda(), # Already on GPU OBS_STATE: torch.randn(10).cuda(), # Already on GPU
"observation.image": torch.randn(3, 224, 224).cuda(), # Already on GPU OBS_IMAGE: torch.randn(3, 224, 224).cuda(), # Already on GPU
} }
action = torch.randn(5).cuda() # Already on GPU action = torch.randn(5).cuda() # Already on GPU
@@ -823,14 +822,12 @@ def test_preserves_gpu_placement():
result = processor(transition) result = processor(transition)
# Check that tensors remain on their original GPU # Check that tensors remain on their original GPU
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda" assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda"
assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda"
# Verify no unnecessary copies were made (same data pointer) # Verify no unnecessary copies were made (same data pointer)
assert torch.equal( assert torch.equal(result[TransitionKey.OBSERVATION][OBS_STATE], observation[OBS_STATE])
result[TransitionKey.OBSERVATION]["observation.state"], observation["observation.state"]
)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
@@ -842,8 +839,8 @@ def test_multi_gpu_preservation():
# Create tensors on cuda:1 (simulating Accelerate placement) # Create tensors on cuda:1 (simulating Accelerate placement)
cuda1_device = torch.device("cuda:1") cuda1_device = torch.device("cuda:1")
observation = { observation = {
"observation.state": torch.randn(10).to(cuda1_device), OBS_STATE: torch.randn(10).to(cuda1_device),
"observation.image": torch.randn(3, 224, 224).to(cuda1_device), OBS_IMAGE: torch.randn(3, 224, 224).to(cuda1_device),
} }
action = torch.randn(5).to(cuda1_device) action = torch.randn(5).to(cuda1_device)
@@ -851,20 +848,20 @@ def test_multi_gpu_preservation():
result = processor_gpu(transition) result = processor_gpu(transition)
# Check that tensors remain on cuda:1 (not moved to cuda:0) # Check that tensors remain on cuda:1 (not moved to cuda:0)
assert result[TransitionKey.OBSERVATION]["observation.state"].device == cuda1_device assert result[TransitionKey.OBSERVATION][OBS_STATE].device == cuda1_device
assert result[TransitionKey.OBSERVATION]["observation.image"].device == cuda1_device assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device == cuda1_device
assert result[TransitionKey.ACTION].device == cuda1_device assert result[TransitionKey.ACTION].device == cuda1_device
# Test 2: GPU-to-CPU should move to CPU (not preserve GPU) # Test 2: GPU-to-CPU should move to CPU (not preserve GPU)
processor_cpu = DeviceProcessorStep(device="cpu") processor_cpu = DeviceProcessorStep(device="cpu")
transition_gpu = create_transition( transition_gpu = create_transition(
observation={"observation.state": torch.randn(10).cuda()}, action=torch.randn(5).cuda() observation={OBS_STATE: torch.randn(10).cuda()}, action=torch.randn(5).cuda()
) )
result_cpu = processor_cpu(transition_gpu) result_cpu = processor_cpu(transition_gpu)
# Check that tensors are moved to CPU # Check that tensors are moved to CPU
assert result_cpu[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" assert result_cpu[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
assert result_cpu[TransitionKey.ACTION].device.type == "cpu" assert result_cpu[TransitionKey.ACTION].device.type == "cpu"
@@ -933,14 +930,14 @@ def test_simulated_accelerate_scenario():
# Simulate data already placed by Accelerate # Simulate data already placed by Accelerate
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
observation = {"observation.state": torch.randn(1, 10).to(device)} observation = {OBS_STATE: torch.randn(1, 10).to(device)}
action = torch.randn(1, 5).to(device) action = torch.randn(1, 5).to(device)
transition = create_transition(observation=observation, action=action) transition = create_transition(observation=observation, action=action)
result = processor(transition) result = processor(transition)
# Verify data stays on the GPU where Accelerate placed it # Verify data stays on the GPU where Accelerate placed it
assert result[TransitionKey.OBSERVATION]["observation.state"].device == device assert result[TransitionKey.OBSERVATION][OBS_STATE].device == device
assert result[TransitionKey.ACTION].device == device assert result[TransitionKey.ACTION].device == device
@@ -1081,7 +1078,7 @@ def test_mps_float64_with_complementary_data():
} }
transition = create_transition( transition = create_transition(
observation={"observation.state": torch.randn(5, dtype=torch.float64)}, observation={OBS_STATE: torch.randn(5, dtype=torch.float64)},
action=torch.randn(3, dtype=torch.float64), action=torch.randn(3, dtype=torch.float64),
complementary_data=complementary_data, complementary_data=complementary_data,
) )
@@ -1089,7 +1086,7 @@ def test_mps_float64_with_complementary_data():
result = processor(transition) result = processor(transition)
# Check that all tensors are on MPS device # Check that all tensors are on MPS device
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "mps" assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "mps"
assert result[TransitionKey.ACTION].device.type == "mps" assert result[TransitionKey.ACTION].device.type == "mps"
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
@@ -1099,7 +1096,7 @@ def test_mps_float64_with_complementary_data():
assert processed_comp_data["float32_tensor"].device.type == "mps" assert processed_comp_data["float32_tensor"].device.type == "mps"
# Check dtype conversions # Check dtype conversions
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float32 # Converted
assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted
assert processed_comp_data["float64_tensor"].dtype == torch.float32 # Converted assert processed_comp_data["float64_tensor"].dtype == torch.float32 # Converted
assert processed_comp_data["float32_tensor"].dtype == torch.float32 # Unchanged assert processed_comp_data["float32_tensor"].dtype == torch.float32 # Unchanged

View File

@@ -25,6 +25,7 @@ from pathlib import Path
import pytest import pytest
from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError
from lerobot.utils.constants import OBS_STATE
def test_is_processor_config_valid_configs(): def test_is_processor_config_valid_configs():
@@ -111,7 +112,7 @@ def test_should_suggest_migration_with_model_config_only():
# Create a model config (like old LeRobot format) # Create a model config (like old LeRobot format)
model_config = { model_config = {
"type": "act", "type": "act",
"input_features": {"observation.state": {"shape": [7]}}, "input_features": {OBS_STATE: {"shape": [7]}},
"output_features": {"action": {"shape": [7]}}, "output_features": {"action": {"shape": [7]}},
"hidden_dim": 256, "hidden_dim": 256,
"n_obs_steps": 1, "n_obs_steps": 1,

File diff suppressed because it is too large Load Diff

View File

@@ -39,8 +39,8 @@ def test_process_single_image():
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
# Check that the image was processed correctly # Check that the image was processed correctly
assert "observation.image" in processed_obs assert OBS_IMAGE in processed_obs
processed_img = processed_obs["observation.image"] processed_img = processed_obs[OBS_IMAGE]
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width # Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
assert processed_img.shape == (1, 3, 64, 64) assert processed_img.shape == (1, 3, 64, 64)
@@ -66,12 +66,12 @@ def test_process_image_dict():
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
# Check that both images were processed # Check that both images were processed
assert "observation.images.camera1" in processed_obs assert f"{OBS_IMAGES}.camera1" in processed_obs
assert "observation.images.camera2" in processed_obs assert f"{OBS_IMAGES}.camera2" in processed_obs
# Check shapes # Check shapes
assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32) assert processed_obs[f"{OBS_IMAGES}.camera1"].shape == (1, 3, 32, 32)
assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48) assert processed_obs[f"{OBS_IMAGES}.camera2"].shape == (1, 3, 48, 48)
def test_process_batched_image(): def test_process_batched_image():
@@ -88,7 +88,7 @@ def test_process_batched_image():
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimension is preserved # Check that batch dimension is preserved
assert processed_obs["observation.image"].shape == (2, 3, 64, 64) assert processed_obs[OBS_IMAGE].shape == (2, 3, 64, 64)
def test_invalid_image_format(): def test_invalid_image_format():
@@ -173,10 +173,10 @@ def test_process_environment_state():
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
# Check that environment_state was renamed and processed # Check that environment_state was renamed and processed
assert "observation.environment_state" in processed_obs assert OBS_ENV_STATE in processed_obs
assert "environment_state" not in processed_obs assert "environment_state" not in processed_obs
processed_state = processed_obs["observation.environment_state"] processed_state = processed_obs[OBS_ENV_STATE]
assert processed_state.shape == (1, 3) # Batch dimension added assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32 assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]])) torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
@@ -194,10 +194,10 @@ def test_process_agent_pos():
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
# Check that agent_pos was renamed and processed # Check that agent_pos was renamed and processed
assert "observation.state" in processed_obs assert OBS_STATE in processed_obs
assert "agent_pos" not in processed_obs assert "agent_pos" not in processed_obs
processed_state = processed_obs["observation.state"] processed_state = processed_obs[OBS_STATE]
assert processed_state.shape == (1, 3) # Batch dimension added assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32 assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]])) torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
@@ -217,8 +217,8 @@ def test_process_batched_states():
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimensions are preserved # Check that batch dimensions are preserved
assert processed_obs["observation.environment_state"].shape == (2, 2) assert processed_obs[OBS_ENV_STATE].shape == (2, 2)
assert processed_obs["observation.state"].shape == (2, 2) assert processed_obs[OBS_STATE].shape == (2, 2)
def test_process_both_states(): def test_process_both_states():
@@ -235,8 +235,8 @@ def test_process_both_states():
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
# Check that both states were processed # Check that both states were processed
assert "observation.environment_state" in processed_obs assert OBS_ENV_STATE in processed_obs
assert "observation.state" in processed_obs assert OBS_STATE in processed_obs
# Check that original keys were removed # Check that original keys were removed
assert "environment_state" not in processed_obs assert "environment_state" not in processed_obs
@@ -281,12 +281,12 @@ def test_complete_observation_processing():
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
# Check that image was processed # Check that image was processed
assert "observation.image" in processed_obs assert OBS_IMAGE in processed_obs
assert processed_obs["observation.image"].shape == (1, 3, 32, 32) assert processed_obs[OBS_IMAGE].shape == (1, 3, 32, 32)
# Check that states were processed # Check that states were processed
assert "observation.environment_state" in processed_obs assert OBS_ENV_STATE in processed_obs
assert "observation.state" in processed_obs assert OBS_STATE in processed_obs
# Check that original keys were removed # Check that original keys were removed
assert "pixels" not in processed_obs assert "pixels" not in processed_obs
@@ -308,7 +308,7 @@ def test_image_only_processing():
result = processor(transition) result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.image" in processed_obs assert OBS_IMAGE in processed_obs
assert len(processed_obs) == 1 assert len(processed_obs) == 1
@@ -323,7 +323,7 @@ def test_state_only_processing():
result = processor(transition) result = processor(transition)
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs assert OBS_STATE in processed_obs
assert "agent_pos" not in processed_obs assert "agent_pos" not in processed_obs
@@ -504,7 +504,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory):
proc = VanillaObservationProcessorStep() proc = VanillaObservationProcessorStep()
features = { features = {
PipelineFeatureType.OBSERVATION: { PipelineFeatureType.OBSERVATION: {
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), OBS_ENV_STATE: policy_feature_factory(FeatureType.STATE, (2,)),
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
}, },
} }
@@ -513,7 +513,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory):
assert ( assert (
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION] OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE] and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
== features[PipelineFeatureType.OBSERVATION]["observation.environment_state"] == features[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
) )
assert ( assert (
OBS_STATE in out[PipelineFeatureType.OBSERVATION] OBS_STATE in out[PipelineFeatureType.OBSERVATION]

View File

@@ -35,6 +35,7 @@ from lerobot.processor import (
TransitionKey, TransitionKey,
) )
from lerobot.processor.converters import create_transition, identity_transition from lerobot.processor.converters import create_transition, identity_transition
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE
from tests.conftest import assert_contract_is_typed from tests.conftest import assert_contract_is_typed
@@ -255,7 +256,7 @@ def test_step_through_with_dict():
pipeline = DataProcessorPipeline([step1, step2]) pipeline = DataProcessorPipeline([step1, step2])
batch = { batch = {
"observation.image": None, OBS_IMAGE: None,
"action": None, "action": None,
"next.reward": 0.0, "next.reward": 0.0,
"next.done": False, "next.done": False,
@@ -1840,7 +1841,7 @@ def test_save_load_with_custom_converter_functions():
# Verify it uses default converters by checking with standard batch format # Verify it uses default converters by checking with standard batch format
batch = { batch = {
"observation.image": torch.randn(1, 3, 32, 32), OBS_IMAGE: torch.randn(1, 3, 32, 32),
"action": torch.randn(1, 7), "action": torch.randn(1, 7),
"next.reward": torch.tensor([1.0]), "next.reward": torch.tensor([1.0]),
"next.done": torch.tensor([False]), "next.done": torch.tensor([False]),
@@ -1851,7 +1852,7 @@ def test_save_load_with_custom_converter_functions():
# Should work with standard format (wouldn't work with custom converter) # Should work with standard format (wouldn't work with custom converter)
result = loaded(batch) result = loaded(batch)
# With new behavior, default to_output is _default_transition_to_batch, so result is batch dict # With new behavior, default to_output is _default_transition_to_batch, so result is batch dict
assert "observation.image" in result assert OBS_IMAGE in result
class NonCompliantStep: class NonCompliantStep:
@@ -2075,10 +2076,10 @@ class AddObservationStateFeatures(ProcessorStep):
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
# State features (mix EE and a joint state) # State features (mix EE and a joint state)
features[PipelineFeatureType.OBSERVATION]["observation.state.ee.x"] = float features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.ee.x"] = float
features[PipelineFeatureType.OBSERVATION]["observation.state.j1.pos"] = float features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.j1.pos"] = float
if self.add_front_image: if self.add_front_image:
features[PipelineFeatureType.OBSERVATION]["observation.images.front"] = self.front_image_shape features[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"] = self.front_image_shape
return features return features
@@ -2094,7 +2095,7 @@ def test_aggregate_joint_action_only():
) )
# Expect only "action" with joint names # Expect only "action" with joint names
assert "action" in out and "observation.state" not in out assert "action" in out and OBS_STATE not in out
assert out["action"]["dtype"] == "float32" assert out["action"]["dtype"] == "float32"
assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"} assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"}
assert out["action"]["shape"] == (len(out["action"]["names"]),) assert out["action"]["shape"] == (len(out["action"]["names"]),)
@@ -2108,7 +2109,7 @@ def test_aggregate_ee_action_and_observation_with_videos():
pipeline=rp, pipeline=rp,
initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}}, initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}},
use_videos=True, use_videos=True,
patterns=["action.ee", "observation.state"], patterns=["action.ee", OBS_STATE],
) )
# Action should pack only EE names # Action should pack only EE names
@@ -2117,13 +2118,13 @@ def test_aggregate_ee_action_and_observation_with_videos():
assert out["action"]["dtype"] == "float32" assert out["action"]["dtype"] == "float32"
# Observation state should pack both ee.x and j1.pos as a vector # Observation state should pack both ee.x and j1.pos as a vector
assert "observation.state" in out assert OBS_STATE in out
assert set(out["observation.state"]["names"]) == {"ee.x", "j1.pos"} assert set(out[OBS_STATE]["names"]) == {"ee.x", "j1.pos"}
assert out["observation.state"]["dtype"] == "float32" assert out[OBS_STATE]["dtype"] == "float32"
# Cameras from initial_features appear as videos # Cameras from initial_features appear as videos
for cam in ("front", "side"): for cam in ("front", "side"):
key = f"observation.images.{cam}" key = f"{OBS_IMAGES}.{cam}"
assert key in out assert key in out
assert out[key]["dtype"] == "video" assert out[key]["dtype"] == "video"
assert out[key]["shape"] == initial[cam] assert out[key]["shape"] == initial[cam]
@@ -2156,8 +2157,8 @@ def test_aggregate_images_when_use_videos_false():
patterns=None, patterns=None,
) )
key = "observation.images.back" key = f"{OBS_IMAGES}.back"
key_front = "observation.images.front" key_front = f"{OBS_IMAGES}.front"
assert key not in out assert key not in out
assert key_front not in out assert key_front not in out
@@ -2173,8 +2174,8 @@ def test_aggregate_images_when_use_videos_true():
patterns=None, patterns=None,
) )
key = "observation.images.front" key = f"{OBS_IMAGES}.front"
key_back = "observation.images.back" key_back = f"{OBS_IMAGES}.back"
assert key in out assert key in out
assert key_back in out assert key_back in out
assert out[key]["dtype"] == "video" assert out[key]["dtype"] == "video"
@@ -2194,9 +2195,9 @@ def test_initial_camera_not_overridden_by_step_image():
pipeline=rp, pipeline=rp,
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial}, initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
use_videos=True, use_videos=True,
patterns=["observation.images.front"], patterns=[f"{OBS_IMAGES}.front"],
) )
key = "observation.images.front" key = f"{OBS_IMAGES}.front"
assert key in out assert key in out
assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial

View File

@@ -28,6 +28,7 @@ from lerobot.processor import (
) )
from lerobot.processor.converters import create_transition, identity_transition from lerobot.processor.converters import create_transition, identity_transition
from lerobot.processor.rename_processor import rename_stats from lerobot.processor.rename_processor import rename_stats
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE
from tests.conftest import assert_contract_is_typed from tests.conftest import assert_contract_is_typed
@@ -121,13 +122,13 @@ def test_overlapping_rename():
def test_partial_rename(): def test_partial_rename():
"""Test renaming only some keys.""" """Test renaming only some keys."""
rename_map = { rename_map = {
"observation.state": "observation.proprio_state", OBS_STATE: "observation.proprio_state",
"pixels": "observation.image", "pixels": OBS_IMAGE,
} }
processor = RenameObservationsProcessorStep(rename_map=rename_map) processor = RenameObservationsProcessorStep(rename_map=rename_map)
observation = { observation = {
"observation.state": torch.randn(10), OBS_STATE: torch.randn(10),
"pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8), "pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8),
"reward": 1.0, "reward": 1.0,
"info": {"episode": 1}, "info": {"episode": 1},
@@ -139,8 +140,8 @@ def test_partial_rename():
# Check renamed keys # Check renamed keys
assert "observation.proprio_state" in processed_obs assert "observation.proprio_state" in processed_obs
assert "observation.image" in processed_obs assert OBS_IMAGE in processed_obs
assert "observation.state" not in processed_obs assert OBS_STATE not in processed_obs
assert "pixels" not in processed_obs assert "pixels" not in processed_obs
# Check unchanged keys # Check unchanged keys
@@ -174,8 +175,8 @@ def test_state_dict():
def test_integration_with_robot_processor(): def test_integration_with_robot_processor():
"""Test integration with RobotProcessor pipeline.""" """Test integration with RobotProcessor pipeline."""
rename_map = { rename_map = {
"agent_pos": "observation.state", "agent_pos": OBS_STATE,
"pixels": "observation.image", "pixels": OBS_IMAGE,
} }
rename_processor = RenameObservationsProcessorStep(rename_map=rename_map) rename_processor = RenameObservationsProcessorStep(rename_map=rename_map)
@@ -196,8 +197,8 @@ def test_integration_with_robot_processor():
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
# Check renaming worked through pipeline # Check renaming worked through pipeline
assert "observation.state" in processed_obs assert OBS_STATE in processed_obs
assert "observation.image" in processed_obs assert OBS_IMAGE in processed_obs
assert "agent_pos" not in processed_obs assert "agent_pos" not in processed_obs
assert "pixels" not in processed_obs assert "pixels" not in processed_obs
assert processed_obs["other_data"] == "preserve_me" assert processed_obs["other_data"] == "preserve_me"
@@ -210,8 +211,8 @@ def test_integration_with_robot_processor():
def test_save_and_load_pretrained(): def test_save_and_load_pretrained():
"""Test saving and loading processor with RobotProcessor.""" """Test saving and loading processor with RobotProcessor."""
rename_map = { rename_map = {
"old_state": "observation.state", "old_state": OBS_STATE,
"old_image": "observation.image", "old_image": OBS_IMAGE,
} }
processor = RenameObservationsProcessorStep(rename_map=rename_map) processor = RenameObservationsProcessorStep(rename_map=rename_map)
pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep") pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep")
@@ -253,10 +254,10 @@ def test_save_and_load_pretrained():
result = loaded_pipeline(transition) result = loaded_pipeline(transition)
processed_obs = result[TransitionKey.OBSERVATION] processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs assert OBS_STATE in processed_obs
assert "observation.image" in processed_obs assert OBS_IMAGE in processed_obs
assert processed_obs["observation.state"] == [1, 2, 3] assert processed_obs[OBS_STATE] == [1, 2, 3]
assert processed_obs["observation.image"] == "image_data" assert processed_obs[OBS_IMAGE] == "image_data"
def test_registry_functionality(): def test_registry_functionality():
@@ -317,8 +318,8 @@ def test_chained_rename_processors():
# Second processor: rename to final format # Second processor: rename to final format
processor2 = RenameObservationsProcessorStep( processor2 = RenameObservationsProcessorStep(
rename_map={ rename_map={
"agent_position": "observation.state", "agent_position": OBS_STATE,
"camera_image": "observation.image", "camera_image": OBS_IMAGE,
} }
) )
@@ -342,8 +343,8 @@ def test_chained_rename_processors():
# After second processor # After second processor
final_obs = results[2][TransitionKey.OBSERVATION] final_obs = results[2][TransitionKey.OBSERVATION]
assert "observation.state" in final_obs assert OBS_STATE in final_obs
assert "observation.image" in final_obs assert OBS_IMAGE in final_obs
assert final_obs["extra"] == "keep_me" assert final_obs["extra"] == "keep_me"
# Original keys should be gone # Original keys should be gone
@@ -356,15 +357,15 @@ def test_chained_rename_processors():
def test_nested_observation_rename(): def test_nested_observation_rename():
"""Test renaming with nested observation structures.""" """Test renaming with nested observation structures."""
rename_map = { rename_map = {
"observation.images.left": "observation.camera.left_view", f"{OBS_IMAGES}.left": "observation.camera.left_view",
"observation.images.right": "observation.camera.right_view", f"{OBS_IMAGES}.right": "observation.camera.right_view",
"observation.proprio": "observation.proprioception", "observation.proprio": "observation.proprioception",
} }
processor = RenameObservationsProcessorStep(rename_map=rename_map) processor = RenameObservationsProcessorStep(rename_map=rename_map)
observation = { observation = {
"observation.images.left": torch.randn(3, 64, 64), f"{OBS_IMAGES}.left": torch.randn(3, 64, 64),
"observation.images.right": torch.randn(3, 64, 64), f"{OBS_IMAGES}.right": torch.randn(3, 64, 64),
"observation.proprio": torch.randn(7), "observation.proprio": torch.randn(7),
"observation.gripper": torch.tensor([0.0]), # Not renamed "observation.gripper": torch.tensor([0.0]), # Not renamed
} }
@@ -382,8 +383,8 @@ def test_nested_observation_rename():
assert "observation.gripper" in processed_obs assert "observation.gripper" in processed_obs
# Check old keys removed # Check old keys removed
assert "observation.images.left" not in processed_obs assert f"{OBS_IMAGES}.left" not in processed_obs
assert "observation.images.right" not in processed_obs assert f"{OBS_IMAGES}.right" not in processed_obs
assert "observation.proprio" not in processed_obs assert "observation.proprio" not in processed_obs
@@ -464,7 +465,7 @@ def test_features_chained_processors(policy_feature_factory):
# Chain two rename processors at the contract level # Chain two rename processors at the contract level
processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"}) processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"})
processor2 = RenameObservationsProcessorStep( processor2 = RenameObservationsProcessorStep(
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"} rename_map={"agent_position": OBS_STATE, "camera_image": OBS_IMAGE}
) )
pipeline = DataProcessorPipeline([processor1, processor2]) pipeline = DataProcessorPipeline([processor1, processor2])
@@ -477,27 +478,21 @@ def test_features_chained_processors(policy_feature_factory):
} }
out = pipeline.transform_features(initial_features=spec) out = pipeline.transform_features(initial_features=spec)
assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"} assert set(out[PipelineFeatureType.OBSERVATION]) == {OBS_STATE, OBS_IMAGE, "extra"}
assert ( assert out[PipelineFeatureType.OBSERVATION][OBS_STATE] == spec[PipelineFeatureType.OBSERVATION]["pos"]
out[PipelineFeatureType.OBSERVATION]["observation.state"] assert out[PipelineFeatureType.OBSERVATION][OBS_IMAGE] == spec[PipelineFeatureType.OBSERVATION]["img"]
== spec[PipelineFeatureType.OBSERVATION]["pos"]
)
assert (
out[PipelineFeatureType.OBSERVATION]["observation.image"]
== spec[PipelineFeatureType.OBSERVATION]["img"]
)
assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"] assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"]
assert_contract_is_typed(out) assert_contract_is_typed(out)
def test_rename_stats_basic(): def test_rename_stats_basic():
orig = { orig = {
"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}, OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])},
"action": {"mean": np.array([0.0])}, "action": {"mean": np.array([0.0])},
} }
mapping = {"observation.state": "observation.robot_state"} mapping = {OBS_STATE: "observation.robot_state"}
renamed = rename_stats(orig, mapping) renamed = rename_stats(orig, mapping)
assert "observation.robot_state" in renamed and "observation.state" not in renamed assert "observation.robot_state" in renamed and OBS_STATE not in renamed
# Ensure deep copy: mutate original and verify renamed unaffected # Ensure deep copy: mutate original and verify renamed unaffected
orig["observation.state"]["mean"][0] = 42.0 orig[OBS_STATE]["mean"][0] = 42.0
assert renamed["observation.robot_state"]["mean"][0] != 42.0 assert renamed["observation.robot_state"]["mean"][0] != 42.0

View File

@@ -11,7 +11,7 @@ import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
from lerobot.processor.converters import create_transition, identity_transition from lerobot.processor.converters import create_transition, identity_transition
from lerobot.utils.constants import OBS_LANGUAGE from lerobot.utils.constants import OBS_IMAGE, OBS_LANGUAGE, OBS_STATE
from tests.utils import require_package from tests.utils import require_package
@@ -503,16 +503,14 @@ def test_features_basic():
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=128) processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=128)
input_features = { input_features = {
PipelineFeatureType.OBSERVATION: { PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))
},
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
} }
output_features = processor.transform_features(input_features) output_features = processor.transform_features(input_features)
# Check that original features are preserved # Check that original features are preserved
assert "observation.state" in output_features[PipelineFeatureType.OBSERVATION] assert OBS_STATE in output_features[PipelineFeatureType.OBSERVATION]
assert "action" in output_features[PipelineFeatureType.ACTION] assert "action" in output_features[PipelineFeatureType.ACTION]
# Check that tokenized features are added # Check that tokenized features are added
@@ -797,7 +795,7 @@ def test_device_detection_cpu():
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with CPU tensors # Create transition with CPU tensors
observation = {"observation.state": torch.randn(10)} # CPU tensor observation = {OBS_STATE: torch.randn(10)} # CPU tensor
action = torch.randn(5) # CPU tensor action = torch.randn(5) # CPU tensor
transition = create_transition( transition = create_transition(
observation=observation, action=action, complementary_data={"task": "test task"} observation=observation, action=action, complementary_data={"task": "test task"}
@@ -821,7 +819,7 @@ def test_device_detection_cuda():
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with CUDA tensors # Create transition with CUDA tensors
observation = {"observation.state": torch.randn(10).cuda()} # CUDA tensor observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor
action = torch.randn(5).cuda() # CUDA tensor action = torch.randn(5).cuda() # CUDA tensor
transition = create_transition( transition = create_transition(
observation=observation, action=action, complementary_data={"task": "test task"} observation=observation, action=action, complementary_data={"task": "test task"}
@@ -847,7 +845,7 @@ def test_device_detection_multi_gpu():
# Test with tensors on cuda:1 # Test with tensors on cuda:1
device = torch.device("cuda:1") device = torch.device("cuda:1")
observation = {"observation.state": torch.randn(10).to(device)} observation = {OBS_STATE: torch.randn(10).to(device)}
action = torch.randn(5).to(device) action = torch.randn(5).to(device)
transition = create_transition( transition = create_transition(
observation=observation, action=action, complementary_data={"task": "multi gpu test"} observation=observation, action=action, complementary_data={"task": "multi gpu test"}
@@ -943,7 +941,7 @@ def test_device_detection_preserves_dtype():
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with float tensor (to test dtype isn't affected) # Create transition with float tensor (to test dtype isn't affected)
observation = {"observation.state": torch.randn(10, dtype=torch.float16)} observation = {OBS_STATE: torch.randn(10, dtype=torch.float16)}
transition = create_transition(observation=observation, complementary_data={"task": "dtype test"}) transition = create_transition(observation=observation, complementary_data={"task": "dtype test"})
result = processor(transition) result = processor(transition)
@@ -977,7 +975,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
# Start with CPU tensors # Start with CPU tensors
transition = create_transition( transition = create_transition(
observation={"observation.state": torch.randn(10)}, # CPU observation={OBS_STATE: torch.randn(10)}, # CPU
action=torch.randn(5), # CPU action=torch.randn(5), # CPU
complementary_data={"task": "pipeline test"}, complementary_data={"task": "pipeline test"},
) )
@@ -985,7 +983,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
result = robot_processor(transition) result = robot_processor(transition)
# All tensors should end up on CUDA (moved by DeviceProcessorStep) # All tensors should end up on CUDA (moved by DeviceProcessorStep)
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda"
# Tokenized tensors should also be on CUDA # Tokenized tensors should also be on CUDA
@@ -1005,8 +1003,8 @@ def test_simulated_accelerate_scenario():
# Simulate Accelerate scenario: batch already on GPU # Simulate Accelerate scenario: batch already on GPU
device = torch.device("cuda:0") device = torch.device("cuda:0")
observation = { observation = {
"observation.state": torch.randn(1, 10).to(device), # Batched, on GPU OBS_STATE: torch.randn(1, 10).to(device), # Batched, on GPU
"observation.image": torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU
} }
action = torch.randn(1, 5).to(device) # Batched, on GPU action = torch.randn(1, 5).to(device) # Batched, on GPU

View File

@@ -21,6 +21,7 @@ import pytest
import torch import torch
from torch.multiprocessing import Event, Queue from torch.multiprocessing import Event, Queue
from lerobot.utils.constants import OBS_STR
from lerobot.utils.transition import Transition from lerobot.utils.transition import Transition
from tests.utils import require_package from tests.utils import require_package
@@ -110,12 +111,12 @@ def test_push_transitions_to_transport_queue():
transitions = [] transitions = []
for i in range(3): for i in range(3):
transition = Transition( transition = Transition(
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)},
action=torch.randn(5), action=torch.randn(5),
reward=torch.tensor(1.0 + i), reward=torch.tensor(1.0 + i),
done=torch.tensor(False), done=torch.tensor(False),
truncated=torch.tensor(False), truncated=torch.tensor(False),
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, next_state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)},
complementary_info={"step": torch.tensor(i)}, complementary_info={"step": torch.tensor(i)},
) )
transitions.append(transition) transitions.append(transition)

View File

@@ -24,6 +24,7 @@ from torch.multiprocessing import Event, Queue
from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.utils.constants import OBS_STR
from lerobot.utils.transition import Transition from lerobot.utils.transition import Transition
from tests.utils import require_package from tests.utils import require_package
@@ -33,12 +34,12 @@ def create_test_transitions(count: int = 3) -> list[Transition]:
transitions = [] transitions = []
for i in range(count): for i in range(count):
transition = Transition( transition = Transition(
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)},
action=torch.randn(5), action=torch.randn(5),
reward=torch.tensor(1.0 + i), reward=torch.tensor(1.0 + i),
done=torch.tensor(i == count - 1), # Last transition is done done=torch.tensor(i == count - 1), # Last transition is done
truncated=torch.tensor(False), truncated=torch.tensor(False),
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, next_state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)},
complementary_info={"step": torch.tensor(i), "episode_id": i // 2}, complementary_info={"step": torch.tensor(i), "episode_id": i // 2},
) )
transitions.append(transition) transitions.append(transition)

View File

@@ -22,11 +22,12 @@ import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR
from tests.fixtures.constants import DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_REPO_ID
def state_dims() -> list[str]: def state_dims() -> list[str]:
return ["observation.image", "observation.state"] return [OBS_IMAGE, OBS_STATE]
@pytest.fixture @pytest.fixture
@@ -61,10 +62,10 @@ def create_random_image() -> torch.Tensor:
def create_dummy_transition() -> dict: def create_dummy_transition() -> dict:
return { return {
"observation.image": create_random_image(), OBS_IMAGE: create_random_image(),
"action": torch.randn(4), "action": torch.randn(4),
"reward": torch.tensor(1.0), "reward": torch.tensor(1.0),
"observation.state": torch.randn( OBS_STATE: torch.randn(
10, 10,
), ),
"done": torch.tensor(False), "done": torch.tensor(False),
@@ -98,8 +99,8 @@ def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayB
def create_dummy_state() -> dict: def create_dummy_state() -> dict:
return { return {
"observation.image": create_random_image(), OBS_IMAGE: create_random_image(),
"observation.state": torch.randn( OBS_STATE: torch.randn(
10, 10,
), ),
} }
@@ -180,7 +181,7 @@ def test_empty_buffer_sample_raises_error(replay_buffer):
def test_zero_capacity_buffer_raises_error(): def test_zero_capacity_buffer_raises_error():
with pytest.raises(ValueError, match="Capacity must be greater than 0."): with pytest.raises(ValueError, match="Capacity must be greater than 0."):
ReplayBuffer(0, "cpu", ["observation", "next_observation"]) ReplayBuffer(0, "cpu", [OBS_STR, "next_observation"])
def test_add_transition(replay_buffer, dummy_state, dummy_action): def test_add_transition(replay_buffer, dummy_state, dummy_action):
@@ -203,7 +204,7 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action):
def test_add_over_capacity(): def test_add_over_capacity():
replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"]) replay_buffer = ReplayBuffer(2, "cpu", [OBS_STR, "next_observation"])
dummy_state_1 = create_dummy_state() dummy_state_1 = create_dummy_state()
dummy_action_1 = create_dummy_action() dummy_action_1 = create_dummy_action()
@@ -373,7 +374,7 @@ def test_to_lerobot_dataset(tmp_path):
assert ds.num_frames == 4 assert ds.num_frames == 4
for j, value in enumerate(ds): for j, value in enumerate(ds):
print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j])) print(torch.equal(value[OBS_IMAGE], buffer.next_states[OBS_IMAGE][j]))
for i in range(len(ds)): for i in range(len(ds)):
for feature, value in ds[i].items(): for feature, value in ds[i].items():
@@ -383,12 +384,12 @@ def test_to_lerobot_dataset(tmp_path):
assert torch.equal(value, buffer.rewards[i]) assert torch.equal(value, buffer.rewards[i])
elif feature == "next.done": elif feature == "next.done":
assert torch.equal(value, buffer.dones[i]) assert torch.equal(value, buffer.dones[i])
elif feature == "observation.image": elif feature == OBS_IMAGE:
# Tensor -> numpy is not precise, so we have some diff there # Tensor -> numpy is not precise, so we have some diff there
# TODO: Check and fix it # TODO: Check and fix it
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003) torch.testing.assert_close(value, buffer.states[OBS_IMAGE][i], rtol=0.3, atol=0.003)
elif feature == "observation.state": elif feature == OBS_STATE:
assert torch.equal(value, buffer.states["observation.state"][i]) assert torch.equal(value, buffer.states[OBS_STATE][i])
def test_from_lerobot_dataset(tmp_path): def test_from_lerobot_dataset(tmp_path):
@@ -436,14 +437,14 @@ def test_from_lerobot_dataset(tmp_path):
) )
assert torch.equal( assert torch.equal(
replay_buffer.states["observation.state"][: len(replay_buffer)], replay_buffer.states[OBS_STATE][: len(replay_buffer)],
reconverted_buffer.states["observation.state"][: len(replay_buffer)], reconverted_buffer.states[OBS_STATE][: len(replay_buffer)],
), "State should be the same after converting to dataset and return back" ), "State should be the same after converting to dataset and return back"
for i in range(4): for i in range(4):
torch.testing.assert_close( torch.testing.assert_close(
replay_buffer.states["observation.image"][i], replay_buffer.states[OBS_IMAGE][i],
reconverted_buffer.states["observation.image"][i], reconverted_buffer.states[OBS_IMAGE][i],
rtol=0.4, rtol=0.4,
atol=0.004, atol=0.004,
) )
@@ -454,16 +455,16 @@ def test_from_lerobot_dataset(tmp_path):
next_index = (i + 1) % 4 next_index = (i + 1) % 4
torch.testing.assert_close( torch.testing.assert_close(
replay_buffer.states["observation.image"][next_index], replay_buffer.states[OBS_IMAGE][next_index],
reconverted_buffer.next_states["observation.image"][i], reconverted_buffer.next_states[OBS_IMAGE][i],
rtol=0.4, rtol=0.4,
atol=0.004, atol=0.004,
) )
for i in range(2, 4): for i in range(2, 4):
assert torch.equal( assert torch.equal(
replay_buffer.states["observation.state"][i], replay_buffer.states[OBS_STATE][i],
reconverted_buffer.next_states["observation.state"][i], reconverted_buffer.next_states[OBS_STATE][i],
) )
@@ -563,10 +564,8 @@ def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_functio
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
sampled_transitions = replay_buffer.sample(1) sampled_transitions = replay_buffer.sample(1)
assert torch.all(sampled_transitions["state"]["observation.image"] == 10), ( assert torch.all(sampled_transitions["state"][OBS_IMAGE] == 10), "Image augmentations should be applied"
"Image augmentations should be applied" assert torch.all(sampled_transitions["next_state"][OBS_IMAGE] == 10), (
)
assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), (
"Image augmentations should be applied" "Image augmentations should be applied"
) )
@@ -580,8 +579,8 @@ def test_check_image_augmentations_with_drq_and_default_image_augmentation_funct
# Let's check that it doesn't fail and shapes are correct # Let's check that it doesn't fail and shapes are correct
sampled_transitions = replay_buffer.sample(1) sampled_transitions = replay_buffer.sample(1)
assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84) assert sampled_transitions["state"][OBS_IMAGE].shape == (1, 3, 84, 84)
assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84) assert sampled_transitions["next_state"][OBS_IMAGE].shape == (1, 3, 84, 84)
def test_random_crop_vectorized_basic(): def test_random_crop_vectorized_basic():
@@ -620,7 +619,7 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer:
buffer = ReplayBuffer( buffer = ReplayBuffer(
capacity=capacity, capacity=capacity,
device="cpu", device="cpu",
state_keys=["observation.image", "observation.state"], state_keys=[OBS_IMAGE, OBS_STATE],
storage_device="cpu", storage_device="cpu",
) )
@@ -628,8 +627,8 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer:
img = torch.ones(3, 128, 128) * i img = torch.ones(3, 128, 128) * i
state_vec = torch.arange(11).float() + i state_vec = torch.arange(11).float() + i
state = { state = {
"observation.image": img, OBS_IMAGE: img,
"observation.state": state_vec, OBS_STATE: state_vec,
} }
buffer.add( buffer.add(
state=state, state=state,
@@ -648,14 +647,14 @@ def test_async_iterator_shapes_basic():
iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1) iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1)
batch = next(iterator) batch = next(iterator)
images = batch["state"]["observation.image"] images = batch["state"][OBS_IMAGE]
states = batch["state"]["observation.state"] states = batch["state"][OBS_STATE]
assert images.shape == (batch_size, 3, 128, 128) assert images.shape == (batch_size, 3, 128, 128)
assert states.shape == (batch_size, 11) assert states.shape == (batch_size, 11)
next_images = batch["next_state"]["observation.image"] next_images = batch["next_state"][OBS_IMAGE]
next_states = batch["next_state"]["observation.state"] next_states = batch["next_state"][OBS_STATE]
assert next_images.shape == (batch_size, 3, 128, 128) assert next_images.shape == (batch_size, 3, 128, 128)
assert next_states.shape == (batch_size, 11) assert next_states.shape == (batch_size, 11)
@@ -668,13 +667,13 @@ def test_async_iterator_multiple_iterations():
for _ in range(5): for _ in range(5):
batch = next(iterator) batch = next(iterator)
images = batch["state"]["observation.image"] images = batch["state"][OBS_IMAGE]
states = batch["state"]["observation.state"] states = batch["state"][OBS_STATE]
assert images.shape == (batch_size, 3, 128, 128) assert images.shape == (batch_size, 3, 128, 128)
assert states.shape == (batch_size, 11) assert states.shape == (batch_size, 11)
next_images = batch["next_state"]["observation.image"] next_images = batch["next_state"][OBS_IMAGE]
next_states = batch["next_state"]["observation.state"] next_states = batch["next_state"][OBS_STATE]
assert next_images.shape == (batch_size, 3, 128, 128) assert next_images.shape == (batch_size, 3, 128, 128)
assert next_states.shape == (batch_size, 11) assert next_states.shape == (batch_size, 11)

View File

@@ -6,6 +6,7 @@ import numpy as np
import pytest import pytest
from lerobot.processor import TransitionKey from lerobot.processor import TransitionKey
from lerobot.utils.constants import OBS_STATE
@pytest.fixture @pytest.fixture
@@ -72,7 +73,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
# Build EnvTransition dict # Build EnvTransition dict
obs = { obs = {
"observation.state.temperature": np.float32(25.0), f"{OBS_STATE}.temperature": np.float32(25.0),
# CHW image should be converted to HWC for rr.Image # CHW image should be converted to HWC for rr.Image
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8), "observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
} }
@@ -97,7 +98,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
# - action.throttle -> Scalar # - action.throttle -> Scalar
# - action.vector_0, action.vector_1 -> Scalars # - action.vector_0, action.vector_1 -> Scalars
expected_keys = { expected_keys = {
"observation.state.temperature", f"{OBS_STATE}.temperature",
"observation.camera", "observation.camera",
"action.throttle", "action.throttle",
"action.vector_0", "action.vector_0",
@@ -106,7 +107,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
assert set(_keys(calls)) == expected_keys assert set(_keys(calls)) == expected_keys
# Check scalar types and values # Check scalar types and values
temp_obj = _obj_for(calls, "observation.state.temperature") temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
assert type(temp_obj).__name__ == "DummyScalar" assert type(temp_obj).__name__ == "DummyScalar"
assert temp_obj.value == pytest.approx(25.0) assert temp_obj.value == pytest.approx(25.0)