diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py index f041a906..9f34b227 100644 --- a/benchmarks/video/run_video_benchmark.py +++ b/benchmarks/video/run_video_benchmark.py @@ -41,6 +41,7 @@ from lerobot.datasets.video_utils import ( decode_video_frames_torchvision, encode_video_frames, ) +from lerobot.utils.constants import OBS_IMAGE BASE_ENCODING = OrderedDict( [ @@ -117,7 +118,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: hf_dataset = dataset.hf_dataset.with_format(None) # We only save images from the first camera - img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")] + img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)] imgs_dataset = hf_dataset.select_columns(img_keys[0]) for i, item in enumerate( diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 32a5e0a2..174486eb 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -21,6 +21,7 @@ from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import make_default_processors from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.scripts.lerobot_record import record_loop +from lerobot.utils.constants import OBS_STR from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -42,7 +43,7 @@ policy = ACTPolicy.from_pretrained(HF_MODEL_ID) # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") -obs_features = hw_to_dataset_features(robot.observation_features, "observation") +obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR) dataset_features = {**action_features, **obs_features} # Create the dataset diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 30f34e71..471cb366 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -22,6 +22,7 @@ from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig +from lerobot.utils.constants import OBS_STR from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -48,7 +49,7 @@ teleop_action_processor, robot_action_processor, robot_observation_processor = m # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") -obs_features = hw_to_dataset_features(robot.observation_features, "observation") +obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR) dataset_features = {**action_features, **obs_features} # Create the dataset diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 175cecf6..75d81a0f 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -27,7 +27,7 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features # NOTE: Configs need to be loaded for the client to be able to instantiate the policy config from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401 from lerobot.robots.robot import Robot -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.utils import init_logging Action = torch.Tensor @@ -66,7 +66,7 @@ def validate_robot_cameras_for_policy( def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]: - return hw_to_dataset_features(robot.observation_features, "observation", use_video=False) + return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False) def is_image_key(k: str) -> bool: @@ -141,7 +141,7 @@ def make_lerobot_observation( lerobot_features: dict[str, dict], ) -> LeRobotObservation: """Make a lerobot observation from a raw observation.""" - return build_dataset_frame(lerobot_features, robot_obs, prefix="observation") + return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR) def prepare_raw_observation( diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index a71e978b..2bac84ae 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -27,6 +27,7 @@ from lerobot.datasets.lerobot_dataset import ( ) from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms +from lerobot.utils.constants import OBS_PREFIX IMAGENET_STATS = { "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) @@ -58,7 +59,7 @@ def resolve_delta_timestamps( delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] if key == "action" and cfg.action_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] - if key.startswith("observation.") and cfg.observation_delta_indices is not None: + if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] if len(delta_timestamps) == 0: diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index cdf0b744..13555dd3 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -19,7 +19,7 @@ from typing import Any from lerobot.configs.types import PipelineFeatureType from lerobot.datasets.utils import hw_to_dataset_features from lerobot.processor import DataProcessorPipeline -from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR def create_initial_features( @@ -92,8 +92,8 @@ def aggregate_pipeline_dataset_features( # Intermediate storage for categorized and filtered features. processed_features: dict[str, dict[str, Any]] = { - "action": {}, - "observation": {}, + ACTION: {}, + OBS_STR: {}, } images_token = OBS_IMAGES.split(".")[-1] @@ -125,17 +125,15 @@ def aggregate_pipeline_dataset_features( # 3. Add the feature to the appropriate group with a clean name. name = strip_prefix(key, PREFIXES_TO_STRIP) if is_action: - processed_features["action"][name] = value + processed_features[ACTION][name] = value else: - processed_features["observation"][name] = value + processed_features[OBS_STR][name] = value # Convert the processed features into the final dataset format. dataset_features = {} - if processed_features["action"]: + if processed_features[ACTION]: dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos)) - if processed_features["observation"]: - dataset_features.update( - hw_to_dataset_features(processed_features["observation"], "observation", use_videos) - ) + if processed_features[OBS_STR]: + dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos)) return dataset_features diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 922fc4e3..96ae2eca 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -43,6 +43,7 @@ from lerobot.datasets.backward_compatibility import ( BackwardCompatibilityError, ForwardCompatibilityError, ) +from lerobot.utils.constants import OBS_ENV_STATE, OBS_STR from lerobot.utils.utils import is_valid_numpy_dtype_string DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk @@ -652,7 +653,7 @@ def hw_to_dataset_features( "names": list(joint_fts), } - if joint_fts and prefix == "observation": + if joint_fts and prefix == OBS_STR: features[f"{prefix}.state"] = { "dtype": "float32", "shape": (len(joint_fts),), @@ -728,9 +729,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) shape = (shape[2], shape[0], shape[1]) - elif key == "observation.environment_state": + elif key == OBS_ENV_STATE: type = FeatureType.ENV - elif key.startswith("observation"): + elif key.startswith(OBS_STR): type = FeatureType.STATE elif key.startswith("action"): type = FeatureType.ACTION diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index f0aa0b5c..023ceea6 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -26,6 +26,7 @@ from torch import Tensor from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.envs.configs import EnvConfig +from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.utils.utils import get_channel_first_image_shape @@ -41,9 +42,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten return_observations = {} if "pixels" in observations: if isinstance(observations["pixels"], dict): - imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} + imgs = {f"{OBS_IMAGES}.{key}": img for key, img in observations["pixels"].items()} else: - imgs = {"observation.image": observations["pixels"]} + imgs = {OBS_IMAGE: observations["pixels"]} for imgkey, img in imgs.items(): # TODO(aliberts, rcadene): use transforms.ToTensor()? @@ -72,13 +73,13 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten if env_state.dim() == 1: env_state = env_state.unsqueeze(0) - return_observations["observation.environment_state"] = env_state + return_observations[OBS_ENV_STATE] = env_state # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing agent_pos = torch.from_numpy(observations["agent_pos"]).float() if agent_pos.dim() == 1: agent_pos = agent_pos.unsqueeze(0) - return_observations["observation.state"] = agent_pos + return_observations[OBS_STATE] = agent_pos return return_observations diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index e4ebec19..f8261bb7 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -35,7 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.constants import ACTION, OBS_IMAGES +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE class ACTPolicy(PreTrainedPolicy): @@ -398,10 +398,10 @@ class ACT(nn.Module): "actions must be provided when using the variational objective in training mode." ) - if "observation.images" in batch: - batch_size = batch["observation.images"][0].shape[0] + if OBS_IMAGES in batch: + batch_size = batch[OBS_IMAGES][0].shape[0] else: - batch_size = batch["observation.environment_state"].shape[0] + batch_size = batch[OBS_ENV_STATE].shape[0] # Prepare the latent for input to the transformer encoder. if self.config.use_vae and "action" in batch and self.training: @@ -410,7 +410,7 @@ class ACT(nn.Module): self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) if self.config.robot_state_feature: - robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) + robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE]) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) @@ -430,7 +430,7 @@ class ACT(nn.Module): cls_joint_is_pad = torch.full( (batch_size, 2 if self.config.robot_state_feature else 1), False, - device=batch["observation.state"].device, + device=batch[OBS_STATE].device, ) key_padding_mask = torch.cat( [cls_joint_is_pad, batch["action_is_pad"]], axis=1 @@ -454,7 +454,7 @@ class ACT(nn.Module): mu = log_sigma_x2 = None # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( - batch["observation.state"].device + batch[OBS_STATE].device ) # Prepare transformer encoder inputs. @@ -462,18 +462,16 @@ class ACT(nn.Module): encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)) # Robot state token. if self.config.robot_state_feature: - encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) + encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch[OBS_STATE])) # Environment state token. if self.config.env_state_feature: - encoder_in_tokens.append( - self.encoder_env_state_input_proj(batch["observation.environment_state"]) - ) + encoder_in_tokens.append(self.encoder_env_state_input_proj(batch[OBS_ENV_STATE])) if self.config.image_features: # For a list of images, the H and W may vary but H*W is constant. # NOTE: If modifying this section, verify on MPS devices that # gradients remain stable (no explosions or NaNs). - for img in batch["observation.images"]: + for img in batch[OBS_IMAGES]: cam_features = self.backbone(img)["feature_map"] cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 0bd2e282..af1327ba 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -81,13 +81,13 @@ class DiffusionPolicy(PreTrainedPolicy): def reset(self): """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { - "observation.state": deque(maxlen=self.config.n_obs_steps), + OBS_STATE: deque(maxlen=self.config.n_obs_steps), "action": deque(maxlen=self.config.n_action_steps), } if self.config.image_features: - self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) + self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps) if self.config.env_state_feature: - self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) + self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps) @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: @@ -234,7 +234,7 @@ class DiffusionModel(nn.Module): if self.config.image_features: if self.config.use_separate_rgb_encoder_per_camera: # Combine batch and sequence dims while rearranging to make the camera index dimension first. - images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...") + images_per_camera = einops.rearrange(batch[OBS_IMAGES], "b s n ... -> n (b s) ...") img_features_list = torch.cat( [ encoder(images) @@ -249,7 +249,7 @@ class DiffusionModel(nn.Module): else: # Combine batch, sequence, and "which camera" dims before passing to shared encoder. img_features = self.rgb_encoder( - einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") + einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...") ) # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the # feature dim (effectively concatenating the camera features). @@ -275,7 +275,7 @@ class DiffusionModel(nn.Module): "observation.environment_state": (B, n_obs_steps, environment_dim) } """ - batch_size, n_obs_steps = batch["observation.state"].shape[:2] + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] assert n_obs_steps == self.config.n_obs_steps # Encode image features and concatenate them all together along with the state vector. @@ -306,9 +306,9 @@ class DiffusionModel(nn.Module): } """ # Input validation. - assert set(batch).issuperset({"observation.state", "action", "action_is_pad"}) - assert "observation.images" in batch or "observation.environment_state" in batch - n_obs_steps = batch["observation.state"].shape[1] + assert set(batch).issuperset({OBS_STATE, "action", "action_is_pad"}) + assert OBS_IMAGES in batch or OBS_ENV_STATE in batch + n_obs_steps = batch[OBS_STATE].shape[1] horizon = batch["action"].shape[1] assert horizon == self.config.horizon assert n_obs_steps == self.config.n_obs_steps diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index c9728e41..bd5bbf7e 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) +from lerobot.utils.constants import OBS_IMAGES @PreTrainedConfig.register_subclass("pi0") @@ -113,7 +114,7 @@ class PI0Config(PreTrainedConfig): # raise ValueError("You must provide at least one image or the environment state among the inputs.") for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" + key = f"{OBS_IMAGES}.empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, 480, 640), diff --git a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py b/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py index c0c2e481..fe986569 100644 --- a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py +++ b/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py @@ -21,6 +21,7 @@ import torch from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.policies.factory import make_policy +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE def display(tensor: torch.Tensor): @@ -60,26 +61,26 @@ def main(): # Override stats dataset_meta = LeRobotDatasetMetadata(dataset_repo_id) - dataset_meta.stats["observation.state"]["mean"] = torch.tensor( + dataset_meta.stats[OBS_STATE]["mean"] = torch.tensor( norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32 ) - dataset_meta.stats["observation.state"]["std"] = torch.tensor( + dataset_meta.stats[OBS_STATE]["std"] = torch.tensor( norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32 ) # Create LeRobot batch from Jax batch = {} for cam_key, uint_chw_array in example["images"].items(): - batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 - batch["observation.state"] = torch.from_numpy(example["state"]) + batch[f"{OBS_IMAGES}.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 + batch[OBS_STATE] = torch.from_numpy(example["state"]) batch["action"] = torch.from_numpy(outputs["actions"]) batch["task"] = example["prompt"] if model_name == "pi0_aloha_towel": - del batch["observation.images.cam_low"] + del batch[f"{OBS_IMAGES}.cam_low"] elif model_name == "pi0_aloha_sim": - batch["observation.images.top"] = batch["observation.images.cam_high"] - del batch["observation.images.cam_high"] + batch[f"{OBS_IMAGES}.top"] = batch[f"{OBS_IMAGES}.cam_high"] + del batch[f"{OBS_IMAGES}.cam_high"] # Batchify for key in batch: diff --git a/src/lerobot/policies/pi0fast/configuration_pi0fast.py b/src/lerobot/policies/pi0fast/configuration_pi0fast.py index b72bcd73..705b61ea 100644 --- a/src/lerobot/policies/pi0fast/configuration_pi0fast.py +++ b/src/lerobot/policies/pi0fast/configuration_pi0fast.py @@ -6,6 +6,7 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) +from lerobot.utils.constants import OBS_IMAGES @PreTrainedConfig.register_subclass("pi0fast") @@ -99,7 +100,7 @@ class PI0FASTConfig(PreTrainedConfig): def validate_features(self) -> None: for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" + key = f"{OBS_IMAGES}.empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, 480, 640), diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index fcaf02a4..a6ed79d4 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -31,6 +31,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature from lerobot.policies.utils import get_device_from_parameters +from lerobot.utils.constants import OBS_ENV_STATE, OBS_STATE DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension @@ -513,17 +514,17 @@ class SACObservationEncoder(nn.Module): ) def _init_state_layers(self) -> None: - self.has_env = "observation.environment_state" in self.config.input_features - self.has_state = "observation.state" in self.config.input_features + self.has_env = OBS_ENV_STATE in self.config.input_features + self.has_state = OBS_STATE in self.config.input_features if self.has_env: - dim = self.config.input_features["observation.environment_state"].shape[0] + dim = self.config.input_features[OBS_ENV_STATE].shape[0] self.env_encoder = nn.Sequential( nn.Linear(dim, self.config.latent_dim), nn.LayerNorm(self.config.latent_dim), nn.Tanh(), ) if self.has_state: - dim = self.config.input_features["observation.state"].shape[0] + dim = self.config.input_features[OBS_STATE].shape[0] self.state_encoder = nn.Sequential( nn.Linear(dim, self.config.latent_dim), nn.LayerNorm(self.config.latent_dim), @@ -549,9 +550,9 @@ class SACObservationEncoder(nn.Module): cache = self.get_cached_image_features(obs) parts.append(self._encode_images(cache, detach)) if self.has_env: - parts.append(self.env_encoder(obs["observation.environment_state"])) + parts.append(self.env_encoder(obs[OBS_ENV_STATE])) if self.has_state: - parts.append(self.state_encoder(obs["observation.state"])) + parts.append(self.state_encoder(obs[OBS_STATE])) if parts: return torch.cat(parts, dim=-1) diff --git a/src/lerobot/policies/sac/reward_model/configuration_classifier.py b/src/lerobot/policies/sac/reward_model/configuration_classifier.py index fc53283b..9b76b803 100644 --- a/src/lerobot/policies/sac/reward_model/configuration_classifier.py +++ b/src/lerobot/policies/sac/reward_model/configuration_classifier.py @@ -19,6 +19,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode from lerobot.optim.optimizers import AdamWConfig, OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig +from lerobot.utils.constants import OBS_IMAGE @PreTrainedConfig.register_subclass(name="reward_classifier") @@ -69,7 +70,7 @@ class RewardClassifierConfig(PreTrainedConfig): def validate_features(self) -> None: """Validate feature configurations.""" - has_image = any(key.startswith("observation.image") for key in self.input_features) + has_image = any(key.startswith(OBS_IMAGE) for key in self.input_features) if not has_image: raise ValueError( "You must provide an image observation (key starting with 'observation.image') in the input features" diff --git a/src/lerobot/policies/smolvla/configuration_smolvla.py b/src/lerobot/policies/smolvla/configuration_smolvla.py index 571900c4..eedf477a 100644 --- a/src/lerobot/policies/smolvla/configuration_smolvla.py +++ b/src/lerobot/policies/smolvla/configuration_smolvla.py @@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) +from lerobot.utils.constants import OBS_IMAGES @PreTrainedConfig.register_subclass("smolvla") @@ -117,7 +118,7 @@ class SmolVLAConfig(PreTrainedConfig): def validate_features(self) -> None: for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" + key = f"{OBS_IMAGES}.empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, 480, 640), diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index f8304886..4b5e8b7b 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -38,7 +38,7 @@ from torch import Tensor from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues -from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD class TDMPCPolicy(PreTrainedPolicy): @@ -91,13 +91,13 @@ class TDMPCPolicy(PreTrainedPolicy): called on `env.reset()` """ self._queues = { - "observation.state": deque(maxlen=1), + OBS_STATE: deque(maxlen=1), "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), } if self.config.image_features: - self._queues["observation.image"] = deque(maxlen=1) + self._queues[OBS_IMAGE] = deque(maxlen=1) if self.config.env_state_feature: - self._queues["observation.environment_state"] = deque(maxlen=1) + self._queues[OBS_ENV_STATE] = deque(maxlen=1) # Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start # CEM for the next step. self._prev_mean: torch.Tensor | None = None @@ -325,7 +325,7 @@ class TDMPCPolicy(PreTrainedPolicy): action = batch[ACTION] # (t, b, action_dim) reward = batch[REWARD] # (t, b) - observations = {k: v for k, v in batch.items() if k.startswith("observation.")} + observations = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)} # Apply random image augmentations. if self.config.image_features and self.config.max_random_shift_ratio > 0: @@ -387,10 +387,10 @@ class TDMPCPolicy(PreTrainedPolicy): temporal_loss_coeffs * F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1) # `z_preds` depends on the current observation and the actions. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] # `z_targets` depends on the next observation. - * ~batch["observation.state_is_pad"][1:] + * ~batch[f"{OBS_STR}.state_is_pad"][1:] ) .sum(0) .mean() @@ -403,7 +403,7 @@ class TDMPCPolicy(PreTrainedPolicy): * F.mse_loss(reward_preds, reward, reduction="none") * ~batch["next.reward_is_pad"] # `reward_preds` depends on the current observation and the actions. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] ) .sum(0) @@ -419,11 +419,11 @@ class TDMPCPolicy(PreTrainedPolicy): reduction="none", ).sum(0) # sum over ensemble # `q_preds_ensemble` depends on the first observation and the actions. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] # q_targets depends on the reward and the next observations. * ~batch["next.reward_is_pad"] - * ~batch["observation.state_is_pad"][1:] + * ~batch[f"{OBS_STR}.state_is_pad"][1:] ) .sum(0) .mean() @@ -441,7 +441,7 @@ class TDMPCPolicy(PreTrainedPolicy): temporal_loss_coeffs * raw_v_value_loss # `v_targets` depends on the first observation and the actions, as does `v_preds`. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] ) .sum(0) @@ -477,7 +477,7 @@ class TDMPCPolicy(PreTrainedPolicy): * mse * temporal_loss_coeffs # `action_preds` depends on the first observation and the actions. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] ).mean() diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index 34e5b1c0..91d60970 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -133,7 +133,7 @@ class VQBeTPolicy(PreTrainedPolicy): batch.pop(ACTION) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original # NOTE: It's important that this happens after stacking the images into a single key. - batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out if ACTION in batch: batch.pop(ACTION) @@ -340,14 +340,12 @@ class VQBeTModel(nn.Module): def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]: # Input validation. - assert set(batch).issuperset({"observation.state", "observation.images"}) - batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert set(batch).issuperset({OBS_STATE, OBS_IMAGES}) + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] assert n_obs_steps == self.config.n_obs_steps # Extract image feature (first combine batch and sequence dims). - img_features = self.rgb_encoder( - einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") - ) + img_features = self.rgb_encoder(einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...")) # Separate batch and sequence dims. img_features = einops.rearrange( img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images @@ -359,9 +357,7 @@ class VQBeTModel(nn.Module): img_features ) # (batch, obs_step, number of different cameras, projection dims) input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))] - input_tokens.append( - self.state_projector(batch["observation.state"]) - ) # (batch, obs_step, projection dims) + input_tokens.append(self.state_projector(batch[OBS_STATE])) # (batch, obs_step, projection dims) input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps)) # Interleave tokens by stacking and rearranging. input_tokens = torch.stack(input_tokens, dim=2) diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 440f8b1d..2e80cf4b 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -23,6 +23,8 @@ from typing import Any import numpy as np import torch +from lerobot.utils.constants import OBS_PREFIX + from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey @@ -347,7 +349,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: raise ValueError(f"Action should be a PolicyAction type got {type(action)}") # Extract observation and complementary data keys. - observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} + observation_keys = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)} complementary_data = _extract_complementary_data(batch) return create_transition( diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 2b9402be..48621815 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -21,7 +21,7 @@ import torch from torch import Tensor from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @@ -171,7 +171,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep): # Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1) for old_prefix, new_prefix in prefix_pairs.items(): - prefixed_old = f"observation.{old_prefix}" + prefixed_old = f"{OBS_STR}.{old_prefix}" if key.startswith(prefixed_old): suffix = key[len(prefixed_old) :] new_key = f"{new_prefix}{suffix}" @@ -191,7 +191,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep): # Exact-name rules (pixels, environment_state, agent_pos) for old, new in exact_pairs.items(): - if key == old or key == f"observation.{old}": + if key == old or key == f"{OBS_STR}.{old}": new_key = new new_features[src_ft][new_key] = feat handled = True diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index c6580189..fbf36de3 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -24,6 +24,7 @@ import torch.nn.functional as F # noqa: N812 from tqdm import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import OBS_IMAGE from lerobot.utils.transition import Transition @@ -240,7 +241,7 @@ class ReplayBuffer: idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) # Identify image keys that need augmentation - image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else [] + image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else [] # Create batched state and next_state batch_state = {} diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index f91d077f..39313570 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -73,6 +73,7 @@ from lerobot.teleoperators import ( ) from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say @@ -180,7 +181,7 @@ class RobotEnv(gym.Env): # Define observation spaces for images and other states. if current_observation is not None and "pixels" in current_observation: - prefix = "observation.images" + prefix = OBS_IMAGES observation_spaces = { f"{prefix}.{key}": gym.spaces.Box( low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8 @@ -190,7 +191,7 @@ class RobotEnv(gym.Env): if current_observation is not None: agent_pos = current_observation["agent_pos"] - observation_spaces["observation.state"] = gym.spaces.Box( + observation_spaces[OBS_STATE] = gym.spaces.Box( low=0, high=10, shape=agent_pos.shape, @@ -612,7 +613,7 @@ def control_loop( } for key, value in transition[TransitionKey.OBSERVATION].items(): - if key == "observation.state": + if key == OBS_STATE: features[key] = { "dtype": "float32", "shape": value.squeeze(0).shape, diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 9f636715..392d6d57 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -23,6 +23,7 @@ from typing import Any import cv2 import numpy as np +from lerobot.utils.constants import OBS_STATE from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -203,7 +204,7 @@ class LeKiwiClient(Robot): state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32) - obs_dict: dict[str, Any] = {**flat_state, "observation.state": state_vec} + obs_dict: dict[str, Any] = {**flat_state, OBS_STATE: state_vec} # Decode images current_frames: dict[str, np.ndarray] = {} diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 2033b36b..5c0d31f7 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -75,6 +75,7 @@ import torch.utils.data import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import OBS_STATE class EpisodeSampler(torch.utils.data.Sampler): @@ -161,8 +162,8 @@ def visualize_dataset( rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) # display each dimension of observed state space (e.g. agent position in joint space) - if "observation.state" in batch: - for dim_idx, val in enumerate(batch["observation.state"][i]): + if OBS_STATE in batch: + for dim_idx, val in enumerate(batch[OBS_STATE][i]): rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) if "next.done" in batch: diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index ca900f8d..310f771a 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -81,6 +81,7 @@ from lerobot.envs.utils import ( from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.utils.constants import OBS_STR from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -221,7 +222,7 @@ def rollout( stacked_observations = {} for key in all_observations[0]: stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) - ret["observation"] = stacked_observations + ret[OBS_STR] = stacked_observations if hasattr(policy, "use_original_modules"): policy.use_original_modules() @@ -459,8 +460,8 @@ def _compile_episode_data( for k in ep_dict: ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]]) - for key in rollout_data["observation"]: - ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames] + for key in rollout_data[OBS_STR]: + ep_dict[key] = rollout_data[OBS_STR][key][ep_ix, :num_frames] ep_dicts.append(ep_dict) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index dd4984fa..f1d026a3 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -109,6 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401 so101_leader, ) from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop +from lerobot.utils.constants import OBS_STR from lerobot.utils.control_utils import ( init_keyboard_listener, is_headless, @@ -303,7 +304,7 @@ def record_loop( obs_processed = robot_observation_processor(obs) if policy is not None or dataset is not None: - observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix="observation") + observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR) # Get action from either policy or teleop if policy is not None and preprocessor is not None and postprocessor is not None: diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 464969c7..33781790 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -17,19 +17,21 @@ from pathlib import Path from huggingface_hub.constants import HF_HOME -OBS_ENV_STATE = "observation.environment_state" -OBS_STATE = "observation.state" -OBS_IMAGE = "observation.image" -OBS_IMAGES = "observation.images" -OBS_LANGUAGE = "observation.language" +OBS_STR = "observation" +OBS_PREFIX = OBS_STR + "." +OBS_ENV_STATE = OBS_STR + ".environment_state" +OBS_STATE = OBS_STR + ".state" +OBS_IMAGE = OBS_STR + ".image" +OBS_IMAGES = OBS_IMAGE + "s" +OBS_LANGUAGE = OBS_STR + ".language" +OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" +OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" + ACTION = "action" REWARD = "next.reward" TRUNCATED = "next.truncated" DONE = "next.done" -OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" -OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" - ROBOTS = "robots" ROBOT_TYPE = "robot_type" TELEOPERATORS = "teleoperators" diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 7fc881f2..ae070b7c 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -19,6 +19,8 @@ from typing import Any import numpy as np import rerun as rr +from .constants import OBS_PREFIX, OBS_STR + def init_rerun(session_name: str = "lerobot_control_loop") -> None: """Initializes the Rerun SDK for visualizing the control loop.""" @@ -63,7 +65,7 @@ def log_rerun_data( for k, v in observation.items(): if v is None: continue - key = k if str(k).startswith("observation.") else f"observation.{k}" + key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}" if _is_scalar(v): rr.log(key, rr.Scalar(float(v))) diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index b0ffa9a3..e130ae14 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -24,6 +24,7 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors +from lerobot.utils.constants import OBS_STR from lerobot.utils.random_utils import set_seed @@ -92,7 +93,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): # for backward compatibility if k == "task": continue - if k.startswith("observation"): + if k.startswith(OBS_STR): obs[k] = batch[k] if hasattr(train_cfg.policy, "n_action_steps"): diff --git a/tests/async_inference/test_helpers.py b/tests/async_inference/test_helpers.py index f1c7636e..acf5870d 100644 --- a/tests/async_inference/test_helpers.py +++ b/tests/async_inference/test_helpers.py @@ -30,6 +30,7 @@ from lerobot.async_inference.helpers import ( resize_robot_observation_image, ) from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # --------------------------------------------------------------------- # FPSTracker @@ -115,7 +116,7 @@ def test_timed_action_getters(): def test_timed_observation_getters(): """TimedObservation stores & returns timestamp, dict and timestep.""" ts = time.time() - obs_dict = {"observation.state": torch.ones(6)} + obs_dict = {OBS_STATE: torch.ones(6)} to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0) assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6) @@ -151,7 +152,7 @@ def test_timed_data_deserialization_data_getters(): # ------------------------------------------------------------------ # TimedObservation # ------------------------------------------------------------------ - obs_dict = {"observation.state": torch.arange(4).float()} + obs_dict = {OBS_STATE: torch.arange(4).float()} to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True) to_bytes = pickle.dumps(to_in) # nosec @@ -161,7 +162,7 @@ def test_timed_data_deserialization_data_getters(): assert to_out.get_timestep() == 7 assert to_out.must_go is True assert to_out.get_observation().keys() == obs_dict.keys() - torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"]) + torch.testing.assert_close(to_out.get_observation()[OBS_STATE], obs_dict[OBS_STATE]) # --------------------------------------------------------------------- @@ -187,7 +188,7 @@ def test_observations_similar_true(): """Distance below atol → observations considered similar.""" # Create mock lerobot features for the similarity check lerobot_features = { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": [4], "names": ["shoulder", "elbow", "wrist", "gripper"], @@ -222,17 +223,17 @@ def _create_mock_robot_observation(): def _create_mock_lerobot_features(): """Create mock lerobot features mapping similar to what hw_to_dataset_features returns.""" return { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": [4], "names": ["shoulder", "elbow", "wrist", "gripper"], }, - "observation.images.laptop": { + f"{OBS_IMAGES}.laptop": { "dtype": "image", "shape": [480, 640, 3], "names": ["height", "width", "channels"], }, - "observation.images.phone": { + f"{OBS_IMAGES}.phone": { "dtype": "image", "shape": [480, 640, 3], "names": ["height", "width", "channels"], @@ -243,11 +244,11 @@ def _create_mock_lerobot_features(): def _create_mock_policy_image_features(): """Create mock policy image features with different resolutions.""" return { - "observation.images.laptop": PolicyFeature( + f"{OBS_IMAGES}.laptop": PolicyFeature( type=FeatureType.VISUAL, shape=(3, 224, 224), # Policy expects smaller resolution ), - "observation.images.phone": PolicyFeature( + f"{OBS_IMAGES}.phone": PolicyFeature( type=FeatureType.VISUAL, shape=(3, 160, 160), # Different resolution for second camera ), @@ -306,21 +307,21 @@ def test_prepare_raw_observation(): prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features) # Check that state is properly extracted and batched - assert "observation.state" in prepared - state = prepared["observation.state"] + assert OBS_STATE in prepared + state = prepared[OBS_STATE] assert isinstance(state, torch.Tensor) assert state.shape == (1, 4) # Batched state # Check that images are processed and resized - assert "observation.images.laptop" in prepared - assert "observation.images.phone" in prepared + assert f"{OBS_IMAGES}.laptop" in prepared + assert f"{OBS_IMAGES}.phone" in prepared - laptop_img = prepared["observation.images.laptop"] - phone_img = prepared["observation.images.phone"] + laptop_img = prepared[f"{OBS_IMAGES}.laptop"] + phone_img = prepared[f"{OBS_IMAGES}.phone"] # Check image shapes match policy requirements - assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape - assert phone_img.shape == policy_image_features["observation.images.phone"].shape + assert laptop_img.shape == policy_image_features[f"{OBS_IMAGES}.laptop"].shape + assert phone_img.shape == policy_image_features[f"{OBS_IMAGES}.phone"].shape # Check that images are tensors assert isinstance(laptop_img, torch.Tensor) @@ -337,19 +338,19 @@ def test_raw_observation_to_observation_basic(): observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) # Check that all expected keys are present - assert "observation.state" in observation - assert "observation.images.laptop" in observation - assert "observation.images.phone" in observation + assert OBS_STATE in observation + assert f"{OBS_IMAGES}.laptop" in observation + assert f"{OBS_IMAGES}.phone" in observation # Check state processing - state = observation["observation.state"] + state = observation[OBS_STATE] assert isinstance(state, torch.Tensor) assert state.device.type == device assert state.shape == (1, 4) # Batched # Check image processing - laptop_img = observation["observation.images.laptop"] - phone_img = observation["observation.images.phone"] + laptop_img = observation[f"{OBS_IMAGES}.laptop"] + phone_img = observation[f"{OBS_IMAGES}.phone"] # Images should have batch dimension: (B, C, H, W) assert laptop_img.shape == (1, 3, 224, 224) @@ -429,19 +430,19 @@ def test_image_processing_pipeline_preserves_content(): robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img} lerobot_features = { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": [4], "names": ["shoulder", "elbow", "wrist", "gripper"], }, - "observation.images.laptop": { + f"{OBS_IMAGES}.laptop": { "dtype": "image", "shape": [100, 100, 3], "names": ["height", "width", "channels"], }, } policy_image_features = { - "observation.images.laptop": PolicyFeature( + f"{OBS_IMAGES}.laptop": PolicyFeature( type=FeatureType.VISUAL, shape=(3, 50, 50), # Downsamples from 100x100 ) @@ -449,7 +450,7 @@ def test_image_processing_pipeline_preserves_content(): observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu") - processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim + processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim # Check that the center region has higher values than corners # Due to bilinear interpolation, exact values will change but pattern should remain diff --git a/tests/async_inference/test_policy_server.py b/tests/async_inference/test_policy_server.py index c5c52460..de441ff0 100644 --- a/tests/async_inference/test_policy_server.py +++ b/tests/async_inference/test_policy_server.py @@ -23,6 +23,7 @@ import pytest import torch from lerobot.configs.types import PolicyFeature +from lerobot.utils.constants import OBS_STATE from tests.utils import require_package # ----------------------------------------------------------------------------- @@ -44,7 +45,7 @@ class MockPolicy: def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: """Return a chunk of 20 dummy actions.""" - batch_size = len(observation["observation.state"]) + batch_size = len(observation[OBS_STATE]) return torch.zeros(batch_size, 20, 6) def __init__(self): @@ -77,7 +78,7 @@ def policy_server(): # Add mock lerobot_features that the observation similarity functions need server.lerobot_features = { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": [6], "names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"], diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index 8f8179c2..982f35c3 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -28,6 +28,7 @@ from lerobot.datasets.compute_stats import ( sample_images, sample_indices, ) +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE def mock_load_image_as_numpy(path, dtype, channel_first): @@ -136,21 +137,21 @@ def test_get_feature_stats_single_value(): def test_compute_episode_stats(): episode_data = { - "observation.image": [f"image_{i}.jpg" for i in range(100)], - "observation.state": np.random.rand(100, 10), + OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)], + OBS_STATE: np.random.rand(100, 10), } features = { - "observation.image": {"dtype": "image"}, - "observation.state": {"dtype": "numeric"}, + OBS_IMAGE: {"dtype": "image"}, + OBS_STATE: {"dtype": "numeric"}, } with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy): stats = compute_episode_stats(episode_data, features) - assert "observation.image" in stats and "observation.state" in stats - assert stats["observation.image"]["count"].item() == 100 - assert stats["observation.state"]["count"].item() == 100 - assert stats["observation.image"]["mean"].shape == (3, 1, 1) + assert OBS_IMAGE in stats and OBS_STATE in stats + assert stats[OBS_IMAGE]["count"].item() == 100 + assert stats[OBS_STATE]["count"].item() == 100 + assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1) def test_assert_type_and_shape_valid(): @@ -224,38 +225,38 @@ def test_aggregate_feature_stats(): def test_aggregate_stats(): all_stats = [ { - "observation.image": { + OBS_IMAGE: { "min": [1, 2, 3], "max": [10, 20, 30], "mean": [5.5, 10.5, 15.5], "std": [2.87, 5.87, 8.87], "count": 10, }, - "observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10}, + OBS_STATE: {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10}, "extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6}, }, { - "observation.image": { + OBS_IMAGE: { "min": [2, 1, 0], "max": [15, 10, 5], "mean": [8.5, 5.5, 2.5], "std": [3.42, 2.42, 1.42], "count": 15, }, - "observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15}, + OBS_STATE: {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15}, "extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5}, }, ] expected_agg_stats = { - "observation.image": { + OBS_IMAGE: { "min": [1, 1, 0], "max": [15, 20, 30], "mean": [7.3, 7.5, 7.7], "std": [3.5317, 4.8267, 8.5581], "count": 25, }, - "observation.state": { + OBS_STATE: { "min": 1, "max": 15, "mean": 7.3, @@ -283,7 +284,7 @@ def test_aggregate_stats(): for fkey, stats in ep_stats.items(): for k in stats: stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) - if fkey == "observation.image" and k != "count": + if fkey == OBS_IMAGE and k != "count": stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels else: stats[k] = stats[k].reshape(1) @@ -292,7 +293,7 @@ def test_aggregate_stats(): for fkey, stats in expected_agg_stats.items(): for k in stats: stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) - if fkey == "observation.image" and k != "count": + if fkey == OBS_IMAGE and k != "count": stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels else: stats[k] = stats[k].reshape(1) diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py index f1ffd800..c0b07ca6 100644 --- a/tests/datasets/test_dataset_utils.py +++ b/tests/datasets/test_dataset_utils.py @@ -21,6 +21,7 @@ from huggingface_hub import DatasetCard from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch +from lerobot.utils.constants import OBS_IMAGES def test_default_parameters(): @@ -96,14 +97,14 @@ def test_merge_multiple_groups_order_and_dedup(): def test_non_vector_last_wins_for_images(): # Non-vector (images) with same name should be overwritten by the last image specified g1 = { - "observation.images.front": { + f"{OBS_IMAGES}.front": { "dtype": "image", "shape": (3, 480, 640), "names": ["channels", "height", "width"], } } g2 = { - "observation.images.front": { + f"{OBS_IMAGES}.front": { "dtype": "image", "shape": (3, 720, 1280), "names": ["channels", "height", "width"], @@ -111,8 +112,8 @@ def test_non_vector_last_wins_for_images(): } out = combine_feature_dicts(g1, g2) - assert out["observation.images.front"]["shape"] == (3, 720, 1280) - assert out["observation.images.front"]["dtype"] == "image" + assert out[f"{OBS_IMAGES}.front"]["shape"] == (3, 720, 1280) + assert out[f"{OBS_IMAGES}.front"]["dtype"] == "image" def test_dtype_mismatch_raises(): diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index d1d6dbdb..1d461c8b 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -46,6 +46,7 @@ from lerobot.datasets.utils import ( from lerobot.envs.factory import make_env_config from lerobot.policies.factory import make_policy_config from lerobot.robots import make_robot_from_config +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.mocks.mock_robot import MockRobotConfig from tests.utils import require_x86_64_kernel @@ -75,7 +76,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): # Instantiate both ways robot = make_robot_from_config(MockRobotConfig()) action_features = hw_to_dataset_features(robot.action_features, "action", True) - obs_features = hw_to_dataset_features(robot.observation_features, "observation", True) + obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR, True) dataset_features = {**action_features, **obs_features} root_create = tmp_path / "create" dataset_create = LeRobotDataset.create( @@ -397,7 +398,7 @@ def test_factory(env_name, repo_id, policy_name): ("frame_index", 0, True), ("timestamp", 0, True), # TODO(rcadene): should we rename it agent_pos? - ("observation.state", 1, True), + (OBS_STATE, 1, True), ("next.reward", 0, False), ("next.done", 0, False), ] @@ -662,7 +663,7 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory): """Test the update_chunk_settings functionality for both LeRobotDataset and LeRobotDatasetMetadata.""" features = { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"], @@ -769,7 +770,7 @@ def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory): def test_update_chunk_settings_video_dataset(tmp_path): """Test update_chunk_settings with a video dataset to ensure video-specific logic works.""" features = { - "observation.images.cam": { + f"{OBS_IMAGES}.cam": { "dtype": "video", "shape": (480, 640, 3), "names": ["height", "width", "channels"], diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py index 0be1b9c7..7a878223 100644 --- a/tests/policies/hilserl/test_modeling_classifier.py +++ b/tests/policies/hilserl/test_modeling_classifier.py @@ -19,6 +19,7 @@ import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput +from lerobot.utils.constants import OBS_IMAGE from tests.utils import require_package @@ -41,7 +42,7 @@ def test_binary_classifier_with_default_params(): config = RewardClassifierConfig() config.input_features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)), @@ -56,7 +57,7 @@ def test_binary_classifier_with_default_params(): batch_size = 10 input = { - "observation.image": torch.rand((batch_size, 3, 128, 128)), + OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)), "next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(), } @@ -83,7 +84,7 @@ def test_multiclass_classifier(): num_classes = 5 config = RewardClassifierConfig() config.input_features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)), @@ -95,7 +96,7 @@ def test_multiclass_classifier(): batch_size = 10 input = { - "observation.image": torch.rand((batch_size, 3, 128, 128)), + OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)), "next.reward": torch.rand((batch_size, num_classes)), } diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index b577e576..7752ad63 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -41,7 +41,7 @@ from lerobot.policies.factory import ( make_pre_post_processors, ) from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.constants import ACTION, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -52,7 +52,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p # Create only one camera input which is squared to fit all current policy constraints # e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared camera_features = { - "observation.images.laptop": { + f"{OBS_IMAGES}.laptop": { "shape": (84, 84, 3), "names": ["height", "width", "channels"], "info": None, @@ -64,7 +64,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], }, - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], @@ -281,7 +281,7 @@ def test_multikey_construction(multikey: bool): preventing erroneous creation of the policy object. """ input_features = { - "observation.state": PolicyFeature( + OBS_STATE: PolicyFeature( type=FeatureType.STATE, shape=(10,), ), @@ -297,9 +297,9 @@ def test_multikey_construction(multikey: bool): """Simulates the complete state/action is constructed from more granular multiple keys, of the same type as the overall state/action""" input_features = {} - input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) - input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) - input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,)) + input_features[f"{OBS_STATE}.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) + input_features[f"{OBS_STATE}.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) + input_features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(10,)) output_features = {} output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,)) diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_sac_config.py index a67815ee..59ed4af6 100644 --- a/tests/policies/test_sac_config.py +++ b/tests/policies/test_sac_config.py @@ -25,6 +25,7 @@ from lerobot.policies.sac.configuration_sac import ( PolicyConfig, SACConfig, ) +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE def test_sac_config_default_initialization(): @@ -37,11 +38,11 @@ def test_sac_config_default_initialization(): "ACTION": NormalizationMode.MIN_MAX, } assert config.dataset_stats == { - "observation.image": { + OBS_IMAGE: { "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], }, - "observation.state": { + OBS_STATE: { "min": [0.0, 0.0], "max": [1.0, 1.0], }, @@ -90,11 +91,11 @@ def test_sac_config_default_initialization(): # Dataset stats defaults expected_dataset_stats = { - "observation.image": { + OBS_IMAGE: { "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], }, - "observation.state": { + OBS_STATE: { "min": [0.0, 0.0], "max": [1.0, 1.0], }, @@ -191,7 +192,7 @@ def test_sac_config_custom_initialization(): def test_validate_features(): config = SACConfig( - input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) config.validate_features() @@ -210,7 +211,7 @@ def test_validate_features_missing_observation(): def test_validate_features_missing_action(): config = SACConfig( - input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) with pytest.raises(ValueError, match="You must provide 'action' in the output features"): diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 7891c2e5..71e45e05 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -23,6 +23,7 @@ from torch import Tensor, nn from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.modeling_sac import MLP, SACPolicy +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE from lerobot.utils.random_utils import seeded_context, set_seed try: @@ -85,14 +86,14 @@ def test_sac_policy_with_default_args(): def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor: return { - "observation.state": torch.randn(batch_size, state_dim), + OBS_STATE: torch.randn(batch_size, state_dim), } def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor: return { - "observation.image": torch.randn(batch_size, 3, 84, 84), - "observation.state": torch.randn(batch_size, state_dim), + OBS_IMAGE: torch.randn(batch_size, 3, 84, 84), + OBS_STATE: torch.randn(batch_size, state_dim), } @@ -126,14 +127,14 @@ def create_train_batch_with_visual_input( def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: return { - "observation.state": torch.randn(batch_size, state_dim), + OBS_STATE: torch.randn(batch_size, state_dim), } def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: return { - "observation.state": torch.randn(batch_size, state_dim), - "observation.image": torch.randn(batch_size, 3, 84, 84), + OBS_STATE: torch.randn(batch_size, state_dim), + OBS_IMAGE: torch.randn(batch_size, 3, 84, 84), } @@ -180,10 +181,10 @@ def create_default_config( action_dim += 1 config = SACConfig( - input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, dataset_stats={ - "observation.state": { + OBS_STATE: { "min": [0.0] * state_dim, "max": [1.0] * state_dim, }, @@ -205,8 +206,8 @@ def create_config_with_visual_input( continuous_action_dim=continuous_action_dim, has_discrete_action=has_discrete_action, ) - config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) - config.dataset_stats["observation.image"] = { + config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) + config.dataset_stats[OBS_IMAGE] = { "mean": torch.randn(3, 1, 1), "std": torch.randn(3, 1, 1), } diff --git a/tests/processor/test_act_processor.py b/tests/processor/test_act_processor.py index 00a4dbb9..134cff68 100644 --- a/tests/processor/test_act_processor.py +++ b/tests/processor/test_act_processor.py @@ -342,7 +342,7 @@ def test_act_processor_batch_consistency(): batch = transition_to_batch(transition) processed = preprocessor(batch) - assert processed["observation.state"].shape[0] == 1 # Batched + assert processed[OBS_STATE].shape[0] == 1 # Batched # Test already batched data observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8 diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 631ad789..8bf24db0 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -2,14 +2,15 @@ import torch from lerobot.processor import DataProcessorPipeline, TransitionKey from lerobot.processor.converters import batch_to_transition, transition_to_batch +from lerobot.utils.constants import OBS_IMAGE, OBS_PREFIX, OBS_STATE def _dummy_batch(): """Create a dummy batch using the new format with observation.* and next.* keys.""" return { - "observation.image.left": torch.randn(1, 3, 128, 128), - "observation.image.right": torch.randn(1, 3, 128, 128), - "observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]), + f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), + f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128), + OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]), "action": torch.tensor([[0.5]]), "next.reward": 1.0, "next.done": False, @@ -25,15 +26,15 @@ def test_observation_grouping_roundtrip(): batch_out = proc(batch_in) # Check that all observation.* keys are preserved - original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")} - reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")} + original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith(OBS_PREFIX)} + reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith(OBS_PREFIX)} assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys()) # Check tensor values - assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"]) - assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"]) - assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"]) + assert torch.allclose(batch_out[f"{OBS_IMAGE}.left"], batch_in[f"{OBS_IMAGE}.left"]) + assert torch.allclose(batch_out[f"{OBS_IMAGE}.right"], batch_in[f"{OBS_IMAGE}.right"]) + assert torch.allclose(batch_out[OBS_STATE], batch_in[OBS_STATE]) # Check other fields assert torch.allclose(batch_out["action"], batch_in["action"]) @@ -46,9 +47,9 @@ def test_observation_grouping_roundtrip(): def test_batch_to_transition_observation_grouping(): """Test that batch_to_transition correctly groups observation.* keys.""" batch = { - "observation.image.top": torch.randn(1, 3, 128, 128), - "observation.image.left": torch.randn(1, 3, 128, 128), - "observation.state": [1, 2, 3, 4], + f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128), + f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), + OBS_STATE: [1, 2, 3, 4], "action": torch.tensor([0.1, 0.2, 0.3, 0.4]), "next.reward": 1.5, "next.done": True, @@ -60,18 +61,18 @@ def test_batch_to_transition_observation_grouping(): # Check observation is a dict with all observation.* keys assert isinstance(transition[TransitionKey.OBSERVATION], dict) - assert "observation.image.top" in transition[TransitionKey.OBSERVATION] - assert "observation.image.left" in transition[TransitionKey.OBSERVATION] - assert "observation.state" in transition[TransitionKey.OBSERVATION] + assert f"{OBS_IMAGE}.top" in transition[TransitionKey.OBSERVATION] + assert f"{OBS_IMAGE}.left" in transition[TransitionKey.OBSERVATION] + assert OBS_STATE in transition[TransitionKey.OBSERVATION] # Check values are preserved assert torch.allclose( - transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"] + transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.top"], batch[f"{OBS_IMAGE}.top"] ) assert torch.allclose( - transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"] + transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.left"], batch[f"{OBS_IMAGE}.left"] ) - assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] + assert transition[TransitionKey.OBSERVATION][OBS_STATE] == [1, 2, 3, 4] # Check other fields assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4])) @@ -85,9 +86,9 @@ def test_batch_to_transition_observation_grouping(): def test_transition_to_batch_observation_flattening(): """Test that transition_to_batch correctly flattens observation dict.""" observation_dict = { - "observation.image.top": torch.randn(1, 3, 128, 128), - "observation.image.left": torch.randn(1, 3, 128, 128), - "observation.state": [1, 2, 3, 4], + f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128), + f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), + OBS_STATE: [1, 2, 3, 4], } transition = { @@ -103,14 +104,14 @@ def test_transition_to_batch_observation_flattening(): batch = transition_to_batch(transition) # Check that observation.* keys are flattened back to batch - assert "observation.image.top" in batch - assert "observation.image.left" in batch - assert "observation.state" in batch + assert f"{OBS_IMAGE}.top" in batch + assert f"{OBS_IMAGE}.left" in batch + assert OBS_STATE in batch # Check values are preserved - assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"]) - assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"]) - assert batch["observation.state"] == [1, 2, 3, 4] + assert torch.allclose(batch[f"{OBS_IMAGE}.top"], observation_dict[f"{OBS_IMAGE}.top"]) + assert torch.allclose(batch[f"{OBS_IMAGE}.left"], observation_dict[f"{OBS_IMAGE}.left"]) + assert batch[OBS_STATE] == [1, 2, 3, 4] # Check other fields are mapped to next.* format assert batch["action"] == "action_data" @@ -153,12 +154,12 @@ def test_no_observation_keys(): def test_minimal_batch(): """Test with minimal batch containing only observation.* and action.""" - batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])} + batch = {OBS_STATE: "minimal_state", "action": torch.tensor([0.5])} transition = batch_to_transition(batch) # Check observation - assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} + assert transition[TransitionKey.OBSERVATION] == {OBS_STATE: "minimal_state"} assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5])) # Check defaults @@ -170,7 +171,7 @@ def test_minimal_batch(): # Round trip reconstructed_batch = transition_to_batch(transition) - assert reconstructed_batch["observation.state"] == "minimal_state" + assert reconstructed_batch[OBS_STATE] == "minimal_state" assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5])) assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] @@ -205,9 +206,9 @@ def test_empty_batch(): def test_complex_nested_observation(): """Test with complex nested observation data.""" batch = { - "observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, - "observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, - "observation.state": torch.randn(7), + f"{OBS_IMAGE}.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, + f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, + OBS_STATE: torch.randn(7), "action": torch.randn(8), "next.reward": 3.14, "next.done": False, @@ -219,20 +220,20 @@ def test_complex_nested_observation(): reconstructed_batch = transition_to_batch(transition) # Check that all observation keys are preserved - original_obs_keys = {k for k in batch if k.startswith("observation.")} - reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")} + original_obs_keys = {k for k in batch if k.startswith(OBS_PREFIX)} + reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith(OBS_PREFIX)} assert original_obs_keys == reconstructed_obs_keys # Check tensor values - assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"]) + assert torch.allclose(batch[OBS_STATE], reconstructed_batch[OBS_STATE]) # Check nested dict with tensors assert torch.allclose( - batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"] + batch[f"{OBS_IMAGE}.top"]["image"], reconstructed_batch[f"{OBS_IMAGE}.top"]["image"] ) assert torch.allclose( - batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"] + batch[f"{OBS_IMAGE}.left"]["image"], reconstructed_batch[f"{OBS_IMAGE}.left"]["image"] ) # Check action tensor @@ -264,7 +265,7 @@ def test_custom_converter(): processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch) batch = { - "observation.state": torch.randn(1, 4), + OBS_STATE: torch.randn(1, 4), "action": torch.randn(1, 2), "next.reward": 1.0, "next.done": False, @@ -274,5 +275,5 @@ def test_custom_converter(): # Check the reward was doubled by our custom converter assert result["next.reward"] == 2.0 - assert torch.allclose(result["observation.state"], batch["observation.state"]) + assert torch.allclose(result[OBS_STATE], batch[OBS_STATE]) assert torch.allclose(result["action"], batch["action"]) diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index fc91951d..b03d4921 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -9,6 +9,7 @@ from lerobot.processor.converters import ( to_tensor, transition_to_batch, ) +from lerobot.utils.constants import OBS_STATE, OBS_STR # Tests for the unified to_tensor function @@ -118,16 +119,16 @@ def test_to_tensor_dictionaries(): # Nested dictionary nested = { "action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]}, - "observation": {"mean": np.array([0.5, 0.6]), "count": 10}, + OBS_STR: {"mean": np.array([0.5, 0.6]), "count": 10}, } result = to_tensor(nested) assert isinstance(result, dict) assert isinstance(result["action"], dict) - assert isinstance(result["observation"], dict) + assert isinstance(result[OBS_STR], dict) assert isinstance(result["action"]["mean"], torch.Tensor) - assert isinstance(result["observation"]["mean"], torch.Tensor) + assert isinstance(result[OBS_STR]["mean"], torch.Tensor) assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2])) - assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6])) + assert torch.allclose(result[OBS_STR]["mean"], torch.tensor([0.5, 0.6])) def test_to_tensor_none_filtering(): @@ -198,7 +199,7 @@ def test_batch_to_transition_with_index_fields(): # Create batch with index and task_index fields batch = { - "observation.state": torch.randn(1, 7), + OBS_STATE: torch.randn(1, 7), "action": torch.randn(1, 4), "next.reward": 1.5, "next.done": False, @@ -231,7 +232,7 @@ def testtransition_to_batch_with_index_fields(): # Create transition with index and task_index in complementary_data transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, action=torch.randn(1, 4), reward=1.5, done=False, @@ -260,7 +261,7 @@ def test_batch_to_transition_without_index_fields(): # Batch without index/task_index batch = { - "observation.state": torch.randn(1, 7), + OBS_STATE: torch.randn(1, 7), "action": torch.randn(1, 4), "task": ["pick_cube"], } @@ -279,7 +280,7 @@ def test_transition_to_batch_without_index_fields(): # Transition without index/task_index transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, action=torch.randn(1, 4), complementary_data={"task": ["navigate"]}, ) diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py index 10ee313d..36081e02 100644 --- a/tests/processor/test_device_processor.py +++ b/tests/processor/test_device_processor.py @@ -21,6 +21,7 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE def test_basic_functionality(): @@ -28,7 +29,7 @@ def test_basic_functionality(): processor = DeviceProcessorStep(device="cpu") # Create a transition with CPU tensors - observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} + observation = {OBS_STATE: torch.randn(10), OBS_IMAGE: torch.randn(3, 224, 224)} action = torch.randn(5) reward = torch.tensor(1.0) done = torch.tensor(False) @@ -41,8 +42,8 @@ def test_basic_functionality(): result = processor(transition) # Check that all tensors are on CPU - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" - assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cpu" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu" + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cpu" assert result[TransitionKey.ACTION].device.type == "cpu" assert result[TransitionKey.REWARD].device.type == "cpu" assert result[TransitionKey.DONE].device.type == "cpu" @@ -55,7 +56,7 @@ def test_cuda_functionality(): processor = DeviceProcessorStep(device="cuda") # Create a transition with CPU tensors - observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} + observation = {OBS_STATE: torch.randn(10), OBS_IMAGE: torch.randn(3, 224, 224)} action = torch.randn(5) reward = torch.tensor(1.0) done = torch.tensor(False) @@ -68,8 +69,8 @@ def test_cuda_functionality(): result = processor(transition) # Check that all tensors are on CUDA - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" - assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.REWARD].device.type == "cuda" assert result[TransitionKey.DONE].device.type == "cuda" @@ -81,14 +82,14 @@ def test_specific_cuda_device(): """Test device processor with specific CUDA device.""" processor = DeviceProcessorStep(device="cuda:0") - observation = {"observation.state": torch.randn(10)} + observation = {OBS_STATE: torch.randn(10)} action = torch.randn(5) transition = create_transition(observation=observation, action=action) result = processor(transition) - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" - assert result[TransitionKey.OBSERVATION]["observation.state"].device.index == 0 + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.index == 0 assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.ACTION].device.index == 0 @@ -98,7 +99,7 @@ def test_non_tensor_values(): processor = DeviceProcessorStep(device="cpu") observation = { - "observation.state": torch.randn(10), + OBS_STATE: torch.randn(10), "observation.metadata": {"key": "value"}, # Non-tensor data "observation.list": [1, 2, 3], # Non-tensor data } @@ -110,7 +111,7 @@ def test_non_tensor_values(): result = processor(transition) # Check tensors are processed - assert isinstance(result[TransitionKey.OBSERVATION]["observation.state"], torch.Tensor) + assert isinstance(result[TransitionKey.OBSERVATION][OBS_STATE], torch.Tensor) assert isinstance(result[TransitionKey.ACTION], torch.Tensor) # Check non-tensor values are preserved @@ -130,9 +131,9 @@ def test_none_values(): assert result[TransitionKey.ACTION].device.type == "cpu" # Test with None action - transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None) + transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=None) result = processor(transition) - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu" assert result[TransitionKey.ACTION] is None @@ -271,9 +272,7 @@ def test_features(): processor = DeviceProcessorStep(device="cpu") features = { - PipelineFeatureType.OBSERVATION: { - "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)) - }, + PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, } @@ -376,7 +375,7 @@ def test_reward_done_truncated_types(): # Test with scalar values (not tensors) transition = create_transition( - observation={"observation.state": torch.randn(5)}, + observation={OBS_STATE: torch.randn(5)}, action=torch.randn(3), reward=1.0, # float done=False, # bool @@ -392,7 +391,7 @@ def test_reward_done_truncated_types(): # Test with tensor values transition = create_transition( - observation={"observation.state": torch.randn(5)}, + observation={OBS_STATE: torch.randn(5)}, action=torch.randn(3), reward=torch.tensor(1.0), done=torch.tensor(False), @@ -422,7 +421,7 @@ def test_complementary_data_preserved(): } transition = create_transition( - observation={"observation.state": torch.randn(5)}, complementary_data=complementary_data + observation={OBS_STATE: torch.randn(5)}, complementary_data=complementary_data ) result = processor(transition) @@ -491,13 +490,13 @@ def test_float_dtype_bfloat16(): """Test conversion to bfloat16.""" processor = DeviceProcessorStep(device="cpu", float_dtype="bfloat16") - observation = {"observation.state": torch.randn(5, dtype=torch.float32)} + observation = {OBS_STATE: torch.randn(5, dtype=torch.float32)} action = torch.randn(3, dtype=torch.float64) transition = create_transition(observation=observation, action=action) result = processor(transition) - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.bfloat16 + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 assert result[TransitionKey.ACTION].dtype == torch.bfloat16 @@ -505,13 +504,13 @@ def test_float_dtype_float64(): """Test conversion to float64.""" processor = DeviceProcessorStep(device="cpu", float_dtype="float64") - observation = {"observation.state": torch.randn(5, dtype=torch.float16)} + observation = {OBS_STATE: torch.randn(5, dtype=torch.float16)} action = torch.randn(3, dtype=torch.float32) transition = create_transition(observation=observation, action=action) result = processor(transition) - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float64 + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float64 assert result[TransitionKey.ACTION].dtype == torch.float64 @@ -541,8 +540,8 @@ def test_float_dtype_with_mixed_tensors(): processor = DeviceProcessorStep(device="cpu", float_dtype="float32") observation = { - "observation.image": torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert - "observation.state": torch.randn(10, dtype=torch.float64), # Should convert + OBS_IMAGE: torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert + OBS_STATE: torch.randn(10, dtype=torch.float64), # Should convert "observation.mask": torch.tensor([True, False, True], dtype=torch.bool), # Should not convert "observation.indices": torch.tensor([1, 2, 3], dtype=torch.long), # Should not convert } @@ -552,8 +551,8 @@ def test_float_dtype_with_mixed_tensors(): result = processor(transition) # Check conversions - assert result[TransitionKey.OBSERVATION]["observation.image"].dtype == torch.uint8 # Unchanged - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.uint8 # Unchanged + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float32 # Converted assert result[TransitionKey.OBSERVATION]["observation.mask"].dtype == torch.bool # Unchanged assert result[TransitionKey.OBSERVATION]["observation.indices"].dtype == torch.long # Unchanged assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted @@ -612,7 +611,7 @@ def test_complementary_data_index_fields(): "episode_id": 123, # Non-tensor field } transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, action=torch.randn(1, 4), complementary_data=complementary_data, ) @@ -736,7 +735,7 @@ def test_complementary_data_full_pipeline_cuda(): processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16") # Create full transition with mixed CPU tensors - observation = {"observation.state": torch.randn(1, 7, dtype=torch.float32)} + observation = {OBS_STATE: torch.randn(1, 7, dtype=torch.float32)} action = torch.randn(1, 4, dtype=torch.float32) reward = torch.tensor(1.5, dtype=torch.float32) done = torch.tensor(False) @@ -757,7 +756,7 @@ def test_complementary_data_full_pipeline_cuda(): result = processor(transition) # Check all components moved to CUDA - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.REWARD].device.type == "cuda" assert result[TransitionKey.DONE].device.type == "cuda" @@ -768,7 +767,7 @@ def test_complementary_data_full_pipeline_cuda(): assert processed_comp_data["task_index"].device.type == "cuda" # Check float conversion happened for float tensors - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16 assert result[TransitionKey.ACTION].dtype == torch.float16 assert result[TransitionKey.REWARD].dtype == torch.float16 @@ -782,7 +781,7 @@ def test_complementary_data_empty(): processor = DeviceProcessorStep(device="cpu") transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, complementary_data={}, ) @@ -797,7 +796,7 @@ def test_complementary_data_none(): processor = DeviceProcessorStep(device="cpu") transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, complementary_data=None, ) @@ -814,8 +813,8 @@ def test_preserves_gpu_placement(): # Create tensors already on GPU observation = { - "observation.state": torch.randn(10).cuda(), # Already on GPU - "observation.image": torch.randn(3, 224, 224).cuda(), # Already on GPU + OBS_STATE: torch.randn(10).cuda(), # Already on GPU + OBS_IMAGE: torch.randn(3, 224, 224).cuda(), # Already on GPU } action = torch.randn(5).cuda() # Already on GPU @@ -823,14 +822,12 @@ def test_preserves_gpu_placement(): result = processor(transition) # Check that tensors remain on their original GPU - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" - assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda" # Verify no unnecessary copies were made (same data pointer) - assert torch.equal( - result[TransitionKey.OBSERVATION]["observation.state"], observation["observation.state"] - ) + assert torch.equal(result[TransitionKey.OBSERVATION][OBS_STATE], observation[OBS_STATE]) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") @@ -842,8 +839,8 @@ def test_multi_gpu_preservation(): # Create tensors on cuda:1 (simulating Accelerate placement) cuda1_device = torch.device("cuda:1") observation = { - "observation.state": torch.randn(10).to(cuda1_device), - "observation.image": torch.randn(3, 224, 224).to(cuda1_device), + OBS_STATE: torch.randn(10).to(cuda1_device), + OBS_IMAGE: torch.randn(3, 224, 224).to(cuda1_device), } action = torch.randn(5).to(cuda1_device) @@ -851,20 +848,20 @@ def test_multi_gpu_preservation(): result = processor_gpu(transition) # Check that tensors remain on cuda:1 (not moved to cuda:0) - assert result[TransitionKey.OBSERVATION]["observation.state"].device == cuda1_device - assert result[TransitionKey.OBSERVATION]["observation.image"].device == cuda1_device + assert result[TransitionKey.OBSERVATION][OBS_STATE].device == cuda1_device + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device == cuda1_device assert result[TransitionKey.ACTION].device == cuda1_device # Test 2: GPU-to-CPU should move to CPU (not preserve GPU) processor_cpu = DeviceProcessorStep(device="cpu") transition_gpu = create_transition( - observation={"observation.state": torch.randn(10).cuda()}, action=torch.randn(5).cuda() + observation={OBS_STATE: torch.randn(10).cuda()}, action=torch.randn(5).cuda() ) result_cpu = processor_cpu(transition_gpu) # Check that tensors are moved to CPU - assert result_cpu[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result_cpu[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu" assert result_cpu[TransitionKey.ACTION].device.type == "cpu" @@ -933,14 +930,14 @@ def test_simulated_accelerate_scenario(): # Simulate data already placed by Accelerate device = torch.device(f"cuda:{gpu_id}") - observation = {"observation.state": torch.randn(1, 10).to(device)} + observation = {OBS_STATE: torch.randn(1, 10).to(device)} action = torch.randn(1, 5).to(device) transition = create_transition(observation=observation, action=action) result = processor(transition) # Verify data stays on the GPU where Accelerate placed it - assert result[TransitionKey.OBSERVATION]["observation.state"].device == device + assert result[TransitionKey.OBSERVATION][OBS_STATE].device == device assert result[TransitionKey.ACTION].device == device @@ -1081,7 +1078,7 @@ def test_mps_float64_with_complementary_data(): } transition = create_transition( - observation={"observation.state": torch.randn(5, dtype=torch.float64)}, + observation={OBS_STATE: torch.randn(5, dtype=torch.float64)}, action=torch.randn(3, dtype=torch.float64), complementary_data=complementary_data, ) @@ -1089,7 +1086,7 @@ def test_mps_float64_with_complementary_data(): result = processor(transition) # Check that all tensors are on MPS device - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "mps" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "mps" assert result[TransitionKey.ACTION].device.type == "mps" processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] @@ -1099,7 +1096,7 @@ def test_mps_float64_with_complementary_data(): assert processed_comp_data["float32_tensor"].device.type == "mps" # Check dtype conversions - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float32 # Converted assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted assert processed_comp_data["float64_tensor"].dtype == torch.float32 # Converted assert processed_comp_data["float32_tensor"].dtype == torch.float32 # Unchanged diff --git a/tests/processor/test_migration_detection.py b/tests/processor/test_migration_detection.py index 6bed8289..b46cc6bd 100644 --- a/tests/processor/test_migration_detection.py +++ b/tests/processor/test_migration_detection.py @@ -25,6 +25,7 @@ from pathlib import Path import pytest from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError +from lerobot.utils.constants import OBS_STATE def test_is_processor_config_valid_configs(): @@ -111,7 +112,7 @@ def test_should_suggest_migration_with_model_config_only(): # Create a model config (like old LeRobot format) model_config = { "type": "act", - "input_features": {"observation.state": {"shape": [7]}}, + "input_features": {OBS_STATE: {"shape": [7]}}, "output_features": {"action": {"shape": [7]}}, "hidden_dim": 256, "n_obs_steps": 1, diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 5d779191..616f33db 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -29,22 +29,23 @@ from lerobot.processor import ( hotswap_stats, ) from lerobot.processor.converters import create_transition, identity_transition, to_tensor +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR from lerobot.utils.utils import auto_select_torch_device def test_numpy_conversion(): stats = { - "observation.image": { + OBS_IMAGE: { "mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2]), } } tensor_stats = to_tensor(stats) - assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) - assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) - assert torch.allclose(tensor_stats["observation.image"]["mean"], torch.tensor([0.5, 0.5, 0.5])) - assert torch.allclose(tensor_stats["observation.image"]["std"], torch.tensor([0.2, 0.2, 0.2])) + assert isinstance(tensor_stats[OBS_IMAGE]["mean"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["std"], torch.Tensor) + assert torch.allclose(tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.5, 0.5, 0.5])) + assert torch.allclose(tensor_stats[OBS_IMAGE]["std"], torch.tensor([0.2, 0.2, 0.2])) def test_tensor_conversion(): @@ -75,15 +76,15 @@ def test_scalar_conversion(): def test_list_conversion(): stats = { - "observation.state": { + OBS_STATE: { "min": [0.0, -1.0, -2.0], "max": [1.0, 1.0, 2.0], } } tensor_stats = to_tensor(stats) - assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0])) - assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0])) + assert torch.allclose(tensor_stats[OBS_STATE]["min"], torch.tensor([0.0, -1.0, -2.0])) + assert torch.allclose(tensor_stats[OBS_STATE]["max"], torch.tensor([1.0, 1.0, 2.0])) def test_unsupported_type(): @@ -99,8 +100,8 @@ def test_unsupported_type(): # Helper functions to create feature maps and norm maps def _create_observation_features(): return { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), } @@ -115,11 +116,11 @@ def _create_observation_norm_map(): @pytest.fixture def observation_stats(): return { - "observation.image": { + OBS_IMAGE: { "mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2]), }, - "observation.state": { + OBS_STATE: { "min": np.array([0.0, -1.0]), "max": np.array([1.0, 1.0]), }, @@ -136,8 +137,8 @@ def observation_normalizer(observation_stats): def test_mean_std_normalization(observation_normalizer): observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } transition = create_transition(observation=observation) @@ -146,12 +147,12 @@ def test_mean_std_normalization(observation_normalizer): # Check mean/std normalization expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 - assert torch.allclose(normalized_obs["observation.image"], expected_image) + assert torch.allclose(normalized_obs[OBS_IMAGE], expected_image) def test_min_max_normalization(observation_normalizer): observation = { - "observation.state": torch.tensor([0.5, 0.0]), + OBS_STATE: torch.tensor([0.5, 0.0]), } transition = create_transition(observation=observation) @@ -162,7 +163,7 @@ def test_min_max_normalization(observation_normalizer): # For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0 # For state[1]: 2 * (0.0 - (-1.0)) / (1.0 - (-1.0)) - 1 = 0.0 expected_state = torch.tensor([0.0, 0.0]) - assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) + assert torch.allclose(normalized_obs[OBS_STATE], expected_state, atol=1e-6) def test_selective_normalization(observation_stats): @@ -172,12 +173,12 @@ def test_selective_normalization(observation_stats): features=features, norm_map=norm_map, stats=observation_stats, - normalize_observation_keys={"observation.image"}, + normalize_observation_keys={OBS_IMAGE}, ) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } transition = create_transition(observation=observation) @@ -185,9 +186,9 @@ def test_selective_normalization(observation_stats): normalized_obs = normalized_transition[TransitionKey.OBSERVATION] # Only image should be normalized - assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2) + assert torch.allclose(normalized_obs[OBS_IMAGE], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2) # State should remain unchanged - assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"]) + assert torch.allclose(normalized_obs[OBS_STATE], observation[OBS_STATE]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -196,26 +197,26 @@ def test_device_compatibility(observation_stats): norm_map = _create_observation_norm_map() normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=observation_stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]).cuda(), } transition = create_transition(observation=observation) normalized_transition = normalizer(transition) normalized_obs = normalized_transition[TransitionKey.OBSERVATION] - assert normalized_obs["observation.image"].device.type == "cuda" + assert normalized_obs[OBS_IMAGE].device.type == "cuda" def test_from_lerobot_dataset(): # Mock dataset mock_dataset = Mock() mock_dataset.meta.stats = { - "observation.image": {"mean": [0.5], "std": [0.2]}, + OBS_IMAGE: {"mean": [0.5], "std": [0.2]}, "action": {"mean": [0.0], "std": [1.0]}, } features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), "action": PolicyFeature(FeatureType.ACTION, (1,)), } norm_map = { @@ -226,7 +227,7 @@ def test_from_lerobot_dataset(): normalizer = NormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) # Both observation and action statistics should be present in tensor stats - assert "observation.image" in normalizer._tensor_stats + assert OBS_IMAGE in normalizer._tensor_stats assert "action" in normalizer._tensor_stats @@ -242,13 +243,13 @@ def test_state_dict_save_load(observation_normalizer): new_normalizer.load_state_dict(state_dict) # Test that it works the same - observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + observation = {OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3])} transition = create_transition(observation=observation) result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION] result2 = new_normalizer(transition)[TransitionKey.OBSERVATION] - assert torch.allclose(result1["observation.image"], result2["observation.image"]) + assert torch.allclose(result1[OBS_IMAGE], result2[OBS_IMAGE]) # Fixtures for ActionUnnormalizer tests @@ -375,11 +376,11 @@ def test_action_from_lerobot_dataset(): @pytest.fixture def full_stats(): return { - "observation.image": { + OBS_IMAGE: { "mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2]), }, - "observation.state": { + OBS_STATE: { "min": np.array([0.0, -1.0]), "max": np.array([1.0, 1.0]), }, @@ -392,8 +393,8 @@ def full_stats(): def _create_full_features(): return { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } @@ -415,8 +416,8 @@ def normalizer_processor(full_stats): def test_combined_normalization(normalizer_processor): observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) transition = create_transition( @@ -434,7 +435,7 @@ def test_combined_normalization(normalizer_processor): # Check normalized observations processed_obs = processed_transition[TransitionKey.OBSERVATION] expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 - assert torch.allclose(processed_obs["observation.image"], expected_image) + assert torch.allclose(processed_obs[OBS_IMAGE], expected_image) # Check normalized action processed_action = processed_transition[TransitionKey.ACTION] @@ -455,11 +456,11 @@ def test_processor_from_lerobot_dataset(full_stats): norm_map = _create_full_norm_map() processor = NormalizerProcessorStep.from_lerobot_dataset( - mock_dataset, features, norm_map, normalize_observation_keys={"observation.image"} + mock_dataset, features, norm_map, normalize_observation_keys={OBS_IMAGE} ) - assert processor.normalize_observation_keys == {"observation.image"} - assert "observation.image" in processor._tensor_stats + assert processor.normalize_observation_keys == {OBS_IMAGE} + assert OBS_IMAGE in processor._tensor_stats assert "action" in processor._tensor_stats @@ -470,17 +471,17 @@ def test_get_config(full_stats): features=features, norm_map=norm_map, stats=full_stats, - normalize_observation_keys={"observation.image"}, + normalize_observation_keys={OBS_IMAGE}, eps=1e-6, ) config = processor.get_config() expected_config = { - "normalize_observation_keys": ["observation.image"], + "normalize_observation_keys": [OBS_IMAGE], "eps": 1e-6, "features": { - "observation.image": {"type": "VISUAL", "shape": (3, 96, 96)}, - "observation.state": {"type": "STATE", "shape": (2,)}, + OBS_IMAGE: {"type": "VISUAL", "shape": (3, 96, 96)}, + OBS_STATE: {"type": "STATE", "shape": (2,)}, "action": {"type": "ACTION", "shape": (2,)}, }, "norm_map": { @@ -499,8 +500,8 @@ def test_integration_with_robot_processor(normalizer_processor): ) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) transition = create_transition( @@ -522,8 +523,8 @@ def test_integration_with_robot_processor(normalizer_processor): # Edge case tests def test_empty_observation(): - stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + stats = {OBS_IMAGE: {"mean": [0.5], "std": [0.2]}} + features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -534,37 +535,35 @@ def test_empty_observation(): def test_empty_stats(): - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) - observation = {"observation.image": torch.tensor([0.5])} + observation = {OBS_IMAGE: torch.tensor([0.5])} transition = create_transition(observation=observation) result = normalizer(transition) # Should return observation unchanged since no stats are available - assert torch.allclose( - result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"] - ) + assert torch.allclose(result[TransitionKey.OBSERVATION][OBS_IMAGE], observation[OBS_IMAGE]) def test_partial_stats(): """If statistics are incomplete, the value should pass through unchanged.""" - stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + stats = {OBS_IMAGE: {"mean": [0.5]}} # Missing std / (min,max) + features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) - observation = {"observation.image": torch.tensor([0.7])} + observation = {OBS_IMAGE: torch.tensor([0.7])} transition = create_transition(observation=observation) processed = normalizer(transition)[TransitionKey.OBSERVATION] - assert torch.allclose(processed["observation.image"], observation["observation.image"]) + assert torch.allclose(processed[OBS_IMAGE], observation[OBS_IMAGE]) def test_missing_action_stats_no_error(): mock_dataset = Mock() - mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + mock_dataset.meta.stats = {OBS_IMAGE: {"mean": [0.5], "std": [0.2]}} - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} processor = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) @@ -580,7 +579,7 @@ def test_serialization_roundtrip(full_stats): features=features, norm_map=norm_map, stats=full_stats, - normalize_observation_keys={"observation.image"}, + normalize_observation_keys={OBS_IMAGE}, eps=1e-6, ) @@ -598,8 +597,8 @@ def test_serialization_roundtrip(full_stats): # Test that both processors work the same way observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) transition = create_transition( @@ -617,8 +616,8 @@ def test_serialization_roundtrip(full_stats): # Compare results assert torch.allclose( - result1[TransitionKey.OBSERVATION]["observation.image"], - result2[TransitionKey.OBSERVATION]["observation.image"], + result1[TransitionKey.OBSERVATION][OBS_IMAGE], + result2[TransitionKey.OBSERVATION][OBS_IMAGE], ) assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) @@ -644,23 +643,23 @@ def test_serialization_roundtrip(full_stats): def test_identity_normalization_observations(): """Test that IDENTITY mode skips normalization for observations.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode FeatureType.STATE: NormalizationMode.MEAN_STD, # Normal mode for comparison } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, - "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([1.0, -0.5]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([1.0, -0.5]), } transition = create_transition(observation=observation) @@ -668,11 +667,11 @@ def test_identity_normalization_observations(): normalized_obs = normalized_transition[TransitionKey.OBSERVATION] # Image should remain unchanged (IDENTITY) - assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + assert torch.allclose(normalized_obs[OBS_IMAGE], observation[OBS_IMAGE]) # State should be normalized (MEAN_STD) expected_state = (torch.tensor([1.0, -0.5]) - torch.tensor([0.0, 0.0])) / torch.tensor([1.0, 1.0]) - assert torch.allclose(normalized_obs["observation.state"], expected_state) + assert torch.allclose(normalized_obs[OBS_STATE], expected_state) def test_identity_normalization_actions(): @@ -695,23 +694,23 @@ def test_identity_normalization_actions(): def test_identity_unnormalization_observations(): """Test that IDENTITY mode skips unnormalization for observations.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode FeatureType.STATE: NormalizationMode.MIN_MAX, # Normal mode for comparison } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, - "observation.state": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + OBS_STATE: {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, } unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.0, -1.0]), # Normalized values in [-1, 1] + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.0, -1.0]), # Normalized values in [-1, 1] } transition = create_transition(observation=observation) @@ -719,13 +718,13 @@ def test_identity_unnormalization_observations(): unnormalized_obs = unnormalized_transition[TransitionKey.OBSERVATION] # Image should remain unchanged (IDENTITY) - assert torch.allclose(unnormalized_obs["observation.image"], observation["observation.image"]) + assert torch.allclose(unnormalized_obs[OBS_IMAGE], observation[OBS_IMAGE]) # State should be unnormalized (MIN_MAX) # (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = 0.0 # (-1.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = -1.0 expected_state = torch.tensor([0.0, -1.0]) - assert torch.allclose(unnormalized_obs["observation.state"], expected_state) + assert torch.allclose(unnormalized_obs[OBS_STATE], expected_state) def test_identity_unnormalization_actions(): @@ -748,7 +747,7 @@ def test_identity_unnormalization_actions(): def test_identity_with_missing_stats(): """Test that IDENTITY mode works even when stats are missing.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -760,7 +759,7 @@ def test_identity_with_missing_stats(): normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) - observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + observation = {OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3])} action = torch.tensor([1.0, -0.5]) transition = create_transition(observation=observation, action=action) @@ -769,13 +768,13 @@ def test_identity_with_missing_stats(): unnormalized_transition = unnormalizer(transition) assert torch.allclose( - normalized_transition[TransitionKey.OBSERVATION]["observation.image"], - observation["observation.image"], + normalized_transition[TransitionKey.OBSERVATION][OBS_IMAGE], + observation[OBS_IMAGE], ) assert torch.allclose(normalized_transition[TransitionKey.ACTION], action) assert torch.allclose( - unnormalized_transition[TransitionKey.OBSERVATION]["observation.image"], - observation["observation.image"], + unnormalized_transition[TransitionKey.OBSERVATION][OBS_IMAGE], + observation[OBS_IMAGE], ) assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action) @@ -783,8 +782,8 @@ def test_identity_with_missing_stats(): def test_identity_mixed_with_other_modes(): """Test IDENTITY mode mixed with other normalization modes.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -793,16 +792,16 @@ def test_identity_mixed_with_other_modes(): FeatureType.ACTION: NormalizationMode.MIN_MAX, } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored - "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored + OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([1.0, -0.5]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([1.0, -0.5]), } action = torch.tensor([0.5, 0.0]) transition = create_transition(observation=observation, action=action) @@ -812,11 +811,11 @@ def test_identity_mixed_with_other_modes(): normalized_action = normalized_transition[TransitionKey.ACTION] # Image should remain unchanged (IDENTITY) - assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + assert torch.allclose(normalized_obs[OBS_IMAGE], observation[OBS_IMAGE]) # State should be normalized (MEAN_STD) expected_state = torch.tensor([1.0, -0.5]) # (x - 0) / 1 = x - assert torch.allclose(normalized_obs["observation.state"], expected_state) + assert torch.allclose(normalized_obs[OBS_STATE], expected_state) # Action should be normalized (MIN_MAX) to [-1, 1] # 2 * (0.5 - (-1)) / (1 - (-1)) - 1 = 2 * 1.5 / 2 - 1 = 0.5 @@ -828,23 +827,23 @@ def test_identity_mixed_with_other_modes(): def test_identity_defaults_when_not_in_norm_map(): """Test that IDENTITY is used as default when feature type not in norm_map.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), } norm_map = { FeatureType.STATE: NormalizationMode.MEAN_STD, # VISUAL not specified, should default to IDENTITY } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, - "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([1.0, -0.5]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([1.0, -0.5]), } transition = create_transition(observation=observation) @@ -852,17 +851,17 @@ def test_identity_defaults_when_not_in_norm_map(): normalized_obs = normalized_transition[TransitionKey.OBSERVATION] # Image should remain unchanged (defaults to IDENTITY) - assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + assert torch.allclose(normalized_obs[OBS_IMAGE], observation[OBS_IMAGE]) # State should be normalized (explicitly MEAN_STD) expected_state = torch.tensor([1.0, -0.5]) - assert torch.allclose(normalized_obs["observation.state"], expected_state) + assert torch.allclose(normalized_obs[OBS_STATE], expected_state) def test_identity_roundtrip(): """Test that IDENTITY normalization and unnormalization are true inverses.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -870,14 +869,14 @@ def test_identity_roundtrip(): FeatureType.ACTION: NormalizationMode.IDENTITY, } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) - original_observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + original_observation = {OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3])} original_action = torch.tensor([0.5, -0.2]) original_transition = create_transition(observation=original_observation, action=original_action) @@ -886,16 +885,14 @@ def test_identity_roundtrip(): roundtrip = unnormalizer(normalized) # Should be identical to original - assert torch.allclose( - roundtrip[TransitionKey.OBSERVATION]["observation.image"], original_observation["observation.image"] - ) + assert torch.allclose(roundtrip[TransitionKey.OBSERVATION][OBS_IMAGE], original_observation[OBS_IMAGE]) assert torch.allclose(roundtrip[TransitionKey.ACTION], original_action) def test_identity_config_serialization(): """Test that IDENTITY mode is properly saved and loaded in config.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -903,7 +900,7 @@ def test_identity_config_serialization(): FeatureType.ACTION: NormalizationMode.MEAN_STD, } stats = { - "observation.image": {"mean": [0.5], "std": [0.2]}, + OBS_IMAGE: {"mean": [0.5], "std": [0.2]}, "action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, } @@ -925,7 +922,7 @@ def test_identity_config_serialization(): ) # Test that both work the same way - observation = {"observation.image": torch.tensor([0.7])} + observation = {OBS_IMAGE: torch.tensor([0.7])} action = torch.tensor([1.0, -0.5]) transition = create_transition(observation=observation, action=action) @@ -934,15 +931,15 @@ def test_identity_config_serialization(): # Results should be identical assert torch.allclose( - result1[TransitionKey.OBSERVATION]["observation.image"], - result2[TransitionKey.OBSERVATION]["observation.image"], + result1[TransitionKey.OBSERVATION][OBS_IMAGE], + result2[TransitionKey.OBSERVATION][OBS_IMAGE], ) assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) # def test_unsupported_normalization_mode_error(): # """Test that unsupported normalization modes raise appropriate errors.""" -# features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))} +# features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (2,))} # # Create an invalid norm_map (this would never happen in practice, but tests error handling) # from enum import Enum @@ -953,14 +950,14 @@ def test_identity_config_serialization(): # # We can't actually pass an invalid enum to the processor due to type checking, # # but we can test the error by manipulating the norm_map after creation # norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} -# stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}} +# stats = {OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}} # normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) # # Manually inject an invalid mode to test error handling # normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE" -# observation = {"observation.state": torch.tensor([1.0, -0.5])} +# observation = {OBS_STATE: torch.tensor([1.0, -0.5])} # transition = create_transition(observation=observation) # with pytest.raises(ValueError, match="Unsupported normalization mode"): @@ -971,19 +968,19 @@ def test_hotswap_stats_basic_functionality(): """Test that hotswap_stats correctly updates stats in normalizer/unnormalizer steps.""" # Create initial stats initial_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # Create new stats for hotswapping new_stats = { - "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } # Create features and norm_map features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1021,15 +1018,15 @@ def test_hotswap_stats_basic_functionality(): def test_hotswap_stats_deep_copy(): """Test that hotswap_stats creates a deep copy and doesn't modify the original processor.""" initial_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, } new_stats = { - "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} @@ -1060,15 +1057,15 @@ def test_hotswap_stats_deep_copy(): def test_hotswap_stats_only_affects_normalizer_steps(): """Test that hotswap_stats only modifies NormalizerProcessorStep and UnnormalizerProcessorStep steps.""" stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } new_stats = { - "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + OBS_IMAGE: {"mean": np.array([0.3]), "std": np.array([0.1])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} @@ -1093,13 +1090,13 @@ def test_hotswap_stats_only_affects_normalizer_steps(): def test_hotswap_stats_empty_stats(): """Test hotswap_stats with empty stats dictionary.""" initial_stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } empty_stats = {} features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} @@ -1117,7 +1114,7 @@ def test_hotswap_stats_empty_stats(): def test_hotswap_stats_no_normalizer_steps(): """Test hotswap_stats with a processor that has no normalizer/unnormalizer steps.""" stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } # Create processor with only identity steps @@ -1139,18 +1136,18 @@ def test_hotswap_stats_no_normalizer_steps(): def test_hotswap_stats_preserves_other_attributes(): """Test that hotswap_stats preserves other processor attributes like features and norm_map.""" initial_stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } new_stats = { - "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + OBS_IMAGE: {"mean": np.array([0.3]), "std": np.array([0.1])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalize_observation_keys = {"observation.image"} + normalize_observation_keys = {OBS_IMAGE} eps = 1e-6 normalizer = NormalizerProcessorStep( @@ -1179,17 +1176,17 @@ def test_hotswap_stats_preserves_other_attributes(): def test_hotswap_stats_multiple_normalizer_types(): """Test hotswap_stats with multiple normalizer and unnormalizer steps.""" initial_stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, "action": {"min": np.array([-1.0]), "max": np.array([1.0])}, } new_stats = { - "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + OBS_IMAGE: {"mean": np.array([0.3]), "std": np.array([0.1])}, "action": {"min": np.array([-2.0]), "max": np.array([2.0])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), } norm_map = { @@ -1224,12 +1221,12 @@ def test_hotswap_stats_multiple_normalizer_types(): def test_hotswap_stats_with_different_data_types(): """Test hotswap_stats with various data types in stats.""" initial_stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } # New stats with different data types (int, float, list, tuple) new_stats = { - "observation.image": { + OBS_IMAGE: { "mean": [0.3, 0.4, 0.5], # list "std": (0.1, 0.2, 0.3), # tuple "min": 0, # int @@ -1242,7 +1239,7 @@ def test_hotswap_stats_with_different_data_types(): } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1261,43 +1258,43 @@ def test_hotswap_stats_with_different_data_types(): # Check that tensor conversion worked correctly tensor_stats = new_processor.steps[0]._tensor_stats - assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) - assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) - assert isinstance(tensor_stats["observation.image"]["min"], torch.Tensor) - assert isinstance(tensor_stats["observation.image"]["max"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["mean"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["std"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["min"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["max"], torch.Tensor) assert isinstance(tensor_stats["action"]["mean"], torch.Tensor) assert isinstance(tensor_stats["action"]["std"], torch.Tensor) # Check values - torch.testing.assert_close(tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.4, 0.5])) - torch.testing.assert_close(tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2, 0.3])) - torch.testing.assert_close(tensor_stats["observation.image"]["min"], torch.tensor(0.0)) - torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0)) + torch.testing.assert_close(tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.3, 0.4, 0.5])) + torch.testing.assert_close(tensor_stats[OBS_IMAGE]["std"], torch.tensor([0.1, 0.2, 0.3])) + torch.testing.assert_close(tensor_stats[OBS_IMAGE]["min"], torch.tensor(0.0)) + torch.testing.assert_close(tensor_stats[OBS_IMAGE]["max"], torch.tensor(1.0)) def test_hotswap_stats_functional_test(): """Test that hotswapped processor actually works functionally.""" # Create test data observation = { - "observation.image": torch.tensor([[[0.6, 0.7], [0.8, 0.9]], [[0.5, 0.6], [0.7, 0.8]]]), + OBS_IMAGE: torch.tensor([[[0.6, 0.7], [0.8, 0.9]], [[0.5, 0.6], [0.7, 0.8]]]), } action = torch.tensor([0.5, -0.5]) transition = create_transition(observation=observation, action=action) # Initial stats initial_stats = { - "observation.image": {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # New stats new_stats = { - "observation.image": {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])}, "action": {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1322,8 +1319,8 @@ def test_hotswap_stats_functional_test(): # Results should be different since normalization changed assert not torch.allclose( - original_result["observation"]["observation.image"], - new_result["observation"]["observation.image"], + original_result[OBS_STR][OBS_IMAGE], + new_result[OBS_STR][OBS_IMAGE], rtol=1e-3, atol=1e-3, ) @@ -1331,60 +1328,54 @@ def test_hotswap_stats_functional_test(): # Verify that the new processor is actually using the new stats by checking internal state assert new_processor.steps[0].stats == new_stats - assert torch.allclose( - new_processor.steps[0]._tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.2]) - ) - assert torch.allclose( - new_processor.steps[0]._tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2]) - ) + assert torch.allclose(new_processor.steps[0]._tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.3, 0.2])) + assert torch.allclose(new_processor.steps[0]._tensor_stats[OBS_IMAGE]["std"], torch.tensor([0.1, 0.2])) assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["mean"], torch.tensor([0.1, -0.1])) assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["std"], torch.tensor([0.5, 0.5])) # Test that normalization actually happens (output should not equal input) - assert not torch.allclose( - new_result["observation"]["observation.image"], observation["observation.image"] - ) + assert not torch.allclose(new_result[OBS_STR][OBS_IMAGE], observation[OBS_IMAGE]) assert not torch.allclose(new_result["action"], action) def test_zero_std_uses_eps(): """When std == 0, (x-mean)/(std+eps) is well-defined; x==mean should map to 0.""" - features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": np.array([0.5]), "std": np.array([0.0])}} + stats = {OBS_STATE: {"mean": np.array([0.5]), "std": np.array([0.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6) - observation = {"observation.state": torch.tensor([0.5])} # equals mean + observation = {OBS_STATE: torch.tensor([0.5])} # equals mean out = normalizer(create_transition(observation=observation)) - assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([0.0])) + assert torch.allclose(out[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([0.0])) def test_min_equals_max_maps_to_minus_one(): """When min == max, MIN_MAX path maps to -1 after [-1,1] scaling for x==min.""" - features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MIN_MAX} - stats = {"observation.state": {"min": np.array([2.0]), "max": np.array([2.0])}} + stats = {OBS_STATE: {"min": np.array([2.0]), "max": np.array([2.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6) - observation = {"observation.state": torch.tensor([2.0])} + observation = {OBS_STATE: torch.tensor([2.0])} out = normalizer(create_transition(observation=observation)) - assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([-1.0])) + assert torch.allclose(out[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([-1.0])) def test_action_normalized_despite_normalize_observation_keys(): """Action normalization is independent of normalize_observation_keys filter for observations.""" features = { - "observation.state": PolicyFeature(FeatureType.STATE, (1,)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (1,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD} stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} normalizer = NormalizerProcessorStep( - features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={"observation.state"} + features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={OBS_STATE} ) transition = create_transition( - observation={"observation.state": torch.tensor([3.0])}, action=torch.tensor([3.0, 3.0]) + observation={OBS_STATE: torch.tensor([3.0])}, action=torch.tensor([3.0, 3.0]) ) out = normalizer(transition) # (3-1)/2 = 1.0 ; (3-(-1))/4 = 1.0 @@ -1421,12 +1412,12 @@ def test_unnormalize_observations_mean_std_and_min_max(): def test_unknown_observation_keys_ignored(): - features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}} + stats = {OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) - obs = {"observation.state": torch.tensor([1.0]), "observation.unknown": torch.tensor([5.0])} + obs = {OBS_STATE: torch.tensor([1.0]), "observation.unknown": torch.tensor([5.0])} tr = create_transition(observation=obs) out = normalizer(tr) @@ -1447,13 +1438,13 @@ def test_batched_action_normalization(): def test_complementary_data_preservation(): - features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}} + stats = {OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) comp = {"existing": 123} - tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp) + tr = create_transition(observation={OBS_STATE: torch.tensor([1.0])}, complementary_data=comp) out = normalizer(tr) new_comp = out[TransitionKey.COMPLEMENTARY_DATA] assert new_comp["existing"] == 123 @@ -1461,36 +1452,34 @@ def test_complementary_data_preservation(): def test_roundtrip_normalize_unnormalize_non_identity(): features = { - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MIN_MAX} stats = { - "observation.state": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}, + OBS_STATE: {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}, "action": {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) # Add a time dimension in action for broadcasting check (B,T,D) - obs = {"observation.state": torch.tensor([[3.0, 3.0], [1.0, -1.0]])} + obs = {OBS_STATE: torch.tensor([[3.0, 3.0], [1.0, -1.0]])} act = torch.tensor([[[0.0, -1.0], [1.0, 1.0]]]) # shape (1,2,2) already in [-1,1] tr = create_transition(observation=obs, action=act) out = unnormalizer(normalizer(tr)) - assert torch.allclose( - out[TransitionKey.OBSERVATION]["observation.state"], obs["observation.state"], atol=1e-5 - ) + assert torch.allclose(out[TransitionKey.OBSERVATION][OBS_STATE], obs[OBS_STATE], atol=1e-5) assert torch.allclose(out[TransitionKey.ACTION], act, atol=1e-5) def test_dtype_adaptation_bfloat16_input_float32_normalizer(): """Test automatic dtype adaptation: NormalizerProcessor(float32) adapts to bfloat16 input → bfloat16 output""" - features = {"observation.state": PolicyFeature(FeatureType.STATE, (5,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (5,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} stats = { - "observation.state": { + OBS_STATE: { "mean": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), } @@ -1503,11 +1492,11 @@ def test_dtype_adaptation_bfloat16_input_float32_normalizer(): # Verify initial configuration assert normalizer.dtype == torch.float32 - for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + for stat_tensor in normalizer._tensor_stats[OBS_STATE].values(): assert stat_tensor.dtype == torch.float32 # Create bfloat16 input tensor - observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)} + observation = {OBS_STATE: torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)} transition = create_transition(observation=observation) # Process the transition @@ -1516,11 +1505,11 @@ def test_dtype_adaptation_bfloat16_input_float32_normalizer(): # Verify that: # 1. Stats were automatically adapted to bfloat16 assert normalizer.dtype == torch.bfloat16 - for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + for stat_tensor in normalizer._tensor_stats[OBS_STATE].values(): assert stat_tensor.dtype == torch.bfloat16 # 2. Output is in bfloat16 - output_tensor = result[TransitionKey.OBSERVATION]["observation.state"] + output_tensor = result[TransitionKey.OBSERVATION][OBS_STATE] assert output_tensor.dtype == torch.bfloat16 # 3. Normalization was applied correctly (mean should be close to original - mean) / std @@ -1540,18 +1529,18 @@ def test_stats_override_preservation_in_load_state_dict(): """ # Create original stats original_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # Create override stats (what user wants to use) override_stats = { - "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1611,12 +1600,12 @@ def test_stats_without_override_loads_normally(): load_state_dict works as before. """ original_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1651,12 +1640,12 @@ def test_stats_without_override_loads_normally(): def test_stats_explicit_provided_flag_detection(): """Test that the _stats_explicitly_provided flag is set correctly in different scenarios.""" features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} # Test 1: Explicitly provided stats (non-empty dict) - stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + stats = {OBS_IMAGE: {"mean": [0.5], "std": [0.2]}} normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) assert normalizer1._stats_explicitly_provided is True @@ -1684,7 +1673,7 @@ def test_pipeline_from_pretrained_with_stats_overrides(): # Create test data features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1693,12 +1682,12 @@ def test_pipeline_from_pretrained_with_stats_overrides(): } original_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } override_stats = { - "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } @@ -1740,7 +1729,7 @@ def test_pipeline_from_pretrained_with_stats_overrides(): # Test that the override stats are actually used in processing observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), } action = torch.tensor([1.0, -0.5]) transition = create_transition(observation=observation, action=action) @@ -1770,9 +1759,9 @@ def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): """Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output""" from lerobot.processor import DeviceProcessorStep - features = {"observation.state": PolicyFeature(FeatureType.STATE, (3,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (3,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": np.array([0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0])}} + stats = {OBS_STATE: {"mean": np.array([0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0])}} # Create pipeline: DeviceProcessor(bfloat16) → NormalizerProcessor(float32) device_processor = DeviceProcessorStep(device=str(auto_select_torch_device()), float_dtype="bfloat16") @@ -1784,18 +1773,18 @@ def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): assert normalizer.dtype == torch.float32 # Create CPU input - observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)} + observation = {OBS_STATE: torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)} transition = create_transition(observation=observation) # Step 1: DeviceProcessor converts to bfloat16 + moves to CUDA processed_1 = device_processor(transition) - intermediate_tensor = processed_1[TransitionKey.OBSERVATION]["observation.state"] + intermediate_tensor = processed_1[TransitionKey.OBSERVATION][OBS_STATE] assert intermediate_tensor.dtype == torch.bfloat16 assert intermediate_tensor.device.type == str(auto_select_torch_device()) # Step 2: NormalizerProcessor receives bfloat16 input and adapts final_result = normalizer(processed_1) - final_tensor = final_result[TransitionKey.OBSERVATION]["observation.state"] + final_tensor = final_result[TransitionKey.OBSERVATION][OBS_STATE] # Verify final output is bfloat16 (automatic adaptation worked) assert final_tensor.dtype == torch.bfloat16 @@ -1803,7 +1792,7 @@ def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): # Verify normalizer adapted its internal state assert normalizer.dtype == torch.bfloat16 - for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + for stat_tensor in normalizer._tensor_stats[OBS_STATE].values(): assert stat_tensor.dtype == torch.bfloat16 assert stat_tensor.device.type == str(auto_select_torch_device()) @@ -1821,8 +1810,8 @@ def test_stats_reconstruction_after_load_state_dict(): # Create normalizer with stats features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -1831,11 +1820,11 @@ def test_stats_reconstruction_after_load_state_dict(): FeatureType.ACTION: NormalizationMode.MEAN_STD, } stats = { - "observation.image": { + OBS_IMAGE: { "mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2]), }, - "observation.state": { + OBS_STATE: { "min": np.array([0.0, -1.0]), "max": np.array([1.0, 1.0]), }, @@ -1861,15 +1850,15 @@ def test_stats_reconstruction_after_load_state_dict(): assert new_normalizer.stats != {} # Check that all expected keys are present - assert "observation.image" in new_normalizer.stats - assert "observation.state" in new_normalizer.stats + assert OBS_IMAGE in new_normalizer.stats + assert OBS_STATE in new_normalizer.stats assert "action" in new_normalizer.stats # Check that values are correct (converted back from tensors) - np.testing.assert_allclose(new_normalizer.stats["observation.image"]["mean"], [0.5, 0.5, 0.5]) - np.testing.assert_allclose(new_normalizer.stats["observation.image"]["std"], [0.2, 0.2, 0.2]) - np.testing.assert_allclose(new_normalizer.stats["observation.state"]["min"], [0.0, -1.0]) - np.testing.assert_allclose(new_normalizer.stats["observation.state"]["max"], [1.0, 1.0]) + np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"], [0.5, 0.5, 0.5]) + np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"], [0.2, 0.2, 0.2]) + np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0]) + np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0]) np.testing.assert_allclose(new_normalizer.stats["action"]["mean"], [0.0, 0.0]) np.testing.assert_allclose(new_normalizer.stats["action"]["std"], [1.0, 2.0]) @@ -1885,8 +1874,8 @@ def test_stats_reconstruction_after_load_state_dict(): # Test 2: hotswap_stats should work new_stats = { - "observation.image": {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]}, - "observation.state": {"min": [-1.0, -2.0], "max": [2.0, 2.0]}, + OBS_IMAGE: {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]}, + OBS_STATE: {"min": [-1.0, -2.0], "max": [2.0, 2.0]}, "action": {"mean": [0.1, 0.1], "std": [0.5, 0.5]}, } @@ -1900,8 +1889,8 @@ def test_stats_reconstruction_after_load_state_dict(): # Test 3: The normalizer should work functionally the same as the original observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) transition = create_transition(observation=observation, action=action) @@ -1911,11 +1900,11 @@ def test_stats_reconstruction_after_load_state_dict(): # Results should be identical (within floating point precision) torch.testing.assert_close( - original_result[TransitionKey.OBSERVATION]["observation.image"], - new_result[TransitionKey.OBSERVATION]["observation.image"], + original_result[TransitionKey.OBSERVATION][OBS_IMAGE], + new_result[TransitionKey.OBSERVATION][OBS_IMAGE], ) torch.testing.assert_close( - original_result[TransitionKey.OBSERVATION]["observation.state"], - new_result[TransitionKey.OBSERVATION]["observation.state"], + original_result[TransitionKey.OBSERVATION][OBS_STATE], + new_result[TransitionKey.OBSERVATION][OBS_STATE], ) torch.testing.assert_close(original_result[TransitionKey.ACTION], new_result[TransitionKey.ACTION]) diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index 6abc9ede..11b58a66 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -39,8 +39,8 @@ def test_process_single_image(): processed_obs = result[TransitionKey.OBSERVATION] # Check that the image was processed correctly - assert "observation.image" in processed_obs - processed_img = processed_obs["observation.image"] + assert OBS_IMAGE in processed_obs + processed_img = processed_obs[OBS_IMAGE] # Check shape: should be (1, 3, 64, 64) - batch, channels, height, width assert processed_img.shape == (1, 3, 64, 64) @@ -66,12 +66,12 @@ def test_process_image_dict(): processed_obs = result[TransitionKey.OBSERVATION] # Check that both images were processed - assert "observation.images.camera1" in processed_obs - assert "observation.images.camera2" in processed_obs + assert f"{OBS_IMAGES}.camera1" in processed_obs + assert f"{OBS_IMAGES}.camera2" in processed_obs # Check shapes - assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32) - assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48) + assert processed_obs[f"{OBS_IMAGES}.camera1"].shape == (1, 3, 32, 32) + assert processed_obs[f"{OBS_IMAGES}.camera2"].shape == (1, 3, 48, 48) def test_process_batched_image(): @@ -88,7 +88,7 @@ def test_process_batched_image(): processed_obs = result[TransitionKey.OBSERVATION] # Check that batch dimension is preserved - assert processed_obs["observation.image"].shape == (2, 3, 64, 64) + assert processed_obs[OBS_IMAGE].shape == (2, 3, 64, 64) def test_invalid_image_format(): @@ -173,10 +173,10 @@ def test_process_environment_state(): processed_obs = result[TransitionKey.OBSERVATION] # Check that environment_state was renamed and processed - assert "observation.environment_state" in processed_obs + assert OBS_ENV_STATE in processed_obs assert "environment_state" not in processed_obs - processed_state = processed_obs["observation.environment_state"] + processed_state = processed_obs[OBS_ENV_STATE] assert processed_state.shape == (1, 3) # Batch dimension added assert processed_state.dtype == torch.float32 torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]])) @@ -194,10 +194,10 @@ def test_process_agent_pos(): processed_obs = result[TransitionKey.OBSERVATION] # Check that agent_pos was renamed and processed - assert "observation.state" in processed_obs + assert OBS_STATE in processed_obs assert "agent_pos" not in processed_obs - processed_state = processed_obs["observation.state"] + processed_state = processed_obs[OBS_STATE] assert processed_state.shape == (1, 3) # Batch dimension added assert processed_state.dtype == torch.float32 torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]])) @@ -217,8 +217,8 @@ def test_process_batched_states(): processed_obs = result[TransitionKey.OBSERVATION] # Check that batch dimensions are preserved - assert processed_obs["observation.environment_state"].shape == (2, 2) - assert processed_obs["observation.state"].shape == (2, 2) + assert processed_obs[OBS_ENV_STATE].shape == (2, 2) + assert processed_obs[OBS_STATE].shape == (2, 2) def test_process_both_states(): @@ -235,8 +235,8 @@ def test_process_both_states(): processed_obs = result[TransitionKey.OBSERVATION] # Check that both states were processed - assert "observation.environment_state" in processed_obs - assert "observation.state" in processed_obs + assert OBS_ENV_STATE in processed_obs + assert OBS_STATE in processed_obs # Check that original keys were removed assert "environment_state" not in processed_obs @@ -281,12 +281,12 @@ def test_complete_observation_processing(): processed_obs = result[TransitionKey.OBSERVATION] # Check that image was processed - assert "observation.image" in processed_obs - assert processed_obs["observation.image"].shape == (1, 3, 32, 32) + assert OBS_IMAGE in processed_obs + assert processed_obs[OBS_IMAGE].shape == (1, 3, 32, 32) # Check that states were processed - assert "observation.environment_state" in processed_obs - assert "observation.state" in processed_obs + assert OBS_ENV_STATE in processed_obs + assert OBS_STATE in processed_obs # Check that original keys were removed assert "pixels" not in processed_obs @@ -308,7 +308,7 @@ def test_image_only_processing(): result = processor(transition) processed_obs = result[TransitionKey.OBSERVATION] - assert "observation.image" in processed_obs + assert OBS_IMAGE in processed_obs assert len(processed_obs) == 1 @@ -323,7 +323,7 @@ def test_state_only_processing(): result = processor(transition) processed_obs = result[TransitionKey.OBSERVATION] - assert "observation.state" in processed_obs + assert OBS_STATE in processed_obs assert "agent_pos" not in processed_obs @@ -504,7 +504,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory): proc = VanillaObservationProcessorStep() features = { PipelineFeatureType.OBSERVATION: { - "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), + OBS_ENV_STATE: policy_feature_factory(FeatureType.STATE, (2,)), "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), }, } @@ -513,7 +513,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory): assert ( OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION] and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE] - == features[PipelineFeatureType.OBSERVATION]["observation.environment_state"] + == features[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE] ) assert ( OBS_STATE in out[PipelineFeatureType.OBSERVATION] diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 0d17fed0..6d056e4d 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -35,6 +35,7 @@ from lerobot.processor import ( TransitionKey, ) from lerobot.processor.converters import create_transition, identity_transition +from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed @@ -255,7 +256,7 @@ def test_step_through_with_dict(): pipeline = DataProcessorPipeline([step1, step2]) batch = { - "observation.image": None, + OBS_IMAGE: None, "action": None, "next.reward": 0.0, "next.done": False, @@ -1840,7 +1841,7 @@ def test_save_load_with_custom_converter_functions(): # Verify it uses default converters by checking with standard batch format batch = { - "observation.image": torch.randn(1, 3, 32, 32), + OBS_IMAGE: torch.randn(1, 3, 32, 32), "action": torch.randn(1, 7), "next.reward": torch.tensor([1.0]), "next.done": torch.tensor([False]), @@ -1851,7 +1852,7 @@ def test_save_load_with_custom_converter_functions(): # Should work with standard format (wouldn't work with custom converter) result = loaded(batch) # With new behavior, default to_output is _default_transition_to_batch, so result is batch dict - assert "observation.image" in result + assert OBS_IMAGE in result class NonCompliantStep: @@ -2075,10 +2076,10 @@ class AddObservationStateFeatures(ProcessorStep): self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: # State features (mix EE and a joint state) - features[PipelineFeatureType.OBSERVATION]["observation.state.ee.x"] = float - features[PipelineFeatureType.OBSERVATION]["observation.state.j1.pos"] = float + features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.ee.x"] = float + features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.j1.pos"] = float if self.add_front_image: - features[PipelineFeatureType.OBSERVATION]["observation.images.front"] = self.front_image_shape + features[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"] = self.front_image_shape return features @@ -2094,7 +2095,7 @@ def test_aggregate_joint_action_only(): ) # Expect only "action" with joint names - assert "action" in out and "observation.state" not in out + assert "action" in out and OBS_STATE not in out assert out["action"]["dtype"] == "float32" assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"} assert out["action"]["shape"] == (len(out["action"]["names"]),) @@ -2108,7 +2109,7 @@ def test_aggregate_ee_action_and_observation_with_videos(): pipeline=rp, initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}}, use_videos=True, - patterns=["action.ee", "observation.state"], + patterns=["action.ee", OBS_STATE], ) # Action should pack only EE names @@ -2117,13 +2118,13 @@ def test_aggregate_ee_action_and_observation_with_videos(): assert out["action"]["dtype"] == "float32" # Observation state should pack both ee.x and j1.pos as a vector - assert "observation.state" in out - assert set(out["observation.state"]["names"]) == {"ee.x", "j1.pos"} - assert out["observation.state"]["dtype"] == "float32" + assert OBS_STATE in out + assert set(out[OBS_STATE]["names"]) == {"ee.x", "j1.pos"} + assert out[OBS_STATE]["dtype"] == "float32" # Cameras from initial_features appear as videos for cam in ("front", "side"): - key = f"observation.images.{cam}" + key = f"{OBS_IMAGES}.{cam}" assert key in out assert out[key]["dtype"] == "video" assert out[key]["shape"] == initial[cam] @@ -2156,8 +2157,8 @@ def test_aggregate_images_when_use_videos_false(): patterns=None, ) - key = "observation.images.back" - key_front = "observation.images.front" + key = f"{OBS_IMAGES}.back" + key_front = f"{OBS_IMAGES}.front" assert key not in out assert key_front not in out @@ -2173,8 +2174,8 @@ def test_aggregate_images_when_use_videos_true(): patterns=None, ) - key = "observation.images.front" - key_back = "observation.images.back" + key = f"{OBS_IMAGES}.front" + key_back = f"{OBS_IMAGES}.back" assert key in out assert key_back in out assert out[key]["dtype"] == "video" @@ -2194,9 +2195,9 @@ def test_initial_camera_not_overridden_by_step_image(): pipeline=rp, initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial}, use_videos=True, - patterns=["observation.images.front"], + patterns=[f"{OBS_IMAGES}.front"], ) - key = "observation.images.front" + key = f"{OBS_IMAGES}.front" assert key in out assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index 5f2b4857..c6aa303f 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -28,6 +28,7 @@ from lerobot.processor import ( ) from lerobot.processor.converters import create_transition, identity_transition from lerobot.processor.rename_processor import rename_stats +from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed @@ -121,13 +122,13 @@ def test_overlapping_rename(): def test_partial_rename(): """Test renaming only some keys.""" rename_map = { - "observation.state": "observation.proprio_state", - "pixels": "observation.image", + OBS_STATE: "observation.proprio_state", + "pixels": OBS_IMAGE, } processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { - "observation.state": torch.randn(10), + OBS_STATE: torch.randn(10), "pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8), "reward": 1.0, "info": {"episode": 1}, @@ -139,8 +140,8 @@ def test_partial_rename(): # Check renamed keys assert "observation.proprio_state" in processed_obs - assert "observation.image" in processed_obs - assert "observation.state" not in processed_obs + assert OBS_IMAGE in processed_obs + assert OBS_STATE not in processed_obs assert "pixels" not in processed_obs # Check unchanged keys @@ -174,8 +175,8 @@ def test_state_dict(): def test_integration_with_robot_processor(): """Test integration with RobotProcessor pipeline.""" rename_map = { - "agent_pos": "observation.state", - "pixels": "observation.image", + "agent_pos": OBS_STATE, + "pixels": OBS_IMAGE, } rename_processor = RenameObservationsProcessorStep(rename_map=rename_map) @@ -196,8 +197,8 @@ def test_integration_with_robot_processor(): processed_obs = result[TransitionKey.OBSERVATION] # Check renaming worked through pipeline - assert "observation.state" in processed_obs - assert "observation.image" in processed_obs + assert OBS_STATE in processed_obs + assert OBS_IMAGE in processed_obs assert "agent_pos" not in processed_obs assert "pixels" not in processed_obs assert processed_obs["other_data"] == "preserve_me" @@ -210,8 +211,8 @@ def test_integration_with_robot_processor(): def test_save_and_load_pretrained(): """Test saving and loading processor with RobotProcessor.""" rename_map = { - "old_state": "observation.state", - "old_image": "observation.image", + "old_state": OBS_STATE, + "old_image": OBS_IMAGE, } processor = RenameObservationsProcessorStep(rename_map=rename_map) pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep") @@ -253,10 +254,10 @@ def test_save_and_load_pretrained(): result = loaded_pipeline(transition) processed_obs = result[TransitionKey.OBSERVATION] - assert "observation.state" in processed_obs - assert "observation.image" in processed_obs - assert processed_obs["observation.state"] == [1, 2, 3] - assert processed_obs["observation.image"] == "image_data" + assert OBS_STATE in processed_obs + assert OBS_IMAGE in processed_obs + assert processed_obs[OBS_STATE] == [1, 2, 3] + assert processed_obs[OBS_IMAGE] == "image_data" def test_registry_functionality(): @@ -317,8 +318,8 @@ def test_chained_rename_processors(): # Second processor: rename to final format processor2 = RenameObservationsProcessorStep( rename_map={ - "agent_position": "observation.state", - "camera_image": "observation.image", + "agent_position": OBS_STATE, + "camera_image": OBS_IMAGE, } ) @@ -342,8 +343,8 @@ def test_chained_rename_processors(): # After second processor final_obs = results[2][TransitionKey.OBSERVATION] - assert "observation.state" in final_obs - assert "observation.image" in final_obs + assert OBS_STATE in final_obs + assert OBS_IMAGE in final_obs assert final_obs["extra"] == "keep_me" # Original keys should be gone @@ -356,15 +357,15 @@ def test_chained_rename_processors(): def test_nested_observation_rename(): """Test renaming with nested observation structures.""" rename_map = { - "observation.images.left": "observation.camera.left_view", - "observation.images.right": "observation.camera.right_view", + f"{OBS_IMAGES}.left": "observation.camera.left_view", + f"{OBS_IMAGES}.right": "observation.camera.right_view", "observation.proprio": "observation.proprioception", } processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { - "observation.images.left": torch.randn(3, 64, 64), - "observation.images.right": torch.randn(3, 64, 64), + f"{OBS_IMAGES}.left": torch.randn(3, 64, 64), + f"{OBS_IMAGES}.right": torch.randn(3, 64, 64), "observation.proprio": torch.randn(7), "observation.gripper": torch.tensor([0.0]), # Not renamed } @@ -382,8 +383,8 @@ def test_nested_observation_rename(): assert "observation.gripper" in processed_obs # Check old keys removed - assert "observation.images.left" not in processed_obs - assert "observation.images.right" not in processed_obs + assert f"{OBS_IMAGES}.left" not in processed_obs + assert f"{OBS_IMAGES}.right" not in processed_obs assert "observation.proprio" not in processed_obs @@ -464,7 +465,7 @@ def test_features_chained_processors(policy_feature_factory): # Chain two rename processors at the contract level processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"}) processor2 = RenameObservationsProcessorStep( - rename_map={"agent_position": "observation.state", "camera_image": "observation.image"} + rename_map={"agent_position": OBS_STATE, "camera_image": OBS_IMAGE} ) pipeline = DataProcessorPipeline([processor1, processor2]) @@ -477,27 +478,21 @@ def test_features_chained_processors(policy_feature_factory): } out = pipeline.transform_features(initial_features=spec) - assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"} - assert ( - out[PipelineFeatureType.OBSERVATION]["observation.state"] - == spec[PipelineFeatureType.OBSERVATION]["pos"] - ) - assert ( - out[PipelineFeatureType.OBSERVATION]["observation.image"] - == spec[PipelineFeatureType.OBSERVATION]["img"] - ) + assert set(out[PipelineFeatureType.OBSERVATION]) == {OBS_STATE, OBS_IMAGE, "extra"} + assert out[PipelineFeatureType.OBSERVATION][OBS_STATE] == spec[PipelineFeatureType.OBSERVATION]["pos"] + assert out[PipelineFeatureType.OBSERVATION][OBS_IMAGE] == spec[PipelineFeatureType.OBSERVATION]["img"] assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"] assert_contract_is_typed(out) def test_rename_stats_basic(): orig = { - "observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}, + OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])}, "action": {"mean": np.array([0.0])}, } - mapping = {"observation.state": "observation.robot_state"} + mapping = {OBS_STATE: "observation.robot_state"} renamed = rename_stats(orig, mapping) - assert "observation.robot_state" in renamed and "observation.state" not in renamed + assert "observation.robot_state" in renamed and OBS_STATE not in renamed # Ensure deep copy: mutate original and verify renamed unaffected - orig["observation.state"]["mean"][0] = 42.0 + orig[OBS_STATE]["mean"][0] = 42.0 assert renamed["observation.robot_state"]["mean"][0] != 42.0 diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 9e6c8de2..35bbcfd8 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -11,7 +11,7 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import OBS_LANGUAGE +from lerobot.utils.constants import OBS_IMAGE, OBS_LANGUAGE, OBS_STATE from tests.utils import require_package @@ -503,16 +503,14 @@ def test_features_basic(): processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=128) input_features = { - PipelineFeatureType.OBSERVATION: { - "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)) - }, + PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, } output_features = processor.transform_features(input_features) # Check that original features are preserved - assert "observation.state" in output_features[PipelineFeatureType.OBSERVATION] + assert OBS_STATE in output_features[PipelineFeatureType.OBSERVATION] assert "action" in output_features[PipelineFeatureType.ACTION] # Check that tokenized features are added @@ -797,7 +795,7 @@ def test_device_detection_cpu(): processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) # Create transition with CPU tensors - observation = {"observation.state": torch.randn(10)} # CPU tensor + observation = {OBS_STATE: torch.randn(10)} # CPU tensor action = torch.randn(5) # CPU tensor transition = create_transition( observation=observation, action=action, complementary_data={"task": "test task"} @@ -821,7 +819,7 @@ def test_device_detection_cuda(): processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) # Create transition with CUDA tensors - observation = {"observation.state": torch.randn(10).cuda()} # CUDA tensor + observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor action = torch.randn(5).cuda() # CUDA tensor transition = create_transition( observation=observation, action=action, complementary_data={"task": "test task"} @@ -847,7 +845,7 @@ def test_device_detection_multi_gpu(): # Test with tensors on cuda:1 device = torch.device("cuda:1") - observation = {"observation.state": torch.randn(10).to(device)} + observation = {OBS_STATE: torch.randn(10).to(device)} action = torch.randn(5).to(device) transition = create_transition( observation=observation, action=action, complementary_data={"task": "multi gpu test"} @@ -943,7 +941,7 @@ def test_device_detection_preserves_dtype(): processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) # Create transition with float tensor (to test dtype isn't affected) - observation = {"observation.state": torch.randn(10, dtype=torch.float16)} + observation = {OBS_STATE: torch.randn(10, dtype=torch.float16)} transition = create_transition(observation=observation, complementary_data={"task": "dtype test"}) result = processor(transition) @@ -977,7 +975,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer): # Start with CPU tensors transition = create_transition( - observation={"observation.state": torch.randn(10)}, # CPU + observation={OBS_STATE: torch.randn(10)}, # CPU action=torch.randn(5), # CPU complementary_data={"task": "pipeline test"}, ) @@ -985,7 +983,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer): result = robot_processor(transition) # All tensors should end up on CUDA (moved by DeviceProcessorStep) - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda" # Tokenized tensors should also be on CUDA @@ -1005,8 +1003,8 @@ def test_simulated_accelerate_scenario(): # Simulate Accelerate scenario: batch already on GPU device = torch.device("cuda:0") observation = { - "observation.state": torch.randn(1, 10).to(device), # Batched, on GPU - "observation.image": torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU + OBS_STATE: torch.randn(1, 10).to(device), # Batched, on GPU + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU } action = torch.randn(1, 5).to(device) # Batched, on GPU diff --git a/tests/rl/test_actor.py b/tests/rl/test_actor.py index aa9913bb..ec67f188 100644 --- a/tests/rl/test_actor.py +++ b/tests/rl/test_actor.py @@ -21,6 +21,7 @@ import pytest import torch from torch.multiprocessing import Event, Queue +from lerobot.utils.constants import OBS_STR from lerobot.utils.transition import Transition from tests.utils import require_package @@ -110,12 +111,12 @@ def test_push_transitions_to_transport_queue(): transitions = [] for i in range(3): transition = Transition( - state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, action=torch.randn(5), reward=torch.tensor(1.0 + i), done=torch.tensor(False), truncated=torch.tensor(False), - next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + next_state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, complementary_info={"step": torch.tensor(i)}, ) transitions.append(transition) diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index 43a6b095..5d95dee0 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -24,6 +24,7 @@ from torch.multiprocessing import Event, Queue from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.policies.sac.configuration_sac import SACConfig +from lerobot.utils.constants import OBS_STR from lerobot.utils.transition import Transition from tests.utils import require_package @@ -33,12 +34,12 @@ def create_test_transitions(count: int = 3) -> list[Transition]: transitions = [] for i in range(count): transition = Transition( - state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, action=torch.randn(5), reward=torch.tensor(1.0 + i), done=torch.tensor(i == count - 1), # Last transition is done truncated=torch.tensor(False), - next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + next_state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, complementary_info={"step": torch.tensor(i), "episode_id": i // 2}, ) transitions.append(transition) diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index b5254f39..6820d321 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -22,11 +22,12 @@ import torch from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR from tests.fixtures.constants import DUMMY_REPO_ID def state_dims() -> list[str]: - return ["observation.image", "observation.state"] + return [OBS_IMAGE, OBS_STATE] @pytest.fixture @@ -61,10 +62,10 @@ def create_random_image() -> torch.Tensor: def create_dummy_transition() -> dict: return { - "observation.image": create_random_image(), + OBS_IMAGE: create_random_image(), "action": torch.randn(4), "reward": torch.tensor(1.0), - "observation.state": torch.randn( + OBS_STATE: torch.randn( 10, ), "done": torch.tensor(False), @@ -98,8 +99,8 @@ def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayB def create_dummy_state() -> dict: return { - "observation.image": create_random_image(), - "observation.state": torch.randn( + OBS_IMAGE: create_random_image(), + OBS_STATE: torch.randn( 10, ), } @@ -180,7 +181,7 @@ def test_empty_buffer_sample_raises_error(replay_buffer): def test_zero_capacity_buffer_raises_error(): with pytest.raises(ValueError, match="Capacity must be greater than 0."): - ReplayBuffer(0, "cpu", ["observation", "next_observation"]) + ReplayBuffer(0, "cpu", [OBS_STR, "next_observation"]) def test_add_transition(replay_buffer, dummy_state, dummy_action): @@ -203,7 +204,7 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action): def test_add_over_capacity(): - replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"]) + replay_buffer = ReplayBuffer(2, "cpu", [OBS_STR, "next_observation"]) dummy_state_1 = create_dummy_state() dummy_action_1 = create_dummy_action() @@ -373,7 +374,7 @@ def test_to_lerobot_dataset(tmp_path): assert ds.num_frames == 4 for j, value in enumerate(ds): - print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j])) + print(torch.equal(value[OBS_IMAGE], buffer.next_states[OBS_IMAGE][j])) for i in range(len(ds)): for feature, value in ds[i].items(): @@ -383,12 +384,12 @@ def test_to_lerobot_dataset(tmp_path): assert torch.equal(value, buffer.rewards[i]) elif feature == "next.done": assert torch.equal(value, buffer.dones[i]) - elif feature == "observation.image": + elif feature == OBS_IMAGE: # Tensor -> numpy is not precise, so we have some diff there # TODO: Check and fix it - torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003) - elif feature == "observation.state": - assert torch.equal(value, buffer.states["observation.state"][i]) + torch.testing.assert_close(value, buffer.states[OBS_IMAGE][i], rtol=0.3, atol=0.003) + elif feature == OBS_STATE: + assert torch.equal(value, buffer.states[OBS_STATE][i]) def test_from_lerobot_dataset(tmp_path): @@ -436,14 +437,14 @@ def test_from_lerobot_dataset(tmp_path): ) assert torch.equal( - replay_buffer.states["observation.state"][: len(replay_buffer)], - reconverted_buffer.states["observation.state"][: len(replay_buffer)], + replay_buffer.states[OBS_STATE][: len(replay_buffer)], + reconverted_buffer.states[OBS_STATE][: len(replay_buffer)], ), "State should be the same after converting to dataset and return back" for i in range(4): torch.testing.assert_close( - replay_buffer.states["observation.image"][i], - reconverted_buffer.states["observation.image"][i], + replay_buffer.states[OBS_IMAGE][i], + reconverted_buffer.states[OBS_IMAGE][i], rtol=0.4, atol=0.004, ) @@ -454,16 +455,16 @@ def test_from_lerobot_dataset(tmp_path): next_index = (i + 1) % 4 torch.testing.assert_close( - replay_buffer.states["observation.image"][next_index], - reconverted_buffer.next_states["observation.image"][i], + replay_buffer.states[OBS_IMAGE][next_index], + reconverted_buffer.next_states[OBS_IMAGE][i], rtol=0.4, atol=0.004, ) for i in range(2, 4): assert torch.equal( - replay_buffer.states["observation.state"][i], - reconverted_buffer.next_states["observation.state"][i], + replay_buffer.states[OBS_STATE][i], + reconverted_buffer.next_states[OBS_STATE][i], ) @@ -563,10 +564,8 @@ def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_functio replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) sampled_transitions = replay_buffer.sample(1) - assert torch.all(sampled_transitions["state"]["observation.image"] == 10), ( - "Image augmentations should be applied" - ) - assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), ( + assert torch.all(sampled_transitions["state"][OBS_IMAGE] == 10), "Image augmentations should be applied" + assert torch.all(sampled_transitions["next_state"][OBS_IMAGE] == 10), ( "Image augmentations should be applied" ) @@ -580,8 +579,8 @@ def test_check_image_augmentations_with_drq_and_default_image_augmentation_funct # Let's check that it doesn't fail and shapes are correct sampled_transitions = replay_buffer.sample(1) - assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84) - assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84) + assert sampled_transitions["state"][OBS_IMAGE].shape == (1, 3, 84, 84) + assert sampled_transitions["next_state"][OBS_IMAGE].shape == (1, 3, 84, 84) def test_random_crop_vectorized_basic(): @@ -620,7 +619,7 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer: buffer = ReplayBuffer( capacity=capacity, device="cpu", - state_keys=["observation.image", "observation.state"], + state_keys=[OBS_IMAGE, OBS_STATE], storage_device="cpu", ) @@ -628,8 +627,8 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer: img = torch.ones(3, 128, 128) * i state_vec = torch.arange(11).float() + i state = { - "observation.image": img, - "observation.state": state_vec, + OBS_IMAGE: img, + OBS_STATE: state_vec, } buffer.add( state=state, @@ -648,14 +647,14 @@ def test_async_iterator_shapes_basic(): iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1) batch = next(iterator) - images = batch["state"]["observation.image"] - states = batch["state"]["observation.state"] + images = batch["state"][OBS_IMAGE] + states = batch["state"][OBS_STATE] assert images.shape == (batch_size, 3, 128, 128) assert states.shape == (batch_size, 11) - next_images = batch["next_state"]["observation.image"] - next_states = batch["next_state"]["observation.state"] + next_images = batch["next_state"][OBS_IMAGE] + next_states = batch["next_state"][OBS_STATE] assert next_images.shape == (batch_size, 3, 128, 128) assert next_states.shape == (batch_size, 11) @@ -668,13 +667,13 @@ def test_async_iterator_multiple_iterations(): for _ in range(5): batch = next(iterator) - images = batch["state"]["observation.image"] - states = batch["state"]["observation.state"] + images = batch["state"][OBS_IMAGE] + states = batch["state"][OBS_STATE] assert images.shape == (batch_size, 3, 128, 128) assert states.shape == (batch_size, 11) - next_images = batch["next_state"]["observation.image"] - next_states = batch["next_state"]["observation.state"] + next_images = batch["next_state"][OBS_IMAGE] + next_states = batch["next_state"][OBS_STATE] assert next_images.shape == (batch_size, 3, 128, 128) assert next_states.shape == (batch_size, 11) diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index 29b7bf70..65a97c6a 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -6,6 +6,7 @@ import numpy as np import pytest from lerobot.processor import TransitionKey +from lerobot.utils.constants import OBS_STATE @pytest.fixture @@ -72,7 +73,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): # Build EnvTransition dict obs = { - "observation.state.temperature": np.float32(25.0), + f"{OBS_STATE}.temperature": np.float32(25.0), # CHW image should be converted to HWC for rr.Image "observation.camera": np.zeros((3, 10, 20), dtype=np.uint8), } @@ -97,7 +98,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): # - action.throttle -> Scalar # - action.vector_0, action.vector_1 -> Scalars expected_keys = { - "observation.state.temperature", + f"{OBS_STATE}.temperature", "observation.camera", "action.throttle", "action.vector_0", @@ -106,7 +107,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): assert set(_keys(calls)) == expected_keys # Check scalar types and values - temp_obj = _obj_for(calls, "observation.state.temperature") + temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature") assert type(temp_obj).__name__ == "DummyScalar" assert temp_obj.value == pytest.approx(25.0)