chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -41,6 +41,7 @@ from lerobot.datasets.video_utils import (
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
|
||||
BASE_ENCODING = OrderedDict(
|
||||
[
|
||||
@@ -117,7 +118,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# We only save images from the first camera
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(
|
||||
|
||||
@@ -21,6 +21,7 @@ from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
@@ -42,7 +43,7 @@ policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
|
||||
@@ -22,6 +22,7 @@ from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
@@ -48,7 +49,7 @@ teleop_action_processor, robot_action_processor, robot_observation_processor = m
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
|
||||
@@ -27,7 +27,7 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
|
||||
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
@@ -66,7 +66,7 @@ def validate_robot_cameras_for_policy(
|
||||
|
||||
|
||||
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
|
||||
return hw_to_dataset_features(robot.observation_features, "observation", use_video=False)
|
||||
return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
@@ -141,7 +141,7 @@ def make_lerobot_observation(
|
||||
lerobot_features: dict[str, dict],
|
||||
) -> LeRobotObservation:
|
||||
"""Make a lerobot observation from a raw observation."""
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix="observation")
|
||||
return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR)
|
||||
|
||||
|
||||
def prepare_raw_observation(
|
||||
|
||||
@@ -27,6 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
)
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import OBS_PREFIX
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
@@ -58,7 +59,7 @@ def resolve_delta_timestamps(
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == "action" and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
|
||||
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import Any
|
||||
from lerobot.configs.types import PipelineFeatureType
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
def create_initial_features(
|
||||
@@ -92,8 +92,8 @@ def aggregate_pipeline_dataset_features(
|
||||
|
||||
# Intermediate storage for categorized and filtered features.
|
||||
processed_features: dict[str, dict[str, Any]] = {
|
||||
"action": {},
|
||||
"observation": {},
|
||||
ACTION: {},
|
||||
OBS_STR: {},
|
||||
}
|
||||
images_token = OBS_IMAGES.split(".")[-1]
|
||||
|
||||
@@ -125,17 +125,15 @@ def aggregate_pipeline_dataset_features(
|
||||
# 3. Add the feature to the appropriate group with a clean name.
|
||||
name = strip_prefix(key, PREFIXES_TO_STRIP)
|
||||
if is_action:
|
||||
processed_features["action"][name] = value
|
||||
processed_features[ACTION][name] = value
|
||||
else:
|
||||
processed_features["observation"][name] = value
|
||||
processed_features[OBS_STR][name] = value
|
||||
|
||||
# Convert the processed features into the final dataset format.
|
||||
dataset_features = {}
|
||||
if processed_features["action"]:
|
||||
if processed_features[ACTION]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos))
|
||||
if processed_features["observation"]:
|
||||
dataset_features.update(
|
||||
hw_to_dataset_features(processed_features["observation"], "observation", use_videos)
|
||||
)
|
||||
if processed_features[OBS_STR]:
|
||||
dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos))
|
||||
|
||||
return dataset_features
|
||||
|
||||
@@ -43,6 +43,7 @@ from lerobot.datasets.backward_compatibility import (
|
||||
BackwardCompatibilityError,
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_STR
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
@@ -652,7 +653,7 @@ def hw_to_dataset_features(
|
||||
"names": list(joint_fts),
|
||||
}
|
||||
|
||||
if joint_fts and prefix == "observation":
|
||||
if joint_fts and prefix == OBS_STR:
|
||||
features[f"{prefix}.state"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
@@ -728,9 +729,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == "observation.environment_state":
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith("observation"):
|
||||
elif key.startswith(OBS_STR):
|
||||
type = FeatureType.STATE
|
||||
elif key.startswith("action"):
|
||||
type = FeatureType.ACTION
|
||||
|
||||
@@ -26,6 +26,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.utils import get_channel_first_image_shape
|
||||
|
||||
|
||||
@@ -41,9 +42,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
return_observations = {}
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
imgs = {OBS_IMAGE: observations["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
@@ -72,13 +73,13 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
if env_state.dim() == 1:
|
||||
env_state = env_state.unsqueeze(0)
|
||||
|
||||
return_observations["observation.environment_state"] = env_state
|
||||
return_observations[OBS_ENV_STATE] = env_state
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
return_observations["observation.state"] = agent_pos
|
||||
return_observations[OBS_STATE] = agent_pos
|
||||
|
||||
return return_observations
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class ACTPolicy(PreTrainedPolicy):
|
||||
@@ -398,10 +398,10 @@ class ACT(nn.Module):
|
||||
"actions must be provided when using the variational objective in training mode."
|
||||
)
|
||||
|
||||
if "observation.images" in batch:
|
||||
batch_size = batch["observation.images"][0].shape[0]
|
||||
if OBS_IMAGES in batch:
|
||||
batch_size = batch[OBS_IMAGES][0].shape[0]
|
||||
else:
|
||||
batch_size = batch["observation.environment_state"].shape[0]
|
||||
batch_size = batch[OBS_ENV_STATE].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and "action" in batch and self.training:
|
||||
@@ -410,7 +410,7 @@ class ACT(nn.Module):
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.config.robot_state_feature:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
|
||||
@@ -430,7 +430,7 @@ class ACT(nn.Module):
|
||||
cls_joint_is_pad = torch.full(
|
||||
(batch_size, 2 if self.config.robot_state_feature else 1),
|
||||
False,
|
||||
device=batch["observation.state"].device,
|
||||
device=batch[OBS_STATE].device,
|
||||
)
|
||||
key_padding_mask = torch.cat(
|
||||
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
|
||||
@@ -454,7 +454,7 @@ class ACT(nn.Module):
|
||||
mu = log_sigma_x2 = None
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
batch[OBS_STATE].device
|
||||
)
|
||||
|
||||
# Prepare transformer encoder inputs.
|
||||
@@ -462,18 +462,16 @@ class ACT(nn.Module):
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
# Robot state token.
|
||||
if self.config.robot_state_feature:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch[OBS_STATE]))
|
||||
# Environment state token.
|
||||
if self.config.env_state_feature:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
)
|
||||
encoder_in_tokens.append(self.encoder_env_state_input_proj(batch[OBS_ENV_STATE]))
|
||||
|
||||
if self.config.image_features:
|
||||
# For a list of images, the H and W may vary but H*W is constant.
|
||||
# NOTE: If modifying this section, verify on MPS devices that
|
||||
# gradients remain stable (no explosions or NaNs).
|
||||
for img in batch["observation.images"]:
|
||||
for img in batch[OBS_IMAGES]:
|
||||
cam_features = self.backbone(img)["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
||||
|
||||
@@ -81,13 +81,13 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -234,7 +234,7 @@ class DiffusionModel(nn.Module):
|
||||
if self.config.image_features:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
images_per_camera = einops.rearrange(batch[OBS_IMAGES], "b s n ... -> n (b s) ...")
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
@@ -249,7 +249,7 @@ class DiffusionModel(nn.Module):
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
@@ -275,7 +275,7 @@ class DiffusionModel(nn.Module):
|
||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||
}
|
||||
"""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Encode image features and concatenate them all together along with the state vector.
|
||||
@@ -306,9 +306,9 @@ class DiffusionModel(nn.Module):
|
||||
}
|
||||
"""
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
|
||||
assert "observation.images" in batch or "observation.environment_state" in batch
|
||||
n_obs_steps = batch["observation.state"].shape[1]
|
||||
assert set(batch).issuperset({OBS_STATE, "action", "action_is_pad"})
|
||||
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
||||
n_obs_steps = batch[OBS_STATE].shape[1]
|
||||
horizon = batch["action"].shape[1]
|
||||
assert horizon == self.config.horizon
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@@ -113,7 +114,7 @@ class PI0Config(PreTrainedConfig):
|
||||
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
key = f"{OBS_IMAGES}.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
def display(tensor: torch.Tensor):
|
||||
@@ -60,26 +61,26 @@ def main():
|
||||
|
||||
# Override stats
|
||||
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
|
||||
dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
|
||||
dataset_meta.stats[OBS_STATE]["mean"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
|
||||
dataset_meta.stats[OBS_STATE]["std"] = torch.tensor(
|
||||
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Create LeRobot batch from Jax
|
||||
batch = {}
|
||||
for cam_key, uint_chw_array in example["images"].items():
|
||||
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch["observation.state"] = torch.from_numpy(example["state"])
|
||||
batch[f"{OBS_IMAGES}.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
|
||||
batch[OBS_STATE] = torch.from_numpy(example["state"])
|
||||
batch["action"] = torch.from_numpy(outputs["actions"])
|
||||
batch["task"] = example["prompt"]
|
||||
|
||||
if model_name == "pi0_aloha_towel":
|
||||
del batch["observation.images.cam_low"]
|
||||
del batch[f"{OBS_IMAGES}.cam_low"]
|
||||
elif model_name == "pi0_aloha_sim":
|
||||
batch["observation.images.top"] = batch["observation.images.cam_high"]
|
||||
del batch["observation.images.cam_high"]
|
||||
batch[f"{OBS_IMAGES}.top"] = batch[f"{OBS_IMAGES}.cam_high"]
|
||||
del batch[f"{OBS_IMAGES}.cam_high"]
|
||||
|
||||
# Batchify
|
||||
for key in batch:
|
||||
|
||||
@@ -6,6 +6,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0fast")
|
||||
@@ -99,7 +100,7 @@ class PI0FASTConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
key = f"{OBS_IMAGES}.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
|
||||
@@ -31,6 +31,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_STATE
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
|
||||
@@ -513,17 +514,17 @@ class SACObservationEncoder(nn.Module):
|
||||
)
|
||||
|
||||
def _init_state_layers(self) -> None:
|
||||
self.has_env = "observation.environment_state" in self.config.input_features
|
||||
self.has_state = "observation.state" in self.config.input_features
|
||||
self.has_env = OBS_ENV_STATE in self.config.input_features
|
||||
self.has_state = OBS_STATE in self.config.input_features
|
||||
if self.has_env:
|
||||
dim = self.config.input_features["observation.environment_state"].shape[0]
|
||||
dim = self.config.input_features[OBS_ENV_STATE].shape[0]
|
||||
self.env_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
if self.has_state:
|
||||
dim = self.config.input_features["observation.state"].shape[0]
|
||||
dim = self.config.input_features[OBS_STATE].shape[0]
|
||||
self.state_encoder = nn.Sequential(
|
||||
nn.Linear(dim, self.config.latent_dim),
|
||||
nn.LayerNorm(self.config.latent_dim),
|
||||
@@ -549,9 +550,9 @@ class SACObservationEncoder(nn.Module):
|
||||
cache = self.get_cached_image_features(obs)
|
||||
parts.append(self._encode_images(cache, detach))
|
||||
if self.has_env:
|
||||
parts.append(self.env_encoder(obs["observation.environment_state"]))
|
||||
parts.append(self.env_encoder(obs[OBS_ENV_STATE]))
|
||||
if self.has_state:
|
||||
parts.append(self.state_encoder(obs["observation.state"]))
|
||||
parts.append(self.state_encoder(obs[OBS_STATE]))
|
||||
if parts:
|
||||
return torch.cat(parts, dim=-1)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig, OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
||||
@@ -69,7 +70,7 @@ class RewardClassifierConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate feature configurations."""
|
||||
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||
has_image = any(key.startswith(OBS_IMAGE) for key in self.input_features)
|
||||
if not has_image:
|
||||
raise ValueError(
|
||||
"You must provide an image observation (key starting with 'observation.image') in the input features"
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("smolvla")
|
||||
@@ -117,7 +118,7 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
key = f"{OBS_IMAGES}.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
|
||||
@@ -38,7 +38,7 @@ from torch import Tensor
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD
|
||||
|
||||
|
||||
class TDMPCPolicy(PreTrainedPolicy):
|
||||
@@ -91,13 +91,13 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=1),
|
||||
OBS_STATE: deque(maxlen=1),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self.config.image_features:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
self._queues[OBS_IMAGE] = deque(maxlen=1)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
self._queues[OBS_ENV_STATE] = deque(maxlen=1)
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
@@ -325,7 +325,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
action = batch[ACTION] # (t, b, action_dim)
|
||||
reward = batch[REWARD] # (t, b)
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
observations = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)}
|
||||
|
||||
# Apply random image augmentations.
|
||||
if self.config.image_features and self.config.max_random_shift_ratio > 0:
|
||||
@@ -387,10 +387,10 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||
# `z_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# `z_targets` depends on the next observation.
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
@@ -403,7 +403,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||
* ~batch["next.reward_is_pad"]
|
||||
# `reward_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
@@ -419,11 +419,11 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
@@ -441,7 +441,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
temporal_loss_coeffs
|
||||
* raw_v_value_loss
|
||||
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
@@ -477,7 +477,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
* mse
|
||||
* temporal_loss_coeffs
|
||||
# `action_preds` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch[f"{OBS_STR}.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).mean()
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
batch.pop(ACTION)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
# NOTE: It's important that this happens after stacking the images into a single key.
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
@@ -340,14 +340,12 @@ class VQBeTModel(nn.Module):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.images"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
assert set(batch).issuperset({OBS_STATE, OBS_IMAGES})
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Extract image feature (first combine batch and sequence dims).
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
img_features = self.rgb_encoder(einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ..."))
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
|
||||
@@ -359,9 +357,7 @@ class VQBeTModel(nn.Module):
|
||||
img_features
|
||||
) # (batch, obs_step, number of different cameras, projection dims)
|
||||
input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))]
|
||||
input_tokens.append(
|
||||
self.state_projector(batch["observation.state"])
|
||||
) # (batch, obs_step, projection dims)
|
||||
input_tokens.append(self.state_projector(batch[OBS_STATE])) # (batch, obs_step, projection dims)
|
||||
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
|
||||
# Interleave tokens by stacking and rearranging.
|
||||
input_tokens = torch.stack(input_tokens, dim=2)
|
||||
|
||||
@@ -23,6 +23,8 @@ from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.utils.constants import OBS_PREFIX
|
||||
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
|
||||
|
||||
@@ -347,7 +349,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
||||
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
|
||||
|
||||
# Extract observation and complementary data keys.
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)}
|
||||
complementary_data = _extract_complementary_data(batch)
|
||||
|
||||
return create_transition(
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
@@ -171,7 +171,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1)
|
||||
for old_prefix, new_prefix in prefix_pairs.items():
|
||||
prefixed_old = f"observation.{old_prefix}"
|
||||
prefixed_old = f"{OBS_STR}.{old_prefix}"
|
||||
if key.startswith(prefixed_old):
|
||||
suffix = key[len(prefixed_old) :]
|
||||
new_key = f"{new_prefix}{suffix}"
|
||||
@@ -191,7 +191,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
|
||||
# Exact-name rules (pixels, environment_state, agent_pos)
|
||||
for old, new in exact_pairs.items():
|
||||
if key == old or key == f"observation.{old}":
|
||||
if key == old or key == f"{OBS_STR}.{old}":
|
||||
new_key = new
|
||||
new_features[src_ft][new_key] = feat
|
||||
handled = True
|
||||
|
||||
@@ -24,6 +24,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.transition import Transition
|
||||
|
||||
|
||||
@@ -240,7 +241,7 @@ class ReplayBuffer:
|
||||
idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device)
|
||||
|
||||
# Identify image keys that need augmentation
|
||||
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
|
||||
image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else []
|
||||
|
||||
# Create batched state and next_state
|
||||
batch_state = {}
|
||||
|
||||
@@ -73,6 +73,7 @@ from lerobot.teleoperators import (
|
||||
)
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
@@ -180,7 +181,7 @@ class RobotEnv(gym.Env):
|
||||
|
||||
# Define observation spaces for images and other states.
|
||||
if current_observation is not None and "pixels" in current_observation:
|
||||
prefix = "observation.images"
|
||||
prefix = OBS_IMAGES
|
||||
observation_spaces = {
|
||||
f"{prefix}.{key}": gym.spaces.Box(
|
||||
low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8
|
||||
@@ -190,7 +191,7 @@ class RobotEnv(gym.Env):
|
||||
|
||||
if current_observation is not None:
|
||||
agent_pos = current_observation["agent_pos"]
|
||||
observation_spaces["observation.state"] = gym.spaces.Box(
|
||||
observation_spaces[OBS_STATE] = gym.spaces.Box(
|
||||
low=0,
|
||||
high=10,
|
||||
shape=agent_pos.shape,
|
||||
@@ -612,7 +613,7 @@ def control_loop(
|
||||
}
|
||||
|
||||
for key, value in transition[TransitionKey.OBSERVATION].items():
|
||||
if key == "observation.state":
|
||||
if key == OBS_STATE:
|
||||
features[key] = {
|
||||
"dtype": "float32",
|
||||
"shape": value.squeeze(0).shape,
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing import Any
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
@@ -203,7 +204,7 @@ class LeKiwiClient(Robot):
|
||||
|
||||
state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32)
|
||||
|
||||
obs_dict: dict[str, Any] = {**flat_state, "observation.state": state_vec}
|
||||
obs_dict: dict[str, Any] = {**flat_state, OBS_STATE: state_vec}
|
||||
|
||||
# Decode images
|
||||
current_frames: dict[str, np.ndarray] = {}
|
||||
|
||||
@@ -75,6 +75,7 @@ import torch.utils.data
|
||||
import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
@@ -161,8 +162,8 @@ def visualize_dataset(
|
||||
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
|
||||
|
||||
# display each dimension of observed state space (e.g. agent position in joint space)
|
||||
if "observation.state" in batch:
|
||||
for dim_idx, val in enumerate(batch["observation.state"][i]):
|
||||
if OBS_STATE in batch:
|
||||
for dim_idx, val in enumerate(batch[OBS_STATE][i]):
|
||||
rr.log(f"state/{dim_idx}", rr.Scalar(val.item()))
|
||||
|
||||
if "next.done" in batch:
|
||||
|
||||
@@ -81,6 +81,7 @@ from lerobot.envs.utils import (
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
@@ -221,7 +222,7 @@ def rollout(
|
||||
stacked_observations = {}
|
||||
for key in all_observations[0]:
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
ret["observation"] = stacked_observations
|
||||
ret[OBS_STR] = stacked_observations
|
||||
|
||||
if hasattr(policy, "use_original_modules"):
|
||||
policy.use_original_modules()
|
||||
@@ -459,8 +460,8 @@ def _compile_episode_data(
|
||||
for k in ep_dict:
|
||||
ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]])
|
||||
|
||||
for key in rollout_data["observation"]:
|
||||
ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames]
|
||||
for key in rollout_data[OBS_STR]:
|
||||
ep_dict[key] = rollout_data[OBS_STR][key][ep_ix, :num_frames]
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
|
||||
@@ -109,6 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.control_utils import (
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
@@ -303,7 +304,7 @@ def record_loop(
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
if policy is not None or dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix="observation")
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
|
||||
@@ -17,19 +17,21 @@ from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import HF_HOME
|
||||
|
||||
OBS_ENV_STATE = "observation.environment_state"
|
||||
OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
OBS_LANGUAGE = "observation.language"
|
||||
OBS_STR = "observation"
|
||||
OBS_PREFIX = OBS_STR + "."
|
||||
OBS_ENV_STATE = OBS_STR + ".environment_state"
|
||||
OBS_STATE = OBS_STR + ".state"
|
||||
OBS_IMAGE = OBS_STR + ".image"
|
||||
OBS_IMAGES = OBS_IMAGE + "s"
|
||||
OBS_LANGUAGE = OBS_STR + ".language"
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
|
||||
ACTION = "action"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
DONE = "next.done"
|
||||
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
|
||||
ROBOTS = "robots"
|
||||
ROBOT_TYPE = "robot_type"
|
||||
TELEOPERATORS = "teleoperators"
|
||||
|
||||
@@ -19,6 +19,8 @@ from typing import Any
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from .constants import OBS_PREFIX, OBS_STR
|
||||
|
||||
|
||||
def init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
"""Initializes the Rerun SDK for visualizing the control loop."""
|
||||
@@ -63,7 +65,7 @@ def log_rerun_data(
|
||||
for k, v in observation.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("observation.") else f"observation.{k}"
|
||||
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
|
||||
@@ -24,6 +24,7 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
|
||||
|
||||
@@ -92,7 +93,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
# for backward compatibility
|
||||
if k == "task":
|
||||
continue
|
||||
if k.startswith("observation"):
|
||||
if k.startswith(OBS_STR):
|
||||
obs[k] = batch[k]
|
||||
|
||||
if hasattr(train_cfg.policy, "n_action_steps"):
|
||||
|
||||
@@ -30,6 +30,7 @@ from lerobot.async_inference.helpers import (
|
||||
resize_robot_observation_image,
|
||||
)
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# FPSTracker
|
||||
@@ -115,7 +116,7 @@ def test_timed_action_getters():
|
||||
def test_timed_observation_getters():
|
||||
"""TimedObservation stores & returns timestamp, dict and timestep."""
|
||||
ts = time.time()
|
||||
obs_dict = {"observation.state": torch.ones(6)}
|
||||
obs_dict = {OBS_STATE: torch.ones(6)}
|
||||
to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0)
|
||||
|
||||
assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6)
|
||||
@@ -151,7 +152,7 @@ def test_timed_data_deserialization_data_getters():
|
||||
# ------------------------------------------------------------------
|
||||
# TimedObservation
|
||||
# ------------------------------------------------------------------
|
||||
obs_dict = {"observation.state": torch.arange(4).float()}
|
||||
obs_dict = {OBS_STATE: torch.arange(4).float()}
|
||||
to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True)
|
||||
|
||||
to_bytes = pickle.dumps(to_in) # nosec
|
||||
@@ -161,7 +162,7 @@ def test_timed_data_deserialization_data_getters():
|
||||
assert to_out.get_timestep() == 7
|
||||
assert to_out.must_go is True
|
||||
assert to_out.get_observation().keys() == obs_dict.keys()
|
||||
torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"])
|
||||
torch.testing.assert_close(to_out.get_observation()[OBS_STATE], obs_dict[OBS_STATE])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
@@ -187,7 +188,7 @@ def test_observations_similar_true():
|
||||
"""Distance below atol → observations considered similar."""
|
||||
# Create mock lerobot features for the similarity check
|
||||
lerobot_features = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
@@ -222,17 +223,17 @@ def _create_mock_robot_observation():
|
||||
def _create_mock_lerobot_features():
|
||||
"""Create mock lerobot features mapping similar to what hw_to_dataset_features returns."""
|
||||
return {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
"observation.images.laptop": {
|
||||
f"{OBS_IMAGES}.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"observation.images.phone": {
|
||||
f"{OBS_IMAGES}.phone": {
|
||||
"dtype": "image",
|
||||
"shape": [480, 640, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
@@ -243,11 +244,11 @@ def _create_mock_lerobot_features():
|
||||
def _create_mock_policy_image_features():
|
||||
"""Create mock policy image features with different resolutions."""
|
||||
return {
|
||||
"observation.images.laptop": PolicyFeature(
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224), # Policy expects smaller resolution
|
||||
),
|
||||
"observation.images.phone": PolicyFeature(
|
||||
f"{OBS_IMAGES}.phone": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 160, 160), # Different resolution for second camera
|
||||
),
|
||||
@@ -306,21 +307,21 @@ def test_prepare_raw_observation():
|
||||
prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features)
|
||||
|
||||
# Check that state is properly extracted and batched
|
||||
assert "observation.state" in prepared
|
||||
state = prepared["observation.state"]
|
||||
assert OBS_STATE in prepared
|
||||
state = prepared[OBS_STATE]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.shape == (1, 4) # Batched state
|
||||
|
||||
# Check that images are processed and resized
|
||||
assert "observation.images.laptop" in prepared
|
||||
assert "observation.images.phone" in prepared
|
||||
assert f"{OBS_IMAGES}.laptop" in prepared
|
||||
assert f"{OBS_IMAGES}.phone" in prepared
|
||||
|
||||
laptop_img = prepared["observation.images.laptop"]
|
||||
phone_img = prepared["observation.images.phone"]
|
||||
laptop_img = prepared[f"{OBS_IMAGES}.laptop"]
|
||||
phone_img = prepared[f"{OBS_IMAGES}.phone"]
|
||||
|
||||
# Check image shapes match policy requirements
|
||||
assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape
|
||||
assert phone_img.shape == policy_image_features["observation.images.phone"].shape
|
||||
assert laptop_img.shape == policy_image_features[f"{OBS_IMAGES}.laptop"].shape
|
||||
assert phone_img.shape == policy_image_features[f"{OBS_IMAGES}.phone"].shape
|
||||
|
||||
# Check that images are tensors
|
||||
assert isinstance(laptop_img, torch.Tensor)
|
||||
@@ -337,19 +338,19 @@ def test_raw_observation_to_observation_basic():
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device)
|
||||
|
||||
# Check that all expected keys are present
|
||||
assert "observation.state" in observation
|
||||
assert "observation.images.laptop" in observation
|
||||
assert "observation.images.phone" in observation
|
||||
assert OBS_STATE in observation
|
||||
assert f"{OBS_IMAGES}.laptop" in observation
|
||||
assert f"{OBS_IMAGES}.phone" in observation
|
||||
|
||||
# Check state processing
|
||||
state = observation["observation.state"]
|
||||
state = observation[OBS_STATE]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.device.type == device
|
||||
assert state.shape == (1, 4) # Batched
|
||||
|
||||
# Check image processing
|
||||
laptop_img = observation["observation.images.laptop"]
|
||||
phone_img = observation["observation.images.phone"]
|
||||
laptop_img = observation[f"{OBS_IMAGES}.laptop"]
|
||||
phone_img = observation[f"{OBS_IMAGES}.phone"]
|
||||
|
||||
# Images should have batch dimension: (B, C, H, W)
|
||||
assert laptop_img.shape == (1, 3, 224, 224)
|
||||
@@ -429,19 +430,19 @@ def test_image_processing_pipeline_preserves_content():
|
||||
|
||||
robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img}
|
||||
lerobot_features = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [4],
|
||||
"names": ["shoulder", "elbow", "wrist", "gripper"],
|
||||
},
|
||||
"observation.images.laptop": {
|
||||
f"{OBS_IMAGES}.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": [100, 100, 3],
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
}
|
||||
policy_image_features = {
|
||||
"observation.images.laptop": PolicyFeature(
|
||||
f"{OBS_IMAGES}.laptop": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 50, 50), # Downsamples from 100x100
|
||||
)
|
||||
@@ -449,7 +450,7 @@ def test_image_processing_pipeline_preserves_content():
|
||||
|
||||
observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu")
|
||||
|
||||
processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim
|
||||
processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim
|
||||
|
||||
# Check that the center region has higher values than corners
|
||||
# Due to bilinear interpolation, exact values will change but pattern should remain
|
||||
|
||||
@@ -23,6 +23,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from tests.utils import require_package
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -44,7 +45,7 @@ class MockPolicy:
|
||||
|
||||
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return a chunk of 20 dummy actions."""
|
||||
batch_size = len(observation["observation.state"])
|
||||
batch_size = len(observation[OBS_STATE])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
def __init__(self):
|
||||
@@ -77,7 +78,7 @@ def policy_server():
|
||||
|
||||
# Add mock lerobot_features that the observation similarity functions need
|
||||
server.lerobot_features = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [6],
|
||||
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],
|
||||
|
||||
@@ -28,6 +28,7 @@ from lerobot.datasets.compute_stats import (
|
||||
sample_images,
|
||||
sample_indices,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
@@ -136,21 +137,21 @@ def test_get_feature_stats_single_value():
|
||||
|
||||
def test_compute_episode_stats():
|
||||
episode_data = {
|
||||
"observation.image": [f"image_{i}.jpg" for i in range(100)],
|
||||
"observation.state": np.random.rand(100, 10),
|
||||
OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)],
|
||||
OBS_STATE: np.random.rand(100, 10),
|
||||
}
|
||||
features = {
|
||||
"observation.image": {"dtype": "image"},
|
||||
"observation.state": {"dtype": "numeric"},
|
||||
OBS_IMAGE: {"dtype": "image"},
|
||||
OBS_STATE: {"dtype": "numeric"},
|
||||
}
|
||||
|
||||
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
assert "observation.image" in stats and "observation.state" in stats
|
||||
assert stats["observation.image"]["count"].item() == 100
|
||||
assert stats["observation.state"]["count"].item() == 100
|
||||
assert stats["observation.image"]["mean"].shape == (3, 1, 1)
|
||||
assert OBS_IMAGE in stats and OBS_STATE in stats
|
||||
assert stats[OBS_IMAGE]["count"].item() == 100
|
||||
assert stats[OBS_STATE]["count"].item() == 100
|
||||
assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_valid():
|
||||
@@ -224,38 +225,38 @@ def test_aggregate_feature_stats():
|
||||
def test_aggregate_stats():
|
||||
all_stats = [
|
||||
{
|
||||
"observation.image": {
|
||||
OBS_IMAGE: {
|
||||
"min": [1, 2, 3],
|
||||
"max": [10, 20, 30],
|
||||
"mean": [5.5, 10.5, 15.5],
|
||||
"std": [2.87, 5.87, 8.87],
|
||||
"count": 10,
|
||||
},
|
||||
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
|
||||
OBS_STATE: {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
|
||||
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
|
||||
},
|
||||
{
|
||||
"observation.image": {
|
||||
OBS_IMAGE: {
|
||||
"min": [2, 1, 0],
|
||||
"max": [15, 10, 5],
|
||||
"mean": [8.5, 5.5, 2.5],
|
||||
"std": [3.42, 2.42, 1.42],
|
||||
"count": 15,
|
||||
},
|
||||
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
|
||||
OBS_STATE: {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
|
||||
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
|
||||
},
|
||||
]
|
||||
|
||||
expected_agg_stats = {
|
||||
"observation.image": {
|
||||
OBS_IMAGE: {
|
||||
"min": [1, 1, 0],
|
||||
"max": [15, 20, 30],
|
||||
"mean": [7.3, 7.5, 7.7],
|
||||
"std": [3.5317, 4.8267, 8.5581],
|
||||
"count": 25,
|
||||
},
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"min": 1,
|
||||
"max": 15,
|
||||
"mean": 7.3,
|
||||
@@ -283,7 +284,7 @@ def test_aggregate_stats():
|
||||
for fkey, stats in ep_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
if fkey == OBS_IMAGE and k != "count":
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
@@ -292,7 +293,7 @@ def test_aggregate_stats():
|
||||
for fkey, stats in expected_agg_stats.items():
|
||||
for k in stats:
|
||||
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
||||
if fkey == "observation.image" and k != "count":
|
||||
if fkey == OBS_IMAGE and k != "count":
|
||||
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
||||
else:
|
||||
stats[k] = stats[k].reshape(1)
|
||||
|
||||
@@ -21,6 +21,7 @@ from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
@@ -96,14 +97,14 @@ def test_merge_multiple_groups_order_and_dedup():
|
||||
def test_non_vector_last_wins_for_images():
|
||||
# Non-vector (images) with same name should be overwritten by the last image specified
|
||||
g1 = {
|
||||
"observation.images.front": {
|
||||
f"{OBS_IMAGES}.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 480, 640),
|
||||
"names": ["channels", "height", "width"],
|
||||
}
|
||||
}
|
||||
g2 = {
|
||||
"observation.images.front": {
|
||||
f"{OBS_IMAGES}.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 720, 1280),
|
||||
"names": ["channels", "height", "width"],
|
||||
@@ -111,8 +112,8 @@ def test_non_vector_last_wins_for_images():
|
||||
}
|
||||
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
assert out["observation.images.front"]["shape"] == (3, 720, 1280)
|
||||
assert out["observation.images.front"]["dtype"] == "image"
|
||||
assert out[f"{OBS_IMAGES}.front"]["shape"] == (3, 720, 1280)
|
||||
assert out[f"{OBS_IMAGES}.front"]["dtype"] == "image"
|
||||
|
||||
|
||||
def test_dtype_mismatch_raises():
|
||||
|
||||
@@ -46,6 +46,7 @@ from lerobot.datasets.utils import (
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
@@ -75,7 +76,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
# Instantiate both ways
|
||||
robot = make_robot_from_config(MockRobotConfig())
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action", True)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR, True)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(
|
||||
@@ -397,7 +398,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
("frame_index", 0, True),
|
||||
("timestamp", 0, True),
|
||||
# TODO(rcadene): should we rename it agent_pos?
|
||||
("observation.state", 1, True),
|
||||
(OBS_STATE, 1, True),
|
||||
("next.reward", 0, False),
|
||||
("next.done", 0, False),
|
||||
]
|
||||
@@ -662,7 +663,7 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
|
||||
def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test the update_chunk_settings functionality for both LeRobotDataset and LeRobotDatasetMetadata."""
|
||||
features = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"],
|
||||
@@ -769,7 +770,7 @@ def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
|
||||
def test_update_chunk_settings_video_dataset(tmp_path):
|
||||
"""Test update_chunk_settings with a video dataset to ensure video-specific logic works."""
|
||||
features = {
|
||||
"observation.images.cam": {
|
||||
f"{OBS_IMAGES}.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -41,7 +42,7 @@ def test_binary_classifier_with_default_params():
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
@@ -56,7 +57,7 @@ def test_binary_classifier_with_default_params():
|
||||
batch_size = 10
|
||||
|
||||
input = {
|
||||
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
||||
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
||||
}
|
||||
|
||||
@@ -83,7 +84,7 @@ def test_multiclass_classifier():
|
||||
num_classes = 5
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
|
||||
@@ -95,7 +96,7 @@ def test_multiclass_classifier():
|
||||
batch_size = 10
|
||||
|
||||
input = {
|
||||
"observation.image": torch.rand((batch_size, 3, 128, 128)),
|
||||
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.rand((batch_size, num_classes)),
|
||||
}
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ from lerobot.policies.factory import (
|
||||
make_pre_post_processors,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
@@ -52,7 +52,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
||||
# Create only one camera input which is squared to fit all current policy constraints
|
||||
# e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared
|
||||
camera_features = {
|
||||
"observation.images.laptop": {
|
||||
f"{OBS_IMAGES}.laptop": {
|
||||
"shape": (84, 84, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
@@ -64,7 +64,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
},
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
@@ -281,7 +281,7 @@ def test_multikey_construction(multikey: bool):
|
||||
preventing erroneous creation of the policy object.
|
||||
"""
|
||||
input_features = {
|
||||
"observation.state": PolicyFeature(
|
||||
OBS_STATE: PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(10,),
|
||||
),
|
||||
@@ -297,9 +297,9 @@ def test_multikey_construction(multikey: bool):
|
||||
"""Simulates the complete state/action is constructed from more granular multiple
|
||||
keys, of the same type as the overall state/action"""
|
||||
input_features = {}
|
||||
input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
||||
input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
||||
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
||||
input_features[f"{OBS_STATE}.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
||||
input_features[f"{OBS_STATE}.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
||||
input_features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
||||
|
||||
output_features = {}
|
||||
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))
|
||||
|
||||
@@ -25,6 +25,7 @@ from lerobot.policies.sac.configuration_sac import (
|
||||
PolicyConfig,
|
||||
SACConfig,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def test_sac_config_default_initialization():
|
||||
@@ -37,11 +38,11 @@ def test_sac_config_default_initialization():
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
assert config.dataset_stats == {
|
||||
"observation.image": {
|
||||
OBS_IMAGE: {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
@@ -90,11 +91,11 @@ def test_sac_config_default_initialization():
|
||||
|
||||
# Dataset stats defaults
|
||||
expected_dataset_stats = {
|
||||
"observation.image": {
|
||||
OBS_IMAGE: {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
@@ -191,7 +192,7 @@ def test_sac_config_custom_initialization():
|
||||
|
||||
def test_validate_features():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
config.validate_features()
|
||||
@@ -210,7 +211,7 @@ def test_validate_features_missing_observation():
|
||||
|
||||
def test_validate_features_missing_action():
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):
|
||||
|
||||
@@ -23,6 +23,7 @@ from torch import Tensor, nn
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
|
||||
try:
|
||||
@@ -85,14 +86,14 @@ def test_sac_policy_with_default_args():
|
||||
|
||||
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
return {
|
||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
@@ -126,14 +127,14 @@ def create_train_batch_with_visual_input(
|
||||
|
||||
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
}
|
||||
|
||||
|
||||
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
|
||||
return {
|
||||
"observation.state": torch.randn(batch_size, state_dim),
|
||||
"observation.image": torch.randn(batch_size, 3, 84, 84),
|
||||
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
|
||||
}
|
||||
|
||||
|
||||
@@ -180,10 +181,10 @@ def create_default_config(
|
||||
action_dim += 1
|
||||
|
||||
config = SACConfig(
|
||||
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
|
||||
dataset_stats={
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"min": [0.0] * state_dim,
|
||||
"max": [1.0] * state_dim,
|
||||
},
|
||||
@@ -205,8 +206,8 @@ def create_config_with_visual_input(
|
||||
continuous_action_dim=continuous_action_dim,
|
||||
has_discrete_action=has_discrete_action,
|
||||
)
|
||||
config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||
config.dataset_stats["observation.image"] = {
|
||||
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||
config.dataset_stats[OBS_IMAGE] = {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
}
|
||||
|
||||
@@ -342,7 +342,7 @@ def test_act_processor_batch_consistency():
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
assert processed["observation.state"].shape[0] == 1 # Batched
|
||||
assert processed[OBS_STATE].shape[0] == 1 # Batched
|
||||
|
||||
# Test already batched data
|
||||
observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8
|
||||
|
||||
@@ -2,14 +2,15 @@ import torch
|
||||
|
||||
from lerobot.processor import DataProcessorPipeline, TransitionKey
|
||||
from lerobot.processor.converters import batch_to_transition, transition_to_batch
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_PREFIX, OBS_STATE
|
||||
|
||||
|
||||
def _dummy_batch():
|
||||
"""Create a dummy batch using the new format with observation.* and next.* keys."""
|
||||
return {
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.right": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
|
||||
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
|
||||
f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128),
|
||||
OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
|
||||
"action": torch.tensor([[0.5]]),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
@@ -25,15 +26,15 @@ def test_observation_grouping_roundtrip():
|
||||
batch_out = proc(batch_in)
|
||||
|
||||
# Check that all observation.* keys are preserved
|
||||
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")}
|
||||
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")}
|
||||
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith(OBS_PREFIX)}
|
||||
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith(OBS_PREFIX)}
|
||||
|
||||
assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys())
|
||||
|
||||
# Check tensor values
|
||||
assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"])
|
||||
assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"])
|
||||
assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"])
|
||||
assert torch.allclose(batch_out[f"{OBS_IMAGE}.left"], batch_in[f"{OBS_IMAGE}.left"])
|
||||
assert torch.allclose(batch_out[f"{OBS_IMAGE}.right"], batch_in[f"{OBS_IMAGE}.right"])
|
||||
assert torch.allclose(batch_out[OBS_STATE], batch_in[OBS_STATE])
|
||||
|
||||
# Check other fields
|
||||
assert torch.allclose(batch_out["action"], batch_in["action"])
|
||||
@@ -46,9 +47,9 @@ def test_observation_grouping_roundtrip():
|
||||
def test_batch_to_transition_observation_grouping():
|
||||
"""Test that batch_to_transition correctly groups observation.* keys."""
|
||||
batch = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128),
|
||||
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
|
||||
OBS_STATE: [1, 2, 3, 4],
|
||||
"action": torch.tensor([0.1, 0.2, 0.3, 0.4]),
|
||||
"next.reward": 1.5,
|
||||
"next.done": True,
|
||||
@@ -60,18 +61,18 @@ def test_batch_to_transition_observation_grouping():
|
||||
|
||||
# Check observation is a dict with all observation.* keys
|
||||
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
|
||||
assert "observation.image.top" in transition[TransitionKey.OBSERVATION]
|
||||
assert "observation.image.left" in transition[TransitionKey.OBSERVATION]
|
||||
assert "observation.state" in transition[TransitionKey.OBSERVATION]
|
||||
assert f"{OBS_IMAGE}.top" in transition[TransitionKey.OBSERVATION]
|
||||
assert f"{OBS_IMAGE}.left" in transition[TransitionKey.OBSERVATION]
|
||||
assert OBS_STATE in transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check values are preserved
|
||||
assert torch.allclose(
|
||||
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
|
||||
transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.top"], batch[f"{OBS_IMAGE}.top"]
|
||||
)
|
||||
assert torch.allclose(
|
||||
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
|
||||
transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.left"], batch[f"{OBS_IMAGE}.left"]
|
||||
)
|
||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
assert transition[TransitionKey.OBSERVATION][OBS_STATE] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4]))
|
||||
@@ -85,9 +86,9 @@ def test_batch_to_transition_observation_grouping():
|
||||
def test_transition_to_batch_observation_flattening():
|
||||
"""Test that transition_to_batch correctly flattens observation dict."""
|
||||
observation_dict = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128),
|
||||
f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128),
|
||||
OBS_STATE: [1, 2, 3, 4],
|
||||
}
|
||||
|
||||
transition = {
|
||||
@@ -103,14 +104,14 @@ def test_transition_to_batch_observation_flattening():
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Check that observation.* keys are flattened back to batch
|
||||
assert "observation.image.top" in batch
|
||||
assert "observation.image.left" in batch
|
||||
assert "observation.state" in batch
|
||||
assert f"{OBS_IMAGE}.top" in batch
|
||||
assert f"{OBS_IMAGE}.left" in batch
|
||||
assert OBS_STATE in batch
|
||||
|
||||
# Check values are preserved
|
||||
assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"])
|
||||
assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"])
|
||||
assert batch["observation.state"] == [1, 2, 3, 4]
|
||||
assert torch.allclose(batch[f"{OBS_IMAGE}.top"], observation_dict[f"{OBS_IMAGE}.top"])
|
||||
assert torch.allclose(batch[f"{OBS_IMAGE}.left"], observation_dict[f"{OBS_IMAGE}.left"])
|
||||
assert batch[OBS_STATE] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields are mapped to next.* format
|
||||
assert batch["action"] == "action_data"
|
||||
@@ -153,12 +154,12 @@ def test_no_observation_keys():
|
||||
|
||||
def test_minimal_batch():
|
||||
"""Test with minimal batch containing only observation.* and action."""
|
||||
batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])}
|
||||
batch = {OBS_STATE: "minimal_state", "action": torch.tensor([0.5])}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert transition[TransitionKey.OBSERVATION] == {OBS_STATE: "minimal_state"}
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5]))
|
||||
|
||||
# Check defaults
|
||||
@@ -170,7 +171,7 @@ def test_minimal_batch():
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||
assert reconstructed_batch[OBS_STATE] == "minimal_state"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5]))
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
@@ -205,9 +206,9 @@ def test_empty_batch():
|
||||
def test_complex_nested_observation():
|
||||
"""Test with complex nested observation data."""
|
||||
batch = {
|
||||
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
||||
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
||||
"observation.state": torch.randn(7),
|
||||
f"{OBS_IMAGE}.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
||||
f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
||||
OBS_STATE: torch.randn(7),
|
||||
"action": torch.randn(8),
|
||||
"next.reward": 3.14,
|
||||
"next.done": False,
|
||||
@@ -219,20 +220,20 @@ def test_complex_nested_observation():
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
|
||||
# Check that all observation keys are preserved
|
||||
original_obs_keys = {k for k in batch if k.startswith("observation.")}
|
||||
reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")}
|
||||
original_obs_keys = {k for k in batch if k.startswith(OBS_PREFIX)}
|
||||
reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith(OBS_PREFIX)}
|
||||
|
||||
assert original_obs_keys == reconstructed_obs_keys
|
||||
|
||||
# Check tensor values
|
||||
assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"])
|
||||
assert torch.allclose(batch[OBS_STATE], reconstructed_batch[OBS_STATE])
|
||||
|
||||
# Check nested dict with tensors
|
||||
assert torch.allclose(
|
||||
batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"]
|
||||
batch[f"{OBS_IMAGE}.top"]["image"], reconstructed_batch[f"{OBS_IMAGE}.top"]["image"]
|
||||
)
|
||||
assert torch.allclose(
|
||||
batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"]
|
||||
batch[f"{OBS_IMAGE}.left"]["image"], reconstructed_batch[f"{OBS_IMAGE}.left"]["image"]
|
||||
)
|
||||
|
||||
# Check action tensor
|
||||
@@ -264,7 +265,7 @@ def test_custom_converter():
|
||||
processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch)
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 4),
|
||||
OBS_STATE: torch.randn(1, 4),
|
||||
"action": torch.randn(1, 2),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
@@ -274,5 +275,5 @@ def test_custom_converter():
|
||||
|
||||
# Check the reward was doubled by our custom converter
|
||||
assert result["next.reward"] == 2.0
|
||||
assert torch.allclose(result["observation.state"], batch["observation.state"])
|
||||
assert torch.allclose(result[OBS_STATE], batch[OBS_STATE])
|
||||
assert torch.allclose(result["action"], batch["action"])
|
||||
|
||||
@@ -9,6 +9,7 @@ from lerobot.processor.converters import (
|
||||
to_tensor,
|
||||
transition_to_batch,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STATE, OBS_STR
|
||||
|
||||
|
||||
# Tests for the unified to_tensor function
|
||||
@@ -118,16 +119,16 @@ def test_to_tensor_dictionaries():
|
||||
# Nested dictionary
|
||||
nested = {
|
||||
"action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
|
||||
"observation": {"mean": np.array([0.5, 0.6]), "count": 10},
|
||||
OBS_STR: {"mean": np.array([0.5, 0.6]), "count": 10},
|
||||
}
|
||||
result = to_tensor(nested)
|
||||
assert isinstance(result, dict)
|
||||
assert isinstance(result["action"], dict)
|
||||
assert isinstance(result["observation"], dict)
|
||||
assert isinstance(result[OBS_STR], dict)
|
||||
assert isinstance(result["action"]["mean"], torch.Tensor)
|
||||
assert isinstance(result["observation"]["mean"], torch.Tensor)
|
||||
assert isinstance(result[OBS_STR]["mean"], torch.Tensor)
|
||||
assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6]))
|
||||
assert torch.allclose(result[OBS_STR]["mean"], torch.tensor([0.5, 0.6]))
|
||||
|
||||
|
||||
def test_to_tensor_none_filtering():
|
||||
@@ -198,7 +199,7 @@ def test_batch_to_transition_with_index_fields():
|
||||
|
||||
# Create batch with index and task_index fields
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 7),
|
||||
OBS_STATE: torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
"next.reward": 1.5,
|
||||
"next.done": False,
|
||||
@@ -231,7 +232,7 @@ def testtransition_to_batch_with_index_fields():
|
||||
|
||||
# Create transition with index and task_index in complementary_data
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
observation={OBS_STATE: torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
reward=1.5,
|
||||
done=False,
|
||||
@@ -260,7 +261,7 @@ def test_batch_to_transition_without_index_fields():
|
||||
|
||||
# Batch without index/task_index
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 7),
|
||||
OBS_STATE: torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
"task": ["pick_cube"],
|
||||
}
|
||||
@@ -279,7 +280,7 @@ def test_transition_to_batch_without_index_fields():
|
||||
|
||||
# Transition without index/task_index
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
observation={OBS_STATE: torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
complementary_data={"task": ["navigate"]},
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def test_basic_functionality():
|
||||
@@ -28,7 +29,7 @@ def test_basic_functionality():
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
# Create a transition with CPU tensors
|
||||
observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)}
|
||||
observation = {OBS_STATE: torch.randn(10), OBS_IMAGE: torch.randn(3, 224, 224)}
|
||||
action = torch.randn(5)
|
||||
reward = torch.tensor(1.0)
|
||||
done = torch.tensor(False)
|
||||
@@ -41,8 +42,8 @@ def test_basic_functionality():
|
||||
result = processor(transition)
|
||||
|
||||
# Check that all tensors are on CPU
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
|
||||
assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cpu"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cpu"
|
||||
assert result[TransitionKey.ACTION].device.type == "cpu"
|
||||
assert result[TransitionKey.REWARD].device.type == "cpu"
|
||||
assert result[TransitionKey.DONE].device.type == "cpu"
|
||||
@@ -55,7 +56,7 @@ def test_cuda_functionality():
|
||||
processor = DeviceProcessorStep(device="cuda")
|
||||
|
||||
# Create a transition with CPU tensors
|
||||
observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)}
|
||||
observation = {OBS_STATE: torch.randn(10), OBS_IMAGE: torch.randn(3, 224, 224)}
|
||||
action = torch.randn(5)
|
||||
reward = torch.tensor(1.0)
|
||||
done = torch.tensor(False)
|
||||
@@ -68,8 +69,8 @@ def test_cuda_functionality():
|
||||
result = processor(transition)
|
||||
|
||||
# Check that all tensors are on CUDA
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
assert result[TransitionKey.REWARD].device.type == "cuda"
|
||||
assert result[TransitionKey.DONE].device.type == "cuda"
|
||||
@@ -81,14 +82,14 @@ def test_specific_cuda_device():
|
||||
"""Test device processor with specific CUDA device."""
|
||||
processor = DeviceProcessorStep(device="cuda:0")
|
||||
|
||||
observation = {"observation.state": torch.randn(10)}
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
result = processor(transition)
|
||||
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.index == 0
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.index == 0
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.index == 0
|
||||
|
||||
@@ -98,7 +99,7 @@ def test_non_tensor_values():
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.randn(10),
|
||||
OBS_STATE: torch.randn(10),
|
||||
"observation.metadata": {"key": "value"}, # Non-tensor data
|
||||
"observation.list": [1, 2, 3], # Non-tensor data
|
||||
}
|
||||
@@ -110,7 +111,7 @@ def test_non_tensor_values():
|
||||
result = processor(transition)
|
||||
|
||||
# Check tensors are processed
|
||||
assert isinstance(result[TransitionKey.OBSERVATION]["observation.state"], torch.Tensor)
|
||||
assert isinstance(result[TransitionKey.OBSERVATION][OBS_STATE], torch.Tensor)
|
||||
assert isinstance(result[TransitionKey.ACTION], torch.Tensor)
|
||||
|
||||
# Check non-tensor values are preserved
|
||||
@@ -130,9 +131,9 @@ def test_none_values():
|
||||
assert result[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
# Test with None action
|
||||
transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None)
|
||||
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=None)
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
assert result[TransitionKey.ACTION] is None
|
||||
|
||||
|
||||
@@ -271,9 +272,7 @@ def test_features():
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
features = {
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
||||
},
|
||||
PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
|
||||
}
|
||||
|
||||
@@ -376,7 +375,7 @@ def test_reward_done_truncated_types():
|
||||
|
||||
# Test with scalar values (not tensors)
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(5)},
|
||||
observation={OBS_STATE: torch.randn(5)},
|
||||
action=torch.randn(3),
|
||||
reward=1.0, # float
|
||||
done=False, # bool
|
||||
@@ -392,7 +391,7 @@ def test_reward_done_truncated_types():
|
||||
|
||||
# Test with tensor values
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(5)},
|
||||
observation={OBS_STATE: torch.randn(5)},
|
||||
action=torch.randn(3),
|
||||
reward=torch.tensor(1.0),
|
||||
done=torch.tensor(False),
|
||||
@@ -422,7 +421,7 @@ def test_complementary_data_preserved():
|
||||
}
|
||||
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(5)}, complementary_data=complementary_data
|
||||
observation={OBS_STATE: torch.randn(5)}, complementary_data=complementary_data
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
@@ -491,13 +490,13 @@ def test_float_dtype_bfloat16():
|
||||
"""Test conversion to bfloat16."""
|
||||
processor = DeviceProcessorStep(device="cpu", float_dtype="bfloat16")
|
||||
|
||||
observation = {"observation.state": torch.randn(5, dtype=torch.float32)}
|
||||
observation = {OBS_STATE: torch.randn(5, dtype=torch.float32)}
|
||||
action = torch.randn(3, dtype=torch.float64)
|
||||
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
result = processor(transition)
|
||||
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.bfloat16
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
|
||||
assert result[TransitionKey.ACTION].dtype == torch.bfloat16
|
||||
|
||||
|
||||
@@ -505,13 +504,13 @@ def test_float_dtype_float64():
|
||||
"""Test conversion to float64."""
|
||||
processor = DeviceProcessorStep(device="cpu", float_dtype="float64")
|
||||
|
||||
observation = {"observation.state": torch.randn(5, dtype=torch.float16)}
|
||||
observation = {OBS_STATE: torch.randn(5, dtype=torch.float16)}
|
||||
action = torch.randn(3, dtype=torch.float32)
|
||||
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
result = processor(transition)
|
||||
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float64
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float64
|
||||
assert result[TransitionKey.ACTION].dtype == torch.float64
|
||||
|
||||
|
||||
@@ -541,8 +540,8 @@ def test_float_dtype_with_mixed_tensors():
|
||||
processor = DeviceProcessorStep(device="cpu", float_dtype="float32")
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert
|
||||
"observation.state": torch.randn(10, dtype=torch.float64), # Should convert
|
||||
OBS_IMAGE: torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert
|
||||
OBS_STATE: torch.randn(10, dtype=torch.float64), # Should convert
|
||||
"observation.mask": torch.tensor([True, False, True], dtype=torch.bool), # Should not convert
|
||||
"observation.indices": torch.tensor([1, 2, 3], dtype=torch.long), # Should not convert
|
||||
}
|
||||
@@ -552,8 +551,8 @@ def test_float_dtype_with_mixed_tensors():
|
||||
result = processor(transition)
|
||||
|
||||
# Check conversions
|
||||
assert result[TransitionKey.OBSERVATION]["observation.image"].dtype == torch.uint8 # Unchanged
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted
|
||||
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.uint8 # Unchanged
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float32 # Converted
|
||||
assert result[TransitionKey.OBSERVATION]["observation.mask"].dtype == torch.bool # Unchanged
|
||||
assert result[TransitionKey.OBSERVATION]["observation.indices"].dtype == torch.long # Unchanged
|
||||
assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted
|
||||
@@ -612,7 +611,7 @@ def test_complementary_data_index_fields():
|
||||
"episode_id": 123, # Non-tensor field
|
||||
}
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
observation={OBS_STATE: torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
complementary_data=complementary_data,
|
||||
)
|
||||
@@ -736,7 +735,7 @@ def test_complementary_data_full_pipeline_cuda():
|
||||
processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16")
|
||||
|
||||
# Create full transition with mixed CPU tensors
|
||||
observation = {"observation.state": torch.randn(1, 7, dtype=torch.float32)}
|
||||
observation = {OBS_STATE: torch.randn(1, 7, dtype=torch.float32)}
|
||||
action = torch.randn(1, 4, dtype=torch.float32)
|
||||
reward = torch.tensor(1.5, dtype=torch.float32)
|
||||
done = torch.tensor(False)
|
||||
@@ -757,7 +756,7 @@ def test_complementary_data_full_pipeline_cuda():
|
||||
result = processor(transition)
|
||||
|
||||
# Check all components moved to CUDA
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
assert result[TransitionKey.REWARD].device.type == "cuda"
|
||||
assert result[TransitionKey.DONE].device.type == "cuda"
|
||||
@@ -768,7 +767,7 @@ def test_complementary_data_full_pipeline_cuda():
|
||||
assert processed_comp_data["task_index"].device.type == "cuda"
|
||||
|
||||
# Check float conversion happened for float tensors
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float16
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16
|
||||
assert result[TransitionKey.ACTION].dtype == torch.float16
|
||||
assert result[TransitionKey.REWARD].dtype == torch.float16
|
||||
|
||||
@@ -782,7 +781,7 @@ def test_complementary_data_empty():
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
observation={OBS_STATE: torch.randn(1, 7)},
|
||||
complementary_data={},
|
||||
)
|
||||
|
||||
@@ -797,7 +796,7 @@ def test_complementary_data_none():
|
||||
processor = DeviceProcessorStep(device="cpu")
|
||||
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
observation={OBS_STATE: torch.randn(1, 7)},
|
||||
complementary_data=None,
|
||||
)
|
||||
|
||||
@@ -814,8 +813,8 @@ def test_preserves_gpu_placement():
|
||||
|
||||
# Create tensors already on GPU
|
||||
observation = {
|
||||
"observation.state": torch.randn(10).cuda(), # Already on GPU
|
||||
"observation.image": torch.randn(3, 224, 224).cuda(), # Already on GPU
|
||||
OBS_STATE: torch.randn(10).cuda(), # Already on GPU
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).cuda(), # Already on GPU
|
||||
}
|
||||
action = torch.randn(5).cuda() # Already on GPU
|
||||
|
||||
@@ -823,14 +822,12 @@ def test_preserves_gpu_placement():
|
||||
result = processor(transition)
|
||||
|
||||
# Check that tensors remain on their original GPU
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Verify no unnecessary copies were made (same data pointer)
|
||||
assert torch.equal(
|
||||
result[TransitionKey.OBSERVATION]["observation.state"], observation["observation.state"]
|
||||
)
|
||||
assert torch.equal(result[TransitionKey.OBSERVATION][OBS_STATE], observation[OBS_STATE])
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
@@ -842,8 +839,8 @@ def test_multi_gpu_preservation():
|
||||
# Create tensors on cuda:1 (simulating Accelerate placement)
|
||||
cuda1_device = torch.device("cuda:1")
|
||||
observation = {
|
||||
"observation.state": torch.randn(10).to(cuda1_device),
|
||||
"observation.image": torch.randn(3, 224, 224).to(cuda1_device),
|
||||
OBS_STATE: torch.randn(10).to(cuda1_device),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(cuda1_device),
|
||||
}
|
||||
action = torch.randn(5).to(cuda1_device)
|
||||
|
||||
@@ -851,20 +848,20 @@ def test_multi_gpu_preservation():
|
||||
result = processor_gpu(transition)
|
||||
|
||||
# Check that tensors remain on cuda:1 (not moved to cuda:0)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device == cuda1_device
|
||||
assert result[TransitionKey.OBSERVATION]["observation.image"].device == cuda1_device
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device == cuda1_device
|
||||
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device == cuda1_device
|
||||
assert result[TransitionKey.ACTION].device == cuda1_device
|
||||
|
||||
# Test 2: GPU-to-CPU should move to CPU (not preserve GPU)
|
||||
processor_cpu = DeviceProcessorStep(device="cpu")
|
||||
|
||||
transition_gpu = create_transition(
|
||||
observation={"observation.state": torch.randn(10).cuda()}, action=torch.randn(5).cuda()
|
||||
observation={OBS_STATE: torch.randn(10).cuda()}, action=torch.randn(5).cuda()
|
||||
)
|
||||
result_cpu = processor_cpu(transition_gpu)
|
||||
|
||||
# Check that tensors are moved to CPU
|
||||
assert result_cpu[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
|
||||
assert result_cpu[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
assert result_cpu[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@@ -933,14 +930,14 @@ def test_simulated_accelerate_scenario():
|
||||
|
||||
# Simulate data already placed by Accelerate
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
observation = {"observation.state": torch.randn(1, 10).to(device)}
|
||||
observation = {OBS_STATE: torch.randn(1, 10).to(device)}
|
||||
action = torch.randn(1, 5).to(device)
|
||||
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
result = processor(transition)
|
||||
|
||||
# Verify data stays on the GPU where Accelerate placed it
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device == device
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert result[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@@ -1081,7 +1078,7 @@ def test_mps_float64_with_complementary_data():
|
||||
}
|
||||
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(5, dtype=torch.float64)},
|
||||
observation={OBS_STATE: torch.randn(5, dtype=torch.float64)},
|
||||
action=torch.randn(3, dtype=torch.float64),
|
||||
complementary_data=complementary_data,
|
||||
)
|
||||
@@ -1089,7 +1086,7 @@ def test_mps_float64_with_complementary_data():
|
||||
result = processor(transition)
|
||||
|
||||
# Check that all tensors are on MPS device
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "mps"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "mps"
|
||||
assert result[TransitionKey.ACTION].device.type == "mps"
|
||||
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
@@ -1099,7 +1096,7 @@ def test_mps_float64_with_complementary_data():
|
||||
assert processed_comp_data["float32_tensor"].device.type == "mps"
|
||||
|
||||
# Check dtype conversions
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float32 # Converted
|
||||
assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted
|
||||
assert processed_comp_data["float64_tensor"].dtype == torch.float32 # Converted
|
||||
assert processed_comp_data["float32_tensor"].dtype == torch.float32 # Unchanged
|
||||
|
||||
@@ -25,6 +25,7 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
|
||||
def test_is_processor_config_valid_configs():
|
||||
@@ -111,7 +112,7 @@ def test_should_suggest_migration_with_model_config_only():
|
||||
# Create a model config (like old LeRobot format)
|
||||
model_config = {
|
||||
"type": "act",
|
||||
"input_features": {"observation.state": {"shape": [7]}},
|
||||
"input_features": {OBS_STATE: {"shape": [7]}},
|
||||
"output_features": {"action": {"shape": [7]}},
|
||||
"hidden_dim": 256,
|
||||
"n_obs_steps": 1,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -39,8 +39,8 @@ def test_process_single_image():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that the image was processed correctly
|
||||
assert "observation.image" in processed_obs
|
||||
processed_img = processed_obs["observation.image"]
|
||||
assert OBS_IMAGE in processed_obs
|
||||
processed_img = processed_obs[OBS_IMAGE]
|
||||
|
||||
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
|
||||
assert processed_img.shape == (1, 3, 64, 64)
|
||||
@@ -66,12 +66,12 @@ def test_process_image_dict():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that both images were processed
|
||||
assert "observation.images.camera1" in processed_obs
|
||||
assert "observation.images.camera2" in processed_obs
|
||||
assert f"{OBS_IMAGES}.camera1" in processed_obs
|
||||
assert f"{OBS_IMAGES}.camera2" in processed_obs
|
||||
|
||||
# Check shapes
|
||||
assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32)
|
||||
assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48)
|
||||
assert processed_obs[f"{OBS_IMAGES}.camera1"].shape == (1, 3, 32, 32)
|
||||
assert processed_obs[f"{OBS_IMAGES}.camera2"].shape == (1, 3, 48, 48)
|
||||
|
||||
|
||||
def test_process_batched_image():
|
||||
@@ -88,7 +88,7 @@ def test_process_batched_image():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
|
||||
assert processed_obs[OBS_IMAGE].shape == (2, 3, 64, 64)
|
||||
|
||||
|
||||
def test_invalid_image_format():
|
||||
@@ -173,10 +173,10 @@ def test_process_environment_state():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that environment_state was renamed and processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert OBS_ENV_STATE in processed_obs
|
||||
assert "environment_state" not in processed_obs
|
||||
|
||||
processed_state = processed_obs["observation.environment_state"]
|
||||
processed_state = processed_obs[OBS_ENV_STATE]
|
||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
||||
assert processed_state.dtype == torch.float32
|
||||
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
@@ -194,10 +194,10 @@ def test_process_agent_pos():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that agent_pos was renamed and processed
|
||||
assert "observation.state" in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
processed_state = processed_obs["observation.state"]
|
||||
processed_state = processed_obs[OBS_STATE]
|
||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
||||
assert processed_state.dtype == torch.float32
|
||||
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
|
||||
@@ -217,8 +217,8 @@ def test_process_batched_states():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that batch dimensions are preserved
|
||||
assert processed_obs["observation.environment_state"].shape == (2, 2)
|
||||
assert processed_obs["observation.state"].shape == (2, 2)
|
||||
assert processed_obs[OBS_ENV_STATE].shape == (2, 2)
|
||||
assert processed_obs[OBS_STATE].shape == (2, 2)
|
||||
|
||||
|
||||
def test_process_both_states():
|
||||
@@ -235,8 +235,8 @@ def test_process_both_states():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that both states were processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert "observation.state" in processed_obs
|
||||
assert OBS_ENV_STATE in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
|
||||
# Check that original keys were removed
|
||||
assert "environment_state" not in processed_obs
|
||||
@@ -281,12 +281,12 @@ def test_complete_observation_processing():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that image was processed
|
||||
assert "observation.image" in processed_obs
|
||||
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert processed_obs[OBS_IMAGE].shape == (1, 3, 32, 32)
|
||||
|
||||
# Check that states were processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert "observation.state" in processed_obs
|
||||
assert OBS_ENV_STATE in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
|
||||
# Check that original keys were removed
|
||||
assert "pixels" not in processed_obs
|
||||
@@ -308,7 +308,7 @@ def test_image_only_processing():
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.image" in processed_obs
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert len(processed_obs) == 1
|
||||
|
||||
|
||||
@@ -323,7 +323,7 @@ def test_state_only_processing():
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
|
||||
@@ -504,7 +504,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory):
|
||||
proc = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
OBS_ENV_STATE: policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||
},
|
||||
}
|
||||
@@ -513,7 +513,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory):
|
||||
assert (
|
||||
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.environment_state"]
|
||||
== features[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
|
||||
)
|
||||
assert (
|
||||
OBS_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
|
||||
@@ -35,6 +35,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -255,7 +256,7 @@ def test_step_through_with_dict():
|
||||
pipeline = DataProcessorPipeline([step1, step2])
|
||||
|
||||
batch = {
|
||||
"observation.image": None,
|
||||
OBS_IMAGE: None,
|
||||
"action": None,
|
||||
"next.reward": 0.0,
|
||||
"next.done": False,
|
||||
@@ -1840,7 +1841,7 @@ def test_save_load_with_custom_converter_functions():
|
||||
|
||||
# Verify it uses default converters by checking with standard batch format
|
||||
batch = {
|
||||
"observation.image": torch.randn(1, 3, 32, 32),
|
||||
OBS_IMAGE: torch.randn(1, 3, 32, 32),
|
||||
"action": torch.randn(1, 7),
|
||||
"next.reward": torch.tensor([1.0]),
|
||||
"next.done": torch.tensor([False]),
|
||||
@@ -1851,7 +1852,7 @@ def test_save_load_with_custom_converter_functions():
|
||||
# Should work with standard format (wouldn't work with custom converter)
|
||||
result = loaded(batch)
|
||||
# With new behavior, default to_output is _default_transition_to_batch, so result is batch dict
|
||||
assert "observation.image" in result
|
||||
assert OBS_IMAGE in result
|
||||
|
||||
|
||||
class NonCompliantStep:
|
||||
@@ -2075,10 +2076,10 @@ class AddObservationStateFeatures(ProcessorStep):
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# State features (mix EE and a joint state)
|
||||
features[PipelineFeatureType.OBSERVATION]["observation.state.ee.x"] = float
|
||||
features[PipelineFeatureType.OBSERVATION]["observation.state.j1.pos"] = float
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.ee.x"] = float
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.j1.pos"] = float
|
||||
if self.add_front_image:
|
||||
features[PipelineFeatureType.OBSERVATION]["observation.images.front"] = self.front_image_shape
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"] = self.front_image_shape
|
||||
return features
|
||||
|
||||
|
||||
@@ -2094,7 +2095,7 @@ def test_aggregate_joint_action_only():
|
||||
)
|
||||
|
||||
# Expect only "action" with joint names
|
||||
assert "action" in out and "observation.state" not in out
|
||||
assert "action" in out and OBS_STATE not in out
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"}
|
||||
assert out["action"]["shape"] == (len(out["action"]["names"]),)
|
||||
@@ -2108,7 +2109,7 @@ def test_aggregate_ee_action_and_observation_with_videos():
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "observation.state"],
|
||||
patterns=["action.ee", OBS_STATE],
|
||||
)
|
||||
|
||||
# Action should pack only EE names
|
||||
@@ -2117,13 +2118,13 @@ def test_aggregate_ee_action_and_observation_with_videos():
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
|
||||
# Observation state should pack both ee.x and j1.pos as a vector
|
||||
assert "observation.state" in out
|
||||
assert set(out["observation.state"]["names"]) == {"ee.x", "j1.pos"}
|
||||
assert out["observation.state"]["dtype"] == "float32"
|
||||
assert OBS_STATE in out
|
||||
assert set(out[OBS_STATE]["names"]) == {"ee.x", "j1.pos"}
|
||||
assert out[OBS_STATE]["dtype"] == "float32"
|
||||
|
||||
# Cameras from initial_features appear as videos
|
||||
for cam in ("front", "side"):
|
||||
key = f"observation.images.{cam}"
|
||||
key = f"{OBS_IMAGES}.{cam}"
|
||||
assert key in out
|
||||
assert out[key]["dtype"] == "video"
|
||||
assert out[key]["shape"] == initial[cam]
|
||||
@@ -2156,8 +2157,8 @@ def test_aggregate_images_when_use_videos_false():
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
key = "observation.images.back"
|
||||
key_front = "observation.images.front"
|
||||
key = f"{OBS_IMAGES}.back"
|
||||
key_front = f"{OBS_IMAGES}.front"
|
||||
assert key not in out
|
||||
assert key_front not in out
|
||||
|
||||
@@ -2173,8 +2174,8 @@ def test_aggregate_images_when_use_videos_true():
|
||||
patterns=None,
|
||||
)
|
||||
|
||||
key = "observation.images.front"
|
||||
key_back = "observation.images.back"
|
||||
key = f"{OBS_IMAGES}.front"
|
||||
key_back = f"{OBS_IMAGES}.back"
|
||||
assert key in out
|
||||
assert key_back in out
|
||||
assert out[key]["dtype"] == "video"
|
||||
@@ -2194,9 +2195,9 @@ def test_initial_camera_not_overridden_by_step_image():
|
||||
pipeline=rp,
|
||||
initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial},
|
||||
use_videos=True,
|
||||
patterns=["observation.images.front"],
|
||||
patterns=[f"{OBS_IMAGES}.front"],
|
||||
)
|
||||
|
||||
key = "observation.images.front"
|
||||
key = f"{OBS_IMAGES}.front"
|
||||
assert key in out
|
||||
assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial
|
||||
|
||||
@@ -28,6 +28,7 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
@@ -121,13 +122,13 @@ def test_overlapping_rename():
|
||||
def test_partial_rename():
|
||||
"""Test renaming only some keys."""
|
||||
rename_map = {
|
||||
"observation.state": "observation.proprio_state",
|
||||
"pixels": "observation.image",
|
||||
OBS_STATE: "observation.proprio_state",
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.randn(10),
|
||||
OBS_STATE: torch.randn(10),
|
||||
"pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8),
|
||||
"reward": 1.0,
|
||||
"info": {"episode": 1},
|
||||
@@ -139,8 +140,8 @@ def test_partial_rename():
|
||||
|
||||
# Check renamed keys
|
||||
assert "observation.proprio_state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
assert "observation.state" not in processed_obs
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert OBS_STATE not in processed_obs
|
||||
assert "pixels" not in processed_obs
|
||||
|
||||
# Check unchanged keys
|
||||
@@ -174,8 +175,8 @@ def test_state_dict():
|
||||
def test_integration_with_robot_processor():
|
||||
"""Test integration with RobotProcessor pipeline."""
|
||||
rename_map = {
|
||||
"agent_pos": "observation.state",
|
||||
"pixels": "observation.image",
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
rename_processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
@@ -196,8 +197,8 @@ def test_integration_with_robot_processor():
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renaming worked through pipeline
|
||||
assert "observation.state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
assert OBS_STATE in processed_obs
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
assert "pixels" not in processed_obs
|
||||
assert processed_obs["other_data"] == "preserve_me"
|
||||
@@ -210,8 +211,8 @@ def test_integration_with_robot_processor():
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading processor with RobotProcessor."""
|
||||
rename_map = {
|
||||
"old_state": "observation.state",
|
||||
"old_image": "observation.image",
|
||||
"old_state": OBS_STATE,
|
||||
"old_image": OBS_IMAGE,
|
||||
}
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep")
|
||||
@@ -253,10 +254,10 @@ def test_save_and_load_pretrained():
|
||||
result = loaded_pipeline(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
assert processed_obs["observation.state"] == [1, 2, 3]
|
||||
assert processed_obs["observation.image"] == "image_data"
|
||||
assert OBS_STATE in processed_obs
|
||||
assert OBS_IMAGE in processed_obs
|
||||
assert processed_obs[OBS_STATE] == [1, 2, 3]
|
||||
assert processed_obs[OBS_IMAGE] == "image_data"
|
||||
|
||||
|
||||
def test_registry_functionality():
|
||||
@@ -317,8 +318,8 @@ def test_chained_rename_processors():
|
||||
# Second processor: rename to final format
|
||||
processor2 = RenameObservationsProcessorStep(
|
||||
rename_map={
|
||||
"agent_position": "observation.state",
|
||||
"camera_image": "observation.image",
|
||||
"agent_position": OBS_STATE,
|
||||
"camera_image": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -342,8 +343,8 @@ def test_chained_rename_processors():
|
||||
|
||||
# After second processor
|
||||
final_obs = results[2][TransitionKey.OBSERVATION]
|
||||
assert "observation.state" in final_obs
|
||||
assert "observation.image" in final_obs
|
||||
assert OBS_STATE in final_obs
|
||||
assert OBS_IMAGE in final_obs
|
||||
assert final_obs["extra"] == "keep_me"
|
||||
|
||||
# Original keys should be gone
|
||||
@@ -356,15 +357,15 @@ def test_chained_rename_processors():
|
||||
def test_nested_observation_rename():
|
||||
"""Test renaming with nested observation structures."""
|
||||
rename_map = {
|
||||
"observation.images.left": "observation.camera.left_view",
|
||||
"observation.images.right": "observation.camera.right_view",
|
||||
f"{OBS_IMAGES}.left": "observation.camera.left_view",
|
||||
f"{OBS_IMAGES}.right": "observation.camera.right_view",
|
||||
"observation.proprio": "observation.proprioception",
|
||||
}
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.images.left": torch.randn(3, 64, 64),
|
||||
"observation.images.right": torch.randn(3, 64, 64),
|
||||
f"{OBS_IMAGES}.left": torch.randn(3, 64, 64),
|
||||
f"{OBS_IMAGES}.right": torch.randn(3, 64, 64),
|
||||
"observation.proprio": torch.randn(7),
|
||||
"observation.gripper": torch.tensor([0.0]), # Not renamed
|
||||
}
|
||||
@@ -382,8 +383,8 @@ def test_nested_observation_rename():
|
||||
assert "observation.gripper" in processed_obs
|
||||
|
||||
# Check old keys removed
|
||||
assert "observation.images.left" not in processed_obs
|
||||
assert "observation.images.right" not in processed_obs
|
||||
assert f"{OBS_IMAGES}.left" not in processed_obs
|
||||
assert f"{OBS_IMAGES}.right" not in processed_obs
|
||||
assert "observation.proprio" not in processed_obs
|
||||
|
||||
|
||||
@@ -464,7 +465,7 @@ def test_features_chained_processors(policy_feature_factory):
|
||||
# Chain two rename processors at the contract level
|
||||
processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"})
|
||||
processor2 = RenameObservationsProcessorStep(
|
||||
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
|
||||
rename_map={"agent_position": OBS_STATE, "camera_image": OBS_IMAGE}
|
||||
)
|
||||
pipeline = DataProcessorPipeline([processor1, processor2])
|
||||
|
||||
@@ -477,27 +478,21 @@ def test_features_chained_processors(policy_feature_factory):
|
||||
}
|
||||
out = pipeline.transform_features(initial_features=spec)
|
||||
|
||||
assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"}
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["observation.state"]
|
||||
== spec[PipelineFeatureType.OBSERVATION]["pos"]
|
||||
)
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["observation.image"]
|
||||
== spec[PipelineFeatureType.OBSERVATION]["img"]
|
||||
)
|
||||
assert set(out[PipelineFeatureType.OBSERVATION]) == {OBS_STATE, OBS_IMAGE, "extra"}
|
||||
assert out[PipelineFeatureType.OBSERVATION][OBS_STATE] == spec[PipelineFeatureType.OBSERVATION]["pos"]
|
||||
assert out[PipelineFeatureType.OBSERVATION][OBS_IMAGE] == spec[PipelineFeatureType.OBSERVATION]["img"]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_rename_stats_basic():
|
||||
orig = {
|
||||
"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])},
|
||||
OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])},
|
||||
"action": {"mean": np.array([0.0])},
|
||||
}
|
||||
mapping = {"observation.state": "observation.robot_state"}
|
||||
mapping = {OBS_STATE: "observation.robot_state"}
|
||||
renamed = rename_stats(orig, mapping)
|
||||
assert "observation.robot_state" in renamed and "observation.state" not in renamed
|
||||
assert "observation.robot_state" in renamed and OBS_STATE not in renamed
|
||||
# Ensure deep copy: mutate original and verify renamed unaffected
|
||||
orig["observation.state"]["mean"][0] = 42.0
|
||||
orig[OBS_STATE]["mean"][0] = 42.0
|
||||
assert renamed["observation.robot_state"]["mean"][0] != 42.0
|
||||
|
||||
@@ -11,7 +11,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import OBS_LANGUAGE
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_LANGUAGE, OBS_STATE
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -503,16 +503,14 @@ def test_features_basic():
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=128)
|
||||
|
||||
input_features = {
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
||||
},
|
||||
PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
|
||||
}
|
||||
|
||||
output_features = processor.transform_features(input_features)
|
||||
|
||||
# Check that original features are preserved
|
||||
assert "observation.state" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
assert OBS_STATE in output_features[PipelineFeatureType.OBSERVATION]
|
||||
assert "action" in output_features[PipelineFeatureType.ACTION]
|
||||
|
||||
# Check that tokenized features are added
|
||||
@@ -797,7 +795,7 @@ def test_device_detection_cpu():
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CPU tensors
|
||||
observation = {"observation.state": torch.randn(10)} # CPU tensor
|
||||
observation = {OBS_STATE: torch.randn(10)} # CPU tensor
|
||||
action = torch.randn(5) # CPU tensor
|
||||
transition = create_transition(
|
||||
observation=observation, action=action, complementary_data={"task": "test task"}
|
||||
@@ -821,7 +819,7 @@ def test_device_detection_cuda():
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CUDA tensors
|
||||
observation = {"observation.state": torch.randn(10).cuda()} # CUDA tensor
|
||||
observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor
|
||||
action = torch.randn(5).cuda() # CUDA tensor
|
||||
transition = create_transition(
|
||||
observation=observation, action=action, complementary_data={"task": "test task"}
|
||||
@@ -847,7 +845,7 @@ def test_device_detection_multi_gpu():
|
||||
|
||||
# Test with tensors on cuda:1
|
||||
device = torch.device("cuda:1")
|
||||
observation = {"observation.state": torch.randn(10).to(device)}
|
||||
observation = {OBS_STATE: torch.randn(10).to(device)}
|
||||
action = torch.randn(5).to(device)
|
||||
transition = create_transition(
|
||||
observation=observation, action=action, complementary_data={"task": "multi gpu test"}
|
||||
@@ -943,7 +941,7 @@ def test_device_detection_preserves_dtype():
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with float tensor (to test dtype isn't affected)
|
||||
observation = {"observation.state": torch.randn(10, dtype=torch.float16)}
|
||||
observation = {OBS_STATE: torch.randn(10, dtype=torch.float16)}
|
||||
transition = create_transition(observation=observation, complementary_data={"task": "dtype test"})
|
||||
|
||||
result = processor(transition)
|
||||
@@ -977,7 +975,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
|
||||
|
||||
# Start with CPU tensors
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(10)}, # CPU
|
||||
observation={OBS_STATE: torch.randn(10)}, # CPU
|
||||
action=torch.randn(5), # CPU
|
||||
complementary_data={"task": "pipeline test"},
|
||||
)
|
||||
@@ -985,7 +983,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
|
||||
result = robot_processor(transition)
|
||||
|
||||
# All tensors should end up on CUDA (moved by DeviceProcessorStep)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Tokenized tensors should also be on CUDA
|
||||
@@ -1005,8 +1003,8 @@ def test_simulated_accelerate_scenario():
|
||||
# Simulate Accelerate scenario: batch already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
"observation.state": torch.randn(1, 10).to(device), # Batched, on GPU
|
||||
"observation.image": torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU
|
||||
OBS_STATE: torch.randn(1, 10).to(device), # Batched, on GPU
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU
|
||||
}
|
||||
action = torch.randn(1, 5).to(device) # Batched, on GPU
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import pytest
|
||||
import torch
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import require_package
|
||||
|
||||
@@ -110,12 +111,12 @@ def test_push_transitions_to_transport_queue():
|
||||
transitions = []
|
||||
for i in range(3):
|
||||
transition = Transition(
|
||||
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
action=torch.randn(5),
|
||||
reward=torch.tensor(1.0 + i),
|
||||
done=torch.tensor(False),
|
||||
truncated=torch.tensor(False),
|
||||
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
next_state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
complementary_info={"step": torch.tensor(i)},
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
@@ -24,6 +24,7 @@ from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import require_package
|
||||
|
||||
@@ -33,12 +34,12 @@ def create_test_transitions(count: int = 3) -> list[Transition]:
|
||||
transitions = []
|
||||
for i in range(count):
|
||||
transition = Transition(
|
||||
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
action=torch.randn(5),
|
||||
reward=torch.tensor(1.0 + i),
|
||||
done=torch.tensor(i == count - 1), # Last transition is done
|
||||
truncated=torch.tensor(False),
|
||||
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
next_state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
complementary_info={"step": torch.tensor(i), "episode_id": i // 2},
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
@@ -22,11 +22,12 @@ import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def state_dims() -> list[str]:
|
||||
return ["observation.image", "observation.state"]
|
||||
return [OBS_IMAGE, OBS_STATE]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -61,10 +62,10 @@ def create_random_image() -> torch.Tensor:
|
||||
|
||||
def create_dummy_transition() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
OBS_IMAGE: create_random_image(),
|
||||
"action": torch.randn(4),
|
||||
"reward": torch.tensor(1.0),
|
||||
"observation.state": torch.randn(
|
||||
OBS_STATE: torch.randn(
|
||||
10,
|
||||
),
|
||||
"done": torch.tensor(False),
|
||||
@@ -98,8 +99,8 @@ def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayB
|
||||
|
||||
def create_dummy_state() -> dict:
|
||||
return {
|
||||
"observation.image": create_random_image(),
|
||||
"observation.state": torch.randn(
|
||||
OBS_IMAGE: create_random_image(),
|
||||
OBS_STATE: torch.randn(
|
||||
10,
|
||||
),
|
||||
}
|
||||
@@ -180,7 +181,7 @@ def test_empty_buffer_sample_raises_error(replay_buffer):
|
||||
|
||||
def test_zero_capacity_buffer_raises_error():
|
||||
with pytest.raises(ValueError, match="Capacity must be greater than 0."):
|
||||
ReplayBuffer(0, "cpu", ["observation", "next_observation"])
|
||||
ReplayBuffer(0, "cpu", [OBS_STR, "next_observation"])
|
||||
|
||||
|
||||
def test_add_transition(replay_buffer, dummy_state, dummy_action):
|
||||
@@ -203,7 +204,7 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action):
|
||||
|
||||
|
||||
def test_add_over_capacity():
|
||||
replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"])
|
||||
replay_buffer = ReplayBuffer(2, "cpu", [OBS_STR, "next_observation"])
|
||||
dummy_state_1 = create_dummy_state()
|
||||
dummy_action_1 = create_dummy_action()
|
||||
|
||||
@@ -373,7 +374,7 @@ def test_to_lerobot_dataset(tmp_path):
|
||||
assert ds.num_frames == 4
|
||||
|
||||
for j, value in enumerate(ds):
|
||||
print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j]))
|
||||
print(torch.equal(value[OBS_IMAGE], buffer.next_states[OBS_IMAGE][j]))
|
||||
|
||||
for i in range(len(ds)):
|
||||
for feature, value in ds[i].items():
|
||||
@@ -383,12 +384,12 @@ def test_to_lerobot_dataset(tmp_path):
|
||||
assert torch.equal(value, buffer.rewards[i])
|
||||
elif feature == "next.done":
|
||||
assert torch.equal(value, buffer.dones[i])
|
||||
elif feature == "observation.image":
|
||||
elif feature == OBS_IMAGE:
|
||||
# Tensor -> numpy is not precise, so we have some diff there
|
||||
# TODO: Check and fix it
|
||||
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003)
|
||||
elif feature == "observation.state":
|
||||
assert torch.equal(value, buffer.states["observation.state"][i])
|
||||
torch.testing.assert_close(value, buffer.states[OBS_IMAGE][i], rtol=0.3, atol=0.003)
|
||||
elif feature == OBS_STATE:
|
||||
assert torch.equal(value, buffer.states[OBS_STATE][i])
|
||||
|
||||
|
||||
def test_from_lerobot_dataset(tmp_path):
|
||||
@@ -436,14 +437,14 @@ def test_from_lerobot_dataset(tmp_path):
|
||||
)
|
||||
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
reconverted_buffer.states["observation.state"][: len(replay_buffer)],
|
||||
replay_buffer.states[OBS_STATE][: len(replay_buffer)],
|
||||
reconverted_buffer.states[OBS_STATE][: len(replay_buffer)],
|
||||
), "State should be the same after converting to dataset and return back"
|
||||
|
||||
for i in range(4):
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][i],
|
||||
reconverted_buffer.states["observation.image"][i],
|
||||
replay_buffer.states[OBS_IMAGE][i],
|
||||
reconverted_buffer.states[OBS_IMAGE][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
@@ -454,16 +455,16 @@ def test_from_lerobot_dataset(tmp_path):
|
||||
next_index = (i + 1) % 4
|
||||
|
||||
torch.testing.assert_close(
|
||||
replay_buffer.states["observation.image"][next_index],
|
||||
reconverted_buffer.next_states["observation.image"][i],
|
||||
replay_buffer.states[OBS_IMAGE][next_index],
|
||||
reconverted_buffer.next_states[OBS_IMAGE][i],
|
||||
rtol=0.4,
|
||||
atol=0.004,
|
||||
)
|
||||
|
||||
for i in range(2, 4):
|
||||
assert torch.equal(
|
||||
replay_buffer.states["observation.state"][i],
|
||||
reconverted_buffer.next_states["observation.state"][i],
|
||||
replay_buffer.states[OBS_STATE][i],
|
||||
reconverted_buffer.next_states[OBS_STATE][i],
|
||||
)
|
||||
|
||||
|
||||
@@ -563,10 +564,8 @@ def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_functio
|
||||
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
|
||||
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert torch.all(sampled_transitions["state"]["observation.image"] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), (
|
||||
assert torch.all(sampled_transitions["state"][OBS_IMAGE] == 10), "Image augmentations should be applied"
|
||||
assert torch.all(sampled_transitions["next_state"][OBS_IMAGE] == 10), (
|
||||
"Image augmentations should be applied"
|
||||
)
|
||||
|
||||
@@ -580,8 +579,8 @@ def test_check_image_augmentations_with_drq_and_default_image_augmentation_funct
|
||||
|
||||
# Let's check that it doesn't fail and shapes are correct
|
||||
sampled_transitions = replay_buffer.sample(1)
|
||||
assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84)
|
||||
assert sampled_transitions["state"][OBS_IMAGE].shape == (1, 3, 84, 84)
|
||||
assert sampled_transitions["next_state"][OBS_IMAGE].shape == (1, 3, 84, 84)
|
||||
|
||||
|
||||
def test_random_crop_vectorized_basic():
|
||||
@@ -620,7 +619,7 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer:
|
||||
buffer = ReplayBuffer(
|
||||
capacity=capacity,
|
||||
device="cpu",
|
||||
state_keys=["observation.image", "observation.state"],
|
||||
state_keys=[OBS_IMAGE, OBS_STATE],
|
||||
storage_device="cpu",
|
||||
)
|
||||
|
||||
@@ -628,8 +627,8 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer:
|
||||
img = torch.ones(3, 128, 128) * i
|
||||
state_vec = torch.arange(11).float() + i
|
||||
state = {
|
||||
"observation.image": img,
|
||||
"observation.state": state_vec,
|
||||
OBS_IMAGE: img,
|
||||
OBS_STATE: state_vec,
|
||||
}
|
||||
buffer.add(
|
||||
state=state,
|
||||
@@ -648,14 +647,14 @@ def test_async_iterator_shapes_basic():
|
||||
iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1)
|
||||
batch = next(iterator)
|
||||
|
||||
images = batch["state"]["observation.image"]
|
||||
states = batch["state"]["observation.state"]
|
||||
images = batch["state"][OBS_IMAGE]
|
||||
states = batch["state"][OBS_STATE]
|
||||
|
||||
assert images.shape == (batch_size, 3, 128, 128)
|
||||
assert states.shape == (batch_size, 11)
|
||||
|
||||
next_images = batch["next_state"]["observation.image"]
|
||||
next_states = batch["next_state"]["observation.state"]
|
||||
next_images = batch["next_state"][OBS_IMAGE]
|
||||
next_states = batch["next_state"][OBS_STATE]
|
||||
|
||||
assert next_images.shape == (batch_size, 3, 128, 128)
|
||||
assert next_states.shape == (batch_size, 11)
|
||||
@@ -668,13 +667,13 @@ def test_async_iterator_multiple_iterations():
|
||||
|
||||
for _ in range(5):
|
||||
batch = next(iterator)
|
||||
images = batch["state"]["observation.image"]
|
||||
states = batch["state"]["observation.state"]
|
||||
images = batch["state"][OBS_IMAGE]
|
||||
states = batch["state"][OBS_STATE]
|
||||
assert images.shape == (batch_size, 3, 128, 128)
|
||||
assert states.shape == (batch_size, 11)
|
||||
|
||||
next_images = batch["next_state"]["observation.image"]
|
||||
next_states = batch["next_state"]["observation.state"]
|
||||
next_images = batch["next_state"][OBS_IMAGE]
|
||||
next_states = batch["next_state"][OBS_STATE]
|
||||
assert next_images.shape == (batch_size, 3, 128, 128)
|
||||
assert next_states.shape == (batch_size, 11)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -72,7 +73,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
|
||||
# Build EnvTransition dict
|
||||
obs = {
|
||||
"observation.state.temperature": np.float32(25.0),
|
||||
f"{OBS_STATE}.temperature": np.float32(25.0),
|
||||
# CHW image should be converted to HWC for rr.Image
|
||||
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
|
||||
}
|
||||
@@ -97,7 +98,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
# - action.throttle -> Scalar
|
||||
# - action.vector_0, action.vector_1 -> Scalars
|
||||
expected_keys = {
|
||||
"observation.state.temperature",
|
||||
f"{OBS_STATE}.temperature",
|
||||
"observation.camera",
|
||||
"action.throttle",
|
||||
"action.vector_0",
|
||||
@@ -106,7 +107,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
assert set(_keys(calls)) == expected_keys
|
||||
|
||||
# Check scalar types and values
|
||||
temp_obj = _obj_for(calls, "observation.state.temperature")
|
||||
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
|
||||
assert type(temp_obj).__name__ == "DummyScalar"
|
||||
assert temp_obj.value == pytest.approx(25.0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user