diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py new file mode 100644 index 000000000..8dd244c27 --- /dev/null +++ b/src/lerobot/processor/__init__.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .device_processor import DeviceProcessor +from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor +from .observation_processor import VanillaObservationProcessor +from .pipeline import ( + ActionProcessor, + DoneProcessor, + EnvTransition, + IdentityProcessor, + InfoProcessor, + ObservationProcessor, + ProcessorStep, + ProcessorStepRegistry, + RewardProcessor, + RobotProcessor, + TransitionKey, + TruncatedProcessor, +) +from .rename_processor import RenameProcessor + +__all__ = [ + "ActionProcessor", + "DeviceProcessor", + "DoneProcessor", + "EnvTransition", + "IdentityProcessor", + "InfoProcessor", + "NormalizerProcessor", + "UnnormalizerProcessor", + "ObservationProcessor", + "ProcessorStep", + "ProcessorStepRegistry", + "RenameProcessor", + "RewardProcessor", + "RobotProcessor", + "TransitionKey", + "TruncatedProcessor", + "VanillaObservationProcessor", +] diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py new file mode 100644 index 000000000..0f00bb470 --- /dev/null +++ b/src/lerobot/processor/device_processor.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any + +import torch + +from lerobot.configs.types import PolicyFeature +from lerobot.processor.pipeline import EnvTransition, TransitionKey +from lerobot.utils.utils import get_safe_torch_device + + +@dataclass +class DeviceProcessor: + """Processes transitions by moving tensors to the specified device. + + This processor ensures that all tensors in the transition are moved to the + specified device (CPU or GPU) before they are returned. + """ + + device: torch.device = "cpu" + + def __post_init__(self): + self.device = get_safe_torch_device(self.device) + self.non_blocking = "cuda" in str(self.device) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + # Create a copy of the transition + new_transition = transition.copy() + + # Process observation tensors + observation = transition.get(TransitionKey.OBSERVATION) + if observation is not None: + new_observation = { + k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v + for k, v in observation.items() + } + new_transition[TransitionKey.OBSERVATION] = new_observation + + # Process action tensor + action = transition.get(TransitionKey.ACTION) + if action is not None and isinstance(action, torch.Tensor): + new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking) + + # Process reward tensor + reward = transition.get(TransitionKey.REWARD) + if reward is not None and isinstance(reward, torch.Tensor): + new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking) + + # Process done tensor + done = transition.get(TransitionKey.DONE) + if done is not None and isinstance(done, torch.Tensor): + new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking) + + # Process truncated tensor + truncated = transition.get(TransitionKey.TRUNCATED) + if truncated is not None and isinstance(truncated, torch.Tensor): + new_transition[TransitionKey.TRUNCATED] = truncated.to( + self.device, non_blocking=self.non_blocking + ) + + return new_transition + + def get_config(self) -> dict[str, Any]: + """Return configuration for serialization.""" + return {"device": self.device} + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py new file mode 100644 index 000000000..14628727f --- /dev/null +++ b/src/lerobot/processor/normalize_processor.py @@ -0,0 +1,331 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import torch +from torch import Tensor + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey + + +def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]: + """Convert numpy arrays and other types to torch tensors.""" + tensor_stats: dict[str, dict[str, Tensor]] = {} + for key, sub in stats.items(): + tensor_stats[key] = {} + for stat_name, value in sub.items(): + if isinstance(value, np.ndarray): + tensor_val = torch.from_numpy(value.astype(np.float32)) + elif isinstance(value, torch.Tensor): + tensor_val = value.to(dtype=torch.float32) + elif isinstance(value, (int, float, list, tuple)): + tensor_val = torch.tensor(value, dtype=torch.float32) + else: + raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}") + tensor_stats[key][stat_name] = tensor_val + return tensor_stats + + +@dataclass +@ProcessorStepRegistry.register(name="normalizer_processor") +class NormalizerProcessor: + """Normalizes observations and actions in a single processor step. + + This processor handles normalization of both observation and action tensors + using either mean/std normalization or min/max scaling to a [-1, 1] range. + + For each tensor key in the stats dictionary, the processor will: + - Use mean/std normalization if those statistics are provided: (x - mean) / std + - Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1 + + The processor can be configured to normalize only specific keys by setting + the normalize_keys parameter. + """ + + # Features and normalisation map are mandatory to match the design of normalize.py + features: dict[str, PolicyFeature] + norm_map: dict[FeatureType, NormalizationMode] + + # Pre-computed statistics coming from dataset.meta.stats for instance. + stats: dict[str, dict[str, Any]] | None = None + + # Explicit subset of keys to normalise. If ``None`` every key (except + # "action") found in ``stats`` will be normalised. Using a ``set`` makes + # membership checks O(1). + normalize_keys: set[str] | None = None + + eps: float = 1e-8 + + _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + + @classmethod + def from_lerobot_dataset( + cls, + dataset: LeRobotDataset, + features: dict[str, PolicyFeature], + norm_map: dict[FeatureType, NormalizationMode], + *, + normalize_keys: set[str] | None = None, + eps: float = 1e-8, + ) -> NormalizerProcessor: + """Factory helper that pulls statistics from a :class:`LeRobotDataset`. + + The features and norm_map parameters are mandatory to match the design + pattern used in normalize.py. + """ + + return cls( + features=features, + norm_map=norm_map, + stats=dataset.meta.stats, + normalize_keys=normalize_keys, + eps=eps, + ) + + def __post_init__(self): + # Handle deserialization from JSON config + if self.features and isinstance(list(self.features.values())[0], dict): + # Features came from JSON - need to reconstruct PolicyFeature objects + reconstructed_features = {} + for key, ft_dict in self.features.items(): + reconstructed_features[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed_features + + if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): + # norm_map came from JSON - need to reconstruct enum keys and values + reconstructed_norm_map = {} + for ft_type_str, norm_mode_str in self.norm_map.items(): + reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) + self.norm_map = reconstructed_norm_map + + # Convert statistics once so we avoid repeated numpy→Tensor conversions + # during runtime. + self.stats = self.stats or {} + self._tensor_stats = _convert_stats_to_tensors(self.stats) + + # Ensure *normalize_keys* is a set for fast look-ups and compare by + # value later when returning the configuration. + if self.normalize_keys is not None and not isinstance(self.normalize_keys, set): + self.normalize_keys = set(self.normalize_keys) + + def _normalize_obs(self, observation): + if observation is None: + return None + + # Decide which keys should be normalised for this call. + if self.normalize_keys is not None: + keys_to_norm = self.normalize_keys + else: + # Use feature map to skip action keys. + keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION} + + processed = dict(observation) + for key in keys_to_norm: + if key not in processed or key not in self._tensor_stats: + continue + + orig_val = processed[key] + tensor = ( + orig_val.to(dtype=torch.float32) + if isinstance(orig_val, torch.Tensor) + else torch.as_tensor(orig_val, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} + + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + processed[key] = (tensor - mean) / (std + self.eps) + elif "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 + return processed + + def _normalize_action(self, action): + if action is None or "action" not in self._tensor_stats: + return action + + tensor = ( + action.to(dtype=torch.float32) + if isinstance(action, torch.Tensor) + else torch.as_tensor(action, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + return (tensor - mean) / (std + self.eps) + if "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 + raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION)) + action = self._normalize_action(transition.get(TransitionKey.ACTION)) + + # Create a new transition with normalized values + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = observation + new_transition[TransitionKey.ACTION] = action + return new_transition + + def get_config(self) -> dict[str, Any]: + config = { + "eps": self.eps, + "features": { + key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() + }, + "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, + } + if self.normalize_keys is not None: + # Serialise as a list for YAML / JSON friendliness + config["normalize_keys"] = sorted(self.normalize_keys) + return config + + def state_dict(self) -> dict[str, Tensor]: + flat = {} + for key, sub in self._tensor_stats.items(): + for stat_name, tensor in sub.items(): + flat[f"{key}.{stat_name}"] = tensor + return flat + + def load_state_dict(self, state: Mapping[str, Tensor]) -> None: + self._tensor_stats.clear() + for flat_key, tensor in state.items(): + key, stat_name = flat_key.rsplit(".", 1) + self._tensor_stats.setdefault(key, {})[stat_name] = tensor + + def reset(self): + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +@dataclass +@ProcessorStepRegistry.register(name="unnormalizer_processor") +class UnnormalizerProcessor: + """Inverse normalisation for observations and actions. + + Exactly mirrors :class:`NormalizerProcessor` but applies the inverse + transform. + """ + + features: dict[str, PolicyFeature] + norm_map: dict[FeatureType, NormalizationMode] + stats: dict[str, dict[str, Any]] | None = None + + _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + + @classmethod + def from_lerobot_dataset( + cls, + dataset: LeRobotDataset, + features: dict[str, PolicyFeature], + norm_map: dict[FeatureType, NormalizationMode], + ) -> UnnormalizerProcessor: + return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats) + + def __post_init__(self): + # Handle deserialization from JSON config + if self.features and isinstance(list(self.features.values())[0], dict): + # Features came from JSON - need to reconstruct PolicyFeature objects + reconstructed_features = {} + for key, ft_dict in self.features.items(): + reconstructed_features[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed_features + + if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): + # norm_map came from JSON - need to reconstruct enum keys and values + reconstructed_norm_map = {} + for ft_type_str, norm_mode_str in self.norm_map.items(): + reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) + self.norm_map = reconstructed_norm_map + + self.stats = self.stats or {} + self._tensor_stats = _convert_stats_to_tensors(self.stats) + + def _unnormalize_obs(self, observation): + if observation is None: + return None + keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION] + processed = dict(observation) + for key in keys: + if key not in processed or key not in self._tensor_stats: + continue + orig_val = processed[key] + tensor = ( + orig_val.to(dtype=torch.float32) + if isinstance(orig_val, torch.Tensor) + else torch.as_tensor(orig_val, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + processed[key] = tensor * std + mean + elif "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val + return processed + + def _unnormalize_action(self, action): + if action is None or "action" not in self._tensor_stats: + return action + tensor = ( + action.to(dtype=torch.float32) + if isinstance(action, torch.Tensor) + else torch.as_tensor(action, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + return tensor * std + mean + if "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + return (tensor + 1) / 2 * (max_val - min_val) + min_val + raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION)) + action = self._unnormalize_action(transition.get(TransitionKey.ACTION)) + + # Create a new transition with unnormalized values + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = observation + new_transition[TransitionKey.ACTION] = action + return new_transition + + def get_config(self) -> dict[str, Any]: + return { + "features": { + key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() + }, + "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, + } + + def state_dict(self) -> dict[str, Tensor]: + flat = {} + for key, sub in self._tensor_stats.items(): + for stat_name, tensor in sub.items(): + flat[f"{key}.{stat_name}"] = tensor + return flat + + def load_state_dict(self, state: Mapping[str, Tensor]) -> None: + self._tensor_stats.clear() + for flat_key, tensor in state.items(): + key, stat_name = flat_key.rsplit(".", 1) + self._tensor_stats.setdefault(key, {})[stat_name] = tensor + + def reset(self): + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py new file mode 100644 index 000000000..7d63db238 --- /dev/null +++ b/src/lerobot/processor/observation_processor.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +import einops +import numpy as np +import torch +from torch import Tensor + +from lerobot.configs.types import PolicyFeature +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry + + +@dataclass +@ProcessorStepRegistry.register(name="observation_processor") +class VanillaObservationProcessor(ObservationProcessor): + """ + Processes environment observations into the LeRobot format by handling both images and states. + + Image processing: + - Converts channel-last (H, W, C) images to channel-first (C, H, W) + - Normalizes uint8 images ([0, 255]) to float32 ([0, 1]) + - Adds a batch dimension if missing + - Supports single images and image dictionaries + + State processing: + - Maps 'environment_state' to observation.environment_state + - Maps 'agent_pos' to observation.state + - Converts numpy arrays to tensors + - Adds a batch dimension if missing + """ + + def _process_single_image(self, img: np.ndarray) -> Tensor: + """Process a single image array.""" + # Convert to tensor + img_tensor = torch.from_numpy(img) + + # Add batch dimension if needed + if img_tensor.ndim == 3: + img_tensor = img_tensor.unsqueeze(0) + + # Validate image format + _, h, w, c = img_tensor.shape + if not (c < h and c < w): + raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}") + + if img_tensor.dtype != torch.uint8: + raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}") + + # Convert to channel-first format + img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous() + + # Convert to float32 and normalize to [0, 1] + img_tensor = img_tensor.type(torch.float32) / 255.0 + + return img_tensor + + def _process_observation(self, observation): + """ + Processes both image and state observations. + """ + + processed_obs = observation.copy() + + if "pixels" in processed_obs: + pixels = processed_obs.pop("pixels") + + if isinstance(pixels, dict): + imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()} + else: + imgs = {OBS_IMAGE: pixels} + + for imgkey, img in imgs.items(): + processed_obs[imgkey] = self._process_single_image(img) + + if "environment_state" in processed_obs: + env_state_np = processed_obs.pop("environment_state") + env_state = torch.from_numpy(env_state_np).float() + if env_state.dim() == 1: + env_state = env_state.unsqueeze(0) + processed_obs[OBS_ENV_STATE] = env_state + + if "agent_pos" in processed_obs: + agent_pos_np = processed_obs.pop("agent_pos") + agent_pos = torch.from_numpy(agent_pos_np).float() + if agent_pos.dim() == 1: + agent_pos = agent_pos.unsqueeze(0) + processed_obs[OBS_STATE] = agent_pos + + return processed_obs + + def observation(self, observation): + return self._process_observation(observation) + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Transforms feature keys to a standardized contract. + + This method handles several renaming patterns: + - Exact matches (e.g., 'pixels' -> 'OBS_IMAGE'). + - Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE'). + - Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1'). + - Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1'). + - environment_state -> OBS_ENV_STATE, + - agent_pos -> OBS_STATE, + - observation.environment_state -> OBS_ENV_STATE, + - observation.agent_pos -> OBS_STATE + """ + exact_pairs = { + "pixels": OBS_IMAGE, + "environment_state": OBS_ENV_STATE, + "agent_pos": OBS_STATE, + } + + prefix_pairs = { + "pixels.": f"{OBS_IMAGES}.", + } + + for key in list(features.keys()): + matched_prefix = False + for old_prefix, new_prefix in prefix_pairs.items(): + prefixed_old = f"observation.{old_prefix}" + if key.startswith(prefixed_old): + suffix = key[len(prefixed_old) :] + features[f"{new_prefix}{suffix}"] = features.pop(key) + matched_prefix = True + break + + if key.startswith(old_prefix): + suffix = key[len(old_prefix) :] + features[f"{new_prefix}{suffix}"] = features.pop(key) + matched_prefix = True + break + + if matched_prefix: + continue + + for old, new in exact_pairs.items(): + if key == old or key == f"observation.{old}": + if key in features: + features[new] = features.pop(key) + break + + return features diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py new file mode 100644 index 000000000..6e1b2a2cb --- /dev/null +++ b/src/lerobot/processor/pipeline.py @@ -0,0 +1,1264 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import importlib +import json +import os +from collections.abc import Callable, Iterable, Sequence +from copy import deepcopy +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Protocol, TypedDict + +import torch +from huggingface_hub import ModelHubMixin, hf_hub_download +from huggingface_hub.errors import HfHubHTTPError +from safetensors.torch import load_file, save_file + +from lerobot.configs.types import PolicyFeature + + +class TransitionKey(str, Enum): + """Keys for accessing EnvTransition dictionary components.""" + + # TODO(Steven): Use consts + OBSERVATION = "observation" + ACTION = "action" + REWARD = "reward" + DONE = "done" + TRUNCATED = "truncated" + INFO = "info" + COMPLEMENTARY_DATA = "complementary_data" + + +EnvTransition = TypedDict( + "EnvTransition", + { + TransitionKey.OBSERVATION.value: dict[str, Any] | None, + TransitionKey.ACTION.value: Any | torch.Tensor | None, + TransitionKey.REWARD.value: float | torch.Tensor | None, + TransitionKey.DONE.value: bool | torch.Tensor | None, + TransitionKey.TRUNCATED.value: bool | torch.Tensor | None, + TransitionKey.INFO.value: dict[str, Any] | None, + TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None, + }, +) + + +class ProcessorStepRegistry: + """Registry for processor steps that enables saving/loading by name instead of module path.""" + + _registry: dict[str, type] = {} + + @classmethod + def register(cls, name: str = None): + """Decorator to register a processor step class. + + Args: + name: Optional registration name. If not provided, uses class name. + + Example: + @ProcessorStepRegistry.register("adaptive_normalizer") + class AdaptiveObservationNormalizer: + ... + """ + + def decorator(step_class: type) -> type: + registration_name = name if name is not None else step_class.__name__ + + if registration_name in cls._registry: + raise ValueError( + f"Processor step '{registration_name}' is already registered. " + f"Use a different name or unregister the existing one first." + ) + + cls._registry[registration_name] = step_class + # Store the registration name on the class for later reference + step_class._registry_name = registration_name + return step_class + + return decorator + + @classmethod + def get(cls, name: str) -> type: + """Get a registered processor step class by name. + + Args: + name: The registration name of the step. + + Returns: + The registered step class. + + Raises: + KeyError: If the step is not registered. + """ + if name not in cls._registry: + available = list(cls._registry.keys()) + raise KeyError( + f"Processor step '{name}' not found in registry. " + f"Available steps: {available}. " + f"Make sure the step is registered using @ProcessorStepRegistry.register()" + ) + return cls._registry[name] + + @classmethod + def unregister(cls, name: str) -> None: + """Remove a step from the registry.""" + cls._registry.pop(name, None) + + @classmethod + def list(cls) -> list[str]: + """List all registered step names.""" + return list(cls._registry.keys()) + + @classmethod + def clear(cls) -> None: + """Clear all registrations.""" + cls._registry.clear() + + +class ProcessorStep(Protocol): + """Structural typing interface for a single processor step. + + A step is any callable accepting a full `EnvTransition` dict and + returning a (possibly modified) dict of the same structure. Implementers + are encouraged—but not required—to expose the optional helper methods + listed below. When present, these hooks let `RobotProcessor` + automatically serialise the step's configuration and learnable state using + a safe-to-share JSON + SafeTensors format. + + + **Required**: + - ``__call__(transition: EnvTransition) -> EnvTransition`` + - ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]`` + + Optional helper protocol: + * ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable + configuration and state. YOU decide what to save here. This is where all + non-tensor state goes (e.g., name, counter, threshold, window_size). + The config dict will be passed to your class constructor when loading. + * ``state_dict() -> dict[str, torch.Tensor]`` – PyTorch tensor state ONLY. + This is exclusively for torch.Tensor objects (e.g., learned weights, + running statistics as tensors). Never put simple Python types here. + * ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict + containing torch tensors only. + * ``reset()`` – Clear internal buffers at episode boundaries. + + Example separation: + - get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10} + - state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)} + """ + + def __call__(self, transition: EnvTransition) -> EnvTransition: ... + + def get_config(self) -> dict[str, Any]: ... + + def state_dict(self) -> dict[str, torch.Tensor]: ... + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ... + + def reset(self) -> None: ... + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ... + + +def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401 + """Convert a *batch* dict coming from Learobot replay/dataset code into an + ``EnvTransition`` dictionary. + + The function maps well known keys to the EnvTransition structure. Missing keys are + filled with sane defaults (``None`` or ``0.0``/``False``). + + Keys recognised (case-sensitive): + + * "observation.*" (keys starting with "observation." are grouped into observation dict) + * "action" + * "next.reward" + * "next.done" + * "next.truncated" + * "info" + + Additional keys are ignored so that existing dataloaders can carry extra + metadata without breaking the processor. + """ + + # Extract observation keys + observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} + observation = observation_keys if observation_keys else None + + # Extract padding and task keys for complementary data + pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} + task_key = {"task": batch["task"]} if "task" in batch else {} + complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {} + + transition: EnvTransition = { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: batch.get("action"), + TransitionKey.REWARD: batch.get("next.reward", 0.0), + TransitionKey.DONE: batch.get("next.done", False), + TransitionKey.TRUNCATED: batch.get("next.truncated", False), + TransitionKey.INFO: batch.get("info", {}), + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + return transition + + +def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401 + """Inverse of :pyfunc:`_default_batch_to_transition`. Returns a dict with + the canonical field names used throughout *LeRobot*. + """ + + batch = { + "action": transition.get(TransitionKey.ACTION), + "next.reward": transition.get(TransitionKey.REWARD, 0.0), + "next.done": transition.get(TransitionKey.DONE, False), + "next.truncated": transition.get(TransitionKey.TRUNCATED, False), + "info": transition.get(TransitionKey.INFO, {}), + } + + # Add padding and task data from complementary_data + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data: + pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k} + batch.update(pad_data) + + if "task" in complementary_data: + batch["task"] = complementary_data["task"] + + # Handle observation - flatten dict to observation.* keys if it's a dict + observation = transition.get(TransitionKey.OBSERVATION) + if isinstance(observation, dict): + batch.update(observation) + + return batch + + +@dataclass +class RobotProcessor(ModelHubMixin): + """ + Composable, debuggable post-processing processor for robot transitions. + + The class orchestrates an ordered collection of small, functional transforms—steps—executed + left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts + and batch dictionaries, automatically converting between formats as needed. + + Args: + steps: Ordered list of processing steps executed on every call. Defaults to empty list. + name: Human-readable identifier that is persisted inside the JSON config. + Defaults to "RobotProcessor". + to_transition: Function to convert batch dict to EnvTransition dict. + Defaults to _default_batch_to_transition. + to_output: Function to convert EnvTransition dict to the desired output format. + Usually it is a batch dict or EnvTransition dict. + Defaults to _default_transition_to_batch. + before_step_hooks: List of hooks called before each step. Each hook receives the step + index and transition, and can optionally return a modified transition. + after_step_hooks: List of hooks called after each step. Each hook receives the step + index and transition, and can optionally return a modified transition. + + Hook Semantics: + - Hooks are executed sequentially in the order they were registered. There is no way to + reorder hooks after registration without creating a new pipeline. + - Hooks are for observation/monitoring only and DO NOT modify transitions. They are called + with the step index and current transition for logging, debugging, or monitoring purposes. + - All hooks for a given type (before/after) are executed for every step, or none at all if + an error occurs. There is no partial execution of hooks. + - Hooks should generally be stateless to maintain predictable behavior. If you need stateful + processing, consider implementing a proper ProcessorStep instead. + - To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline. + - Hooks ALWAYS receive transitions in EnvTransition format, regardless of the input format + passed to __call__. This ensures consistent hook behavior whether processing batch dicts + or EnvTransition objects. + """ + + steps: Sequence[ProcessorStep] = field(default_factory=list) + name: str = "RobotProcessor" + + to_transition: Callable[[dict[str, Any]], EnvTransition] = field( + default_factory=lambda: _default_batch_to_transition, repr=False + ) + to_output: Callable[[EnvTransition], dict[str, Any] | EnvTransition] = field( + default_factory=lambda: _default_transition_to_batch, repr=False + ) + + # Processor-level hooks for observation/monitoring + # Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes + before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) + after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) + + def __call__(self, data: EnvTransition | dict[str, Any]): + """Process data through all steps. + + The method accepts either the classic EnvTransition dict or a batch dictionary + (like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied + it is first converted to the internal dict format using to_transition; after all + steps are executed the dict is transformed back into a batch dict with to_batch and the + result is returned – thereby preserving the caller's original data type. + + Args: + data: Either an EnvTransition dict or a batch dictionary to process. + + Returns: + The processed data in the same format as the input (EnvTransition or batch dict). + + Raises: + ValueError: If the transition is not a valid EnvTransition format. + """ + # Check if we need to convert back to batch format at the end + _, called_with_batch = self._prepare_transition(data) + + # Use step_through to get the iterator + step_iterator = self.step_through(data) + + # Get initial state (before any steps) + current_transition = next(step_iterator) + + # Process each step with hooks + for idx, next_transition in enumerate(step_iterator): + # Apply before hooks with current state (before step execution) + for hook in self.before_step_hooks: + hook(idx, current_transition) + + # Move to next state (after step execution) + current_transition = next_transition + + # Apply after hooks with updated state + for hook in self.after_step_hooks: + hook(idx, current_transition) + + # Convert back to original format if needed + return self.to_output(current_transition) if called_with_batch else current_transition + + def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]: + """Prepare and validate transition data for processing. + + Args: + data: Either an EnvTransition dict or a batch dictionary to process. + + Returns: + A tuple of (prepared_transition, called_with_batch_flag) + + Raises: + ValueError: If the transition is not a valid EnvTransition format. + """ + # Check if data is already an EnvTransition or needs conversion + if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()): + # It's a batch dict, convert it + called_with_batch = True + transition = self.to_transition(data) + else: + # It's already an EnvTransition + called_with_batch = False + transition = data + + # Basic validation + if not isinstance(transition, dict): + raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}") + + return transition, called_with_batch + + def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]: + """Yield the intermediate results after each processor step. + + This is a low-level method that does NOT apply hooks. It simply executes each step + and yields the intermediate results. This allows users to debug the pipeline or + apply custom logic between steps if needed. + + Note: This method always yields EnvTransition objects regardless of input format. + If you need the results in the original input format, you'll need to convert them + using `to_output()`. + + Args: + data: Either an EnvTransition dict or a batch dictionary to process. + + Yields: + The intermediate EnvTransition results after each step. + """ + transition, _ = self._prepare_transition(data) + + # Yield initial state + yield transition + + # Process each step WITHOUT hooks (low-level method) + for processor_step in self.steps: + transition = processor_step(transition) + yield transition + + def _save_pretrained(self, save_directory: Path, **kwargs): + """Internal save method for ModelHubMixin compatibility.""" + # Extract config_filename from kwargs if provided + config_filename = kwargs.pop("config_filename", None) + self.save_pretrained(save_directory, config_filename=config_filename) + + def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs): + """Serialize the processor definition and parameters to *save_directory*. + + Args: + save_directory: Directory where the processor will be saved. + config_filename: Optional custom config filename. If not provided, defaults to + "{self.name}.json" where self.name is sanitized for filesystem compatibility. + """ + os.makedirs(str(save_directory), exist_ok=True) + + # Sanitize processor name for use in filenames + import re + + # The huggingface hub does not allow special characters in the repo name, so we sanitize the name + sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) + + # Use sanitized name for config if not provided + if config_filename is None: + config_filename = f"{sanitized_name}.json" + + config: dict[str, Any] = { + "name": self.name, + "steps": [], + } + + for step_index, processor_step in enumerate(self.steps): + # Check if step was registered + registry_name = getattr(processor_step.__class__, "_registry_name", None) + + step_entry: dict[str, Any] = {} + if registry_name: + # Use registry name for registered steps + step_entry["registry_name"] = registry_name + else: + # Fall back to full module path for unregistered steps + step_entry["class"] = ( + f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}" + ) + + if hasattr(processor_step, "get_config"): + step_entry["config"] = processor_step.get_config() + + if hasattr(processor_step, "state_dict"): + state = processor_step.state_dict() + if state: + # Clone tensors to avoid shared memory issues + # This ensures each tensor has its own memory allocation + # The reason is to avoid the following error: + # RuntimeError: Some tensors share memory, this will lead to duplicate memory on disk + # and potential differences when loading them again + # ------------------------------------------------------------------------------ + # Since the state_dict of processor will be light, we can just clone the tensors + # and save them to the disk. + cloned_state = {} + for key, tensor in state.items(): + cloned_state[key] = tensor.clone() + + # Include pipeline name and step index to ensure unique filenames + # This prevents conflicts when multiple processors are saved in the same directory + if registry_name: + state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors" + else: + state_filename = f"{sanitized_name}_step_{step_index}.safetensors" + + save_file(cloned_state, os.path.join(str(save_directory), state_filename)) + step_entry["state_file"] = state_filename + + config["steps"].append(step_entry) + + with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer: + json.dump(config, file_pointer, indent=2) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str | Path, + *, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict[str, str] | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + config_filename: str | None = None, + overrides: dict[str, Any] | None = None, + **kwargs, + ) -> RobotProcessor: + """Load a serialized processor from source (local path or Hugging Face Hub identifier). + + Args: + pretrained_model_name_or_path: Local path to a saved processor directory or Hugging Face Hub identifier + (e.g., "username/processor-name"). + config_filename: Optional specific config filename to load. If not provided, will: + - For local paths: look for any .json file in the directory (error if multiple found) + - For HF Hub: try common names ("processor.json", "preprocessor.json", "postprocessor.json") + overrides: Optional dictionary mapping step names to configuration overrides. + Keys must match exact step class names (for unregistered steps) or registry names + (for registered steps). Values are dictionaries containing parameter overrides + that will be merged with the saved configuration. This is useful for providing + non-serializable objects like environment instances. + + Returns: + A RobotProcessor instance loaded from the saved configuration. + + Raises: + ImportError: If a processor step class cannot be loaded or imported. + ValueError: If a step cannot be instantiated with the provided configuration. + KeyError: If an override key doesn't match any step in the saved configuration. + + Examples: + Basic loading: + ```python + processor = RobotProcessor.from_pretrained("path/to/processor") + ``` + + Loading specific config file: + ```python + processor = RobotProcessor.from_pretrained( + "username/multi-processor-repo", config_filename="preprocessor.json" + ) + ``` + + Loading with overrides for non-serializable objects: + ```python + import gym + + env = gym.make("CartPole-v1") + processor = RobotProcessor.from_pretrained( + "username/cartpole-processor", overrides={"ActionRepeatStep": {"env": env}} + ) + ``` + + Multiple overrides: + ```python + processor = RobotProcessor.from_pretrained( + "path/to/processor", + overrides={ + "CustomStep": {"param1": "new_value"}, + "device_processor": {"device": "cuda:1"}, # For registered steps + }, + ) + ``` + """ + # Use the local variable name 'source' for clarity + source = str(pretrained_model_name_or_path) + + if Path(source).is_dir(): + # Local path - use it directly + base_path = Path(source) + + if config_filename is None: + # Look for any .json file in the directory + json_files = list(base_path.glob("*.json")) + if len(json_files) == 0: + raise FileNotFoundError(f"No .json configuration files found in {source}") + elif len(json_files) > 1: + raise ValueError( + f"Multiple .json files found in {source}: {[f.name for f in json_files]}. " + f"Please specify which one to load using the config_filename parameter." + ) + config_filename = json_files[0].name + + with open(base_path / config_filename) as file_pointer: + loaded_config: dict[str, Any] = json.load(file_pointer) + else: + # Hugging Face Hub - download all required files + if config_filename is None: + # Try common config names + common_names = [ + "processor.json", + "preprocessor.json", + "postprocessor.json", + "robotprocessor.json", + ] + config_path = None + for name in common_names: + try: + config_path = hf_hub_download( + source, + name, + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + config_filename = name + break + except (FileNotFoundError, OSError, HfHubHTTPError): + # FileNotFoundError: local file issues + # OSError: network/system errors + # HfHubHTTPError: file not found on Hub (404) or other HTTP errors + continue + + if config_path is None: + raise FileNotFoundError( + f"No processor configuration file found in {source}. " + f"Tried: {common_names}. Please specify the config_filename parameter." + ) + else: + # Download specific config file + config_path = hf_hub_download( + source, + config_filename, + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + with open(config_path) as file_pointer: + loaded_config = json.load(file_pointer) + + # Store downloaded files in the same directory as the config + base_path = Path(config_path).parent + + # Handle None overrides + if overrides is None: + overrides = {} + + # Validate that all override keys will be matched + override_keys = set(overrides.keys()) + + steps: list[ProcessorStep] = [] + for step_entry in loaded_config["steps"]: + # Check if step uses registry name or module path + if "registry_name" in step_entry: + # Load from registry + try: + step_class = ProcessorStepRegistry.get(step_entry["registry_name"]) + step_key = step_entry["registry_name"] + except KeyError as e: + raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e + else: + # Fall back to module path loading for backward compatibility + full_class_path = step_entry["class"] + module_path, class_name = full_class_path.rsplit(".", 1) + + # Import the module containing the step class + try: + module = importlib.import_module(module_path) + step_class = getattr(module, class_name) + step_key = class_name + except (ImportError, AttributeError) as e: + raise ImportError( + f"Failed to load processor step '{full_class_path}'. " + f"Make sure the module '{module_path}' is installed and contains class '{class_name}'. " + f"Consider registering the step using @ProcessorStepRegistry.register() for better portability. " + f"Error: {str(e)}" + ) from e + + # Instantiate the step with its config + try: + saved_cfg = step_entry.get("config", {}) + step_overrides = overrides.get(step_key, {}) + merged_cfg = {**saved_cfg, **step_overrides} + step_instance: ProcessorStep = step_class(**merged_cfg) + + # Track which override keys were used + if step_key in override_keys: + override_keys.discard(step_key) + + except Exception as e: + step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown")) + raise ValueError( + f"Failed to instantiate processor step '{step_name}' with config: {step_entry.get('config', {})}. " + f"Error: {str(e)}" + ) from e + + # Load state if available + if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"): + if Path(source).is_dir(): + # Local path - read directly + state_path = str(base_path / step_entry["state_file"]) + else: + # Hugging Face Hub - download the state file + state_path = hf_hub_download( + source, + step_entry["state_file"], + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + step_instance.load_state_dict(load_file(state_path)) + + steps.append(step_instance) + + # Check for unused override keys + if override_keys: + available_keys = [] + for step_entry in loaded_config["steps"]: + if "registry_name" in step_entry: + available_keys.append(step_entry["registry_name"]) + else: + full_class_path = step_entry["class"] + class_name = full_class_path.rsplit(".", 1)[1] + available_keys.append(class_name) + + raise KeyError( + f"Override keys {list(override_keys)} do not match any step in the saved configuration. " + f"Available step keys: {available_keys}. " + f"Make sure override keys match exact step class names or registry names." + ) + + return cls(steps, loaded_config.get("name", "RobotProcessor")) + + def __len__(self) -> int: + """Return the number of steps in the processor.""" + return len(self.steps) + + def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor: + """Indexing helper exposing underlying steps. + * ``int`` – returns the idx-th ProcessorStep. + * ``slice`` – returns a new RobotProcessor with the sliced steps. + """ + if isinstance(idx, slice): + return RobotProcessor(self.steps[idx], self.name) + return self.steps[idx] + + def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): + """Attach fn to be executed before every processor step.""" + self.before_step_hooks.append(fn) + + def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): + """Remove a previously registered before_step hook. + + Args: + fn: The exact function reference that was registered. Must be the same object. + + Raises: + ValueError: If the hook is not found in the registered hooks. + """ + try: + self.before_step_hooks.remove(fn) + except ValueError: + raise ValueError( + f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference." + ) from None + + def register_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): + """Attach fn to be executed after every processor step.""" + self.after_step_hooks.append(fn) + + def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): + """Remove a previously registered after_step hook. + + Args: + fn: The exact function reference that was registered. Must be the same object. + + Raises: + ValueError: If the hook is not found in the registered hooks. + """ + try: + self.after_step_hooks.remove(fn) + except ValueError: + raise ValueError( + f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference." + ) from None + + def reset(self): + """Clear state in every step that implements ``reset()`` and fire registered hooks.""" + for step in self.steps: + if hasattr(step, "reset"): + step.reset() # type: ignore[attr-defined] + + def __repr__(self) -> str: + """Return a readable string representation of the processor.""" + step_names = [step.__class__.__name__ for step in self.steps] + + if not step_names: + steps_repr = "steps=0: []" + elif len(step_names) <= 3: + steps_repr = f"steps={len(step_names)}: [{', '.join(step_names)}]" + else: + # Show first 2 and last 1 with ellipsis for long lists + displayed = f"{step_names[0]}, {step_names[1]}, ..., {step_names[-1]}" + steps_repr = f"steps={len(step_names)}: [{displayed}]" + + parts = [f"name='{self.name}'", steps_repr] + + return f"RobotProcessor({', '.join(parts)})" + + def __post_init__(self): + for i, step in enumerate(self.steps): + if not callable(step): + raise TypeError( + f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition" + ) + + fc = getattr(step, "feature_contract", None) + if not callable(fc): + raise TypeError( + f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]" + ) + + def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """ + Apply ALL steps in order. Each step must implement + feature_contract(features) and return a dict (full or incremental schema). + """ + features: dict[str, PolicyFeature] = deepcopy(initial_features) + + for _, step in enumerate(self.steps): + out = step.feature_contract(features) + if not isinstance(out, dict): + raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]") + features = out + return features + + +class ObservationProcessor: + """Base class for processors that modify only the observation component of a transition. + + Subclasses should override the `observation` method to implement custom observation processing. + This class handles the boilerplate of extracting and reinserting the processed observation + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class MyObservationScaler(ObservationProcessor): + def __init__(self, scale_factor): + self.scale_factor = scale_factor + + def observation(self, observation): + return observation * self.scale_factor + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific observation processing logic. + """ + + def observation(self, observation): + """Process the observation component. + + Args: + observation: The observation to process + + Returns: + The processed observation + """ + return observation + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = transition.get(TransitionKey.OBSERVATION) + if observation is None: + return transition + + processed_observation = self.observation(observation) + # Create a new transition dict with the processed observation + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_observation + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class ActionProcessor: + """Base class for processors that modify only the action component of a transition. + + Subclasses should override the `action` method to implement custom action processing. + This class handles the boilerplate of extracting and reinserting the processed action + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class ActionClipping(ActionProcessor): + def __init__(self, min_val, max_val): + self.min_val = min_val + self.max_val = max_val + + def action(self, action): + return np.clip(action, self.min_val, self.max_val) + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific action processing logic. + """ + + def action(self, action): + """Process the action component. + + Args: + action: The action to process + + Returns: + The processed action + """ + return action + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition.get(TransitionKey.ACTION) + if action is None: + return transition + + processed_action = self.action(action) + # Create a new transition dict with the processed action + new_transition = transition.copy() + new_transition[TransitionKey.ACTION] = processed_action + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class RewardProcessor: + """Base class for processors that modify only the reward component of a transition. + + Subclasses should override the `reward` method to implement custom reward processing. + This class handles the boilerplate of extracting and reinserting the processed reward + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class RewardScaler(RewardProcessor): + def __init__(self, scale_factor): + self.scale_factor = scale_factor + + def reward(self, reward): + return reward * self.scale_factor + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific reward processing logic. + """ + + def reward(self, reward): + """Process the reward component. + + Args: + reward: The reward to process + + Returns: + The processed reward + """ + return reward + + def __call__(self, transition: EnvTransition) -> EnvTransition: + reward = transition.get(TransitionKey.REWARD) + if reward is None: + return transition + + processed_reward = self.reward(reward) + # Create a new transition dict with the processed reward + new_transition = transition.copy() + new_transition[TransitionKey.REWARD] = processed_reward + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class DoneProcessor: + """Base class for processors that modify only the done flag of a transition. + + Subclasses should override the `done` method to implement custom done flag processing. + This class handles the boilerplate of extracting and reinserting the processed done flag + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class TimeoutDone(DoneProcessor): + def __init__(self, max_steps): + self.steps = 0 + self.max_steps = max_steps + + def done(self, done): + self.steps += 1 + return done or self.steps >= self.max_steps + + def reset(self): + self.steps = 0 + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific done flag processing logic. + """ + + def done(self, done): + """Process the done flag. + + Args: + done: The done flag to process + + Returns: + The processed done flag + """ + return done + + def __call__(self, transition: EnvTransition) -> EnvTransition: + done = transition.get(TransitionKey.DONE) + if done is None: + return transition + + processed_done = self.done(done) + # Create a new transition dict with the processed done flag + new_transition = transition.copy() + new_transition[TransitionKey.DONE] = processed_done + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class TruncatedProcessor: + """Base class for processors that modify only the truncated flag of a transition. + + Subclasses should override the `truncated` method to implement custom truncated flag processing. + This class handles the boilerplate of extracting and reinserting the processed truncated flag + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class EarlyTruncation(TruncatedProcessor): + def __init__(self, threshold): + self.threshold = threshold + + def truncated(self, truncated): + # Additional truncation condition + return truncated or some_condition > self.threshold + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific truncated flag processing logic. + """ + + def truncated(self, truncated): + """Process the truncated flag. + + Args: + truncated: The truncated flag to process + + Returns: + The processed truncated flag + """ + return truncated + + def __call__(self, transition: EnvTransition) -> EnvTransition: + truncated = transition.get(TransitionKey.TRUNCATED) + if truncated is None: + return transition + + processed_truncated = self.truncated(truncated) + # Create a new transition dict with the processed truncated flag + new_transition = transition.copy() + new_transition[TransitionKey.TRUNCATED] = processed_truncated + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class InfoProcessor: + """Base class for processors that modify only the info dictionary of a transition. + + Subclasses should override the `info` method to implement custom info processing. + This class handles the boilerplate of extracting and reinserting the processed info + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class InfoAugmenter(InfoProcessor): + def __init__(self): + self.step_count = 0 + + def info(self, info): + info = info.copy() # Create a copy to avoid modifying the original + info["steps"] = self.step_count + self.step_count += 1 + return info + + def reset(self): + self.step_count = 0 + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific info dictionary processing logic. + """ + + def info(self, info): + """Process the info dictionary. + + Args: + info: The info dictionary to process + + Returns: + The processed info dictionary + """ + return info + + def __call__(self, transition: EnvTransition) -> EnvTransition: + info = transition.get(TransitionKey.INFO) + if info is None: + return transition + + processed_info = self.info(info) + # Create a new transition dict with the processed info + new_transition = transition.copy() + new_transition[TransitionKey.INFO] = processed_info + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class ComplementaryDataProcessor: + """Base class for processors that modify only the complementary data of a transition. + + Subclasses should override the `complementary_data` method to implement custom complementary data processing. + This class handles the boilerplate of extracting and reinserting the processed complementary data + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + """ + + def complementary_data(self, complementary_data): + """Process the complementary data. + + Args: + complementary_data: The complementary data to process + + Returns: + The processed complementary data + """ + return complementary_data + + def __call__(self, transition: EnvTransition) -> EnvTransition: + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return transition + + processed_complementary_data = self.complementary_data(complementary_data) + # Create a new transition dict with the processed complementary data + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class IdentityProcessor: + """Identity processor that does nothing.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py new file mode 100644 index 000000000..4fe4105a5 --- /dev/null +++ b/src/lerobot/processor/rename_processor.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Any + +from lerobot.configs.types import PolicyFeature +from lerobot.processor.pipeline import ( + ObservationProcessor, + ProcessorStepRegistry, +) + + +@dataclass +@ProcessorStepRegistry.register(name="rename_processor") +class RenameProcessor(ObservationProcessor): + """Rename processor that renames keys in the observation.""" + + rename_map: dict[str, str] = field(default_factory=dict) + + def observation(self, observation): + processed_obs = {} + for key, value in observation.items(): + if key in self.rename_map: + processed_obs[self.rename_map[key]] = value + else: + processed_obs[key] = value + + return processed_obs + + def get_config(self) -> dict[str, Any]: + return {"rename_map": self.rename_map} + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Transforms: + - Each key in the observation that appears in `rename_map` is renamed to its value. + - Keys not in `rename_map` remain unchanged. + """ + return {self.rename_map.get(k, k): v for k, v in features.items()} diff --git a/tests/conftest.py b/tests/conftest.py index 69dd3049b..7940cc5ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,7 @@ import traceback import pytest from serial import SerialException +from lerobot.configs.types import FeatureType, PolicyFeature from tests.utils import DEVICE # Import fixture modules as plugins @@ -69,3 +70,19 @@ def patch_builtins_input(monkeypatch): print(text) monkeypatch.setattr("builtins.input", print_text) + + +@pytest.fixture +def policy_feature_factory(): + """PolicyFeature factory""" + + def _pf(ft: FeatureType, shape: tuple[int, ...]) -> PolicyFeature: + return PolicyFeature(type=ft, shape=shape) + + return _pf + + +def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None: + assert isinstance(features, dict) + assert all(isinstance(k, str) for k in features.keys()) + assert all(isinstance(v, PolicyFeature) for v in features.values()) diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py new file mode 100644 index 000000000..63894025d --- /dev/null +++ b/tests/processor/test_batch_conversion.py @@ -0,0 +1,282 @@ +import torch + +from lerobot.processor.pipeline import ( + RobotProcessor, + TransitionKey, + _default_batch_to_transition, + _default_transition_to_batch, +) + + +def _dummy_batch(): + """Create a dummy batch using the new format with observation.* and next.* keys.""" + return { + "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.image.right": torch.randn(1, 3, 128, 128), + "observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]), + "action": torch.tensor([[0.5]]), + "next.reward": 1.0, + "next.done": False, + "next.truncated": False, + "info": {"key": "value"}, + } + + +def test_observation_grouping_roundtrip(): + """Test that observation.* keys are properly grouped and ungrouped.""" + proc = RobotProcessor([]) + batch_in = _dummy_batch() + batch_out = proc(batch_in) + + # Check that all observation.* keys are preserved + original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")} + reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")} + + assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys()) + + # Check tensor values + assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"]) + assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"]) + assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"]) + + # Check other fields + assert torch.allclose(batch_out["action"], batch_in["action"]) + assert batch_out["next.reward"] == batch_in["next.reward"] + assert batch_out["next.done"] == batch_in["next.done"] + assert batch_out["next.truncated"] == batch_in["next.truncated"] + assert batch_out["info"] == batch_in["info"] + + +def test_batch_to_transition_observation_grouping(): + """Test that _default_batch_to_transition correctly groups observation.* keys.""" + batch = { + "observation.image.top": torch.randn(1, 3, 128, 128), + "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.state": [1, 2, 3, 4], + "action": "action_data", + "next.reward": 1.5, + "next.done": True, + "next.truncated": False, + "info": {"episode": 42}, + } + + transition = _default_batch_to_transition(batch) + + # Check observation is a dict with all observation.* keys + assert isinstance(transition[TransitionKey.OBSERVATION], dict) + assert "observation.image.top" in transition[TransitionKey.OBSERVATION] + assert "observation.image.left" in transition[TransitionKey.OBSERVATION] + assert "observation.state" in transition[TransitionKey.OBSERVATION] + + # Check values are preserved + assert torch.allclose( + transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"] + ) + assert torch.allclose( + transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"] + ) + assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] + + # Check other fields + assert transition[TransitionKey.ACTION] == "action_data" + assert transition[TransitionKey.REWARD] == 1.5 + assert transition[TransitionKey.DONE] + assert not transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {"episode": 42} + assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} + + +def test_transition_to_batch_observation_flattening(): + """Test that _default_transition_to_batch correctly flattens observation dict.""" + observation_dict = { + "observation.image.top": torch.randn(1, 3, 128, 128), + "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.state": [1, 2, 3, 4], + } + + transition = { + TransitionKey.OBSERVATION: observation_dict, + TransitionKey.ACTION: "action_data", + TransitionKey.REWARD: 1.5, + TransitionKey.DONE: True, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: {"episode": 42}, + TransitionKey.COMPLEMENTARY_DATA: {}, + } + + batch = _default_transition_to_batch(transition) + + # Check that observation.* keys are flattened back to batch + assert "observation.image.top" in batch + assert "observation.image.left" in batch + assert "observation.state" in batch + + # Check values are preserved + assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"]) + assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"]) + assert batch["observation.state"] == [1, 2, 3, 4] + + # Check other fields are mapped to next.* format + assert batch["action"] == "action_data" + assert batch["next.reward"] == 1.5 + assert batch["next.done"] + assert not batch["next.truncated"] + assert batch["info"] == {"episode": 42} + + +def test_no_observation_keys(): + """Test behavior when there are no observation.* keys.""" + batch = { + "action": "action_data", + "next.reward": 2.0, + "next.done": False, + "next.truncated": True, + "info": {"test": "no_obs"}, + } + + transition = _default_batch_to_transition(batch) + + # Observation should be None when no observation.* keys + assert transition[TransitionKey.OBSERVATION] is None + + # Check other fields + assert transition[TransitionKey.ACTION] == "action_data" + assert transition[TransitionKey.REWARD] == 2.0 + assert not transition[TransitionKey.DONE] + assert transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {"test": "no_obs"} + + # Round trip should work + reconstructed_batch = _default_transition_to_batch(transition) + assert reconstructed_batch["action"] == "action_data" + assert reconstructed_batch["next.reward"] == 2.0 + assert not reconstructed_batch["next.done"] + assert reconstructed_batch["next.truncated"] + assert reconstructed_batch["info"] == {"test": "no_obs"} + + +def test_minimal_batch(): + """Test with minimal batch containing only observation.* and action.""" + batch = {"observation.state": "minimal_state", "action": "minimal_action"} + + transition = _default_batch_to_transition(batch) + + # Check observation + assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} + assert transition[TransitionKey.ACTION] == "minimal_action" + + # Check defaults + assert transition[TransitionKey.REWARD] == 0.0 + assert not transition[TransitionKey.DONE] + assert not transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {} + assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} + + # Round trip + reconstructed_batch = _default_transition_to_batch(transition) + assert reconstructed_batch["observation.state"] == "minimal_state" + assert reconstructed_batch["action"] == "minimal_action" + assert reconstructed_batch["next.reward"] == 0.0 + assert not reconstructed_batch["next.done"] + assert not reconstructed_batch["next.truncated"] + assert reconstructed_batch["info"] == {} + + +def test_empty_batch(): + """Test behavior with empty batch.""" + batch = {} + + transition = _default_batch_to_transition(batch) + + # All fields should have defaults + assert transition[TransitionKey.OBSERVATION] is None + assert transition[TransitionKey.ACTION] is None + assert transition[TransitionKey.REWARD] == 0.0 + assert not transition[TransitionKey.DONE] + assert not transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {} + assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} + + # Round trip + reconstructed_batch = _default_transition_to_batch(transition) + assert reconstructed_batch["action"] is None + assert reconstructed_batch["next.reward"] == 0.0 + assert not reconstructed_batch["next.done"] + assert not reconstructed_batch["next.truncated"] + assert reconstructed_batch["info"] == {} + + +def test_complex_nested_observation(): + """Test with complex nested observation data.""" + batch = { + "observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, + "observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, + "observation.state": torch.randn(7), + "action": torch.randn(8), + "next.reward": 3.14, + "next.done": False, + "next.truncated": True, + "info": {"episode_length": 200, "success": True}, + } + + transition = _default_batch_to_transition(batch) + reconstructed_batch = _default_transition_to_batch(transition) + + # Check that all observation keys are preserved + original_obs_keys = {k for k in batch if k.startswith("observation.")} + reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")} + + assert original_obs_keys == reconstructed_obs_keys + + # Check tensor values + assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"]) + + # Check nested dict with tensors + assert torch.allclose( + batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"] + ) + assert torch.allclose( + batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"] + ) + + # Check action tensor + assert torch.allclose(batch["action"], reconstructed_batch["action"]) + + # Check other fields + assert batch["next.reward"] == reconstructed_batch["next.reward"] + assert batch["next.done"] == reconstructed_batch["next.done"] + assert batch["next.truncated"] == reconstructed_batch["next.truncated"] + assert batch["info"] == reconstructed_batch["info"] + + +def test_custom_converter(): + """Test that custom converters can still be used.""" + + def to_tr(batch): + # Custom converter that modifies the reward + tr = _default_batch_to_transition(batch) + # Double the reward + reward = tr.get(TransitionKey.REWARD, 0.0) + new_tr = tr.copy() + new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0 + return new_tr + + def to_batch(tr): + batch = _default_transition_to_batch(tr) + return batch + + processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch) + + batch = { + "observation.state": torch.randn(1, 4), + "action": torch.randn(1, 2), + "next.reward": 1.0, + "next.done": False, + } + + result = processor(batch) + + # Check the reward was doubled by our custom converter + assert result["next.reward"] == 2.0 + assert torch.allclose(result["observation.state"], batch["observation.state"]) + assert torch.allclose(result["action"], batch["action"]) diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py new file mode 100644 index 000000000..26aea56c7 --- /dev/null +++ b/tests/processor/test_normalize_processor.py @@ -0,0 +1,628 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +import numpy as np +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.processor.normalize_processor import ( + NormalizerProcessor, + UnnormalizerProcessor, + _convert_stats_to_tensors, +) +from lerobot.processor.pipeline import RobotProcessor, TransitionKey + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + + +def test_numpy_conversion(): + stats = { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) + assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) + assert torch.allclose(tensor_stats["observation.image"]["mean"], torch.tensor([0.5, 0.5, 0.5])) + assert torch.allclose(tensor_stats["observation.image"]["std"], torch.tensor([0.2, 0.2, 0.2])) + + +def test_tensor_conversion(): + stats = { + "action": { + "mean": torch.tensor([0.0, 0.0]), + "std": torch.tensor([1.0, 1.0]), + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert tensor_stats["action"]["mean"].dtype == torch.float32 + assert tensor_stats["action"]["std"].dtype == torch.float32 + + +def test_scalar_conversion(): + stats = { + "reward": { + "mean": 0.5, + "std": 0.1, + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5)) + assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1)) + + +def test_list_conversion(): + stats = { + "observation.state": { + "min": [0.0, -1.0, -2.0], + "max": [1.0, 1.0, 2.0], + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0])) + assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0])) + + +def test_unsupported_type(): + stats = { + "bad_key": { + "mean": "string_value", + } + } + with pytest.raises(TypeError, match="Unsupported type"): + _convert_stats_to_tensors(stats) + + +# Helper functions to create feature maps and norm maps +def _create_observation_features(): + return { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + + +def _create_observation_norm_map(): + return { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + } + + +# Fixtures for observation normalisation tests using NormalizerProcessor +@pytest.fixture +def observation_stats(): + return { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "observation.state": { + "min": np.array([0.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + } + + +@pytest.fixture +def observation_normalizer(observation_stats): + """Return a NormalizerProcessor that only has observation stats (no action).""" + features = _create_observation_features() + norm_map = _create_observation_norm_map() + return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) + + +def test_mean_std_normalization(observation_normalizer): + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = observation_normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Check mean/std normalization + expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 + assert torch.allclose(normalized_obs["observation.image"], expected_image) + + +def test_min_max_normalization(observation_normalizer): + observation = { + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = observation_normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Check min/max normalization to [-1, 1] + # For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0 + # For state[1]: 2 * (0.0 - (-1.0)) / (1.0 - (-1.0)) - 1 = 0.0 + expected_state = torch.tensor([0.0, 0.0]) + assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) + + +def test_selective_normalization(observation_stats): + features = _create_observation_features() + norm_map = _create_observation_norm_map() + normalizer = NormalizerProcessor( + features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"} + ) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Only image should be normalized + assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2) + # State should remain unchanged + assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"]) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_device_compatibility(observation_stats): + features = _create_observation_features() + norm_map = _create_observation_norm_map() + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + assert normalized_obs["observation.image"].device.type == "cuda" + + +def test_from_lerobot_dataset(): + # Mock dataset + mock_dataset = Mock() + mock_dataset.meta.stats = { + "observation.image": {"mean": [0.5], "std": [0.2]}, + "action": {"mean": [0.0], "std": [1.0]}, + } + + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "action": PolicyFeature(FeatureType.ACTION, (1,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + + # Both observation and action statistics should be present in tensor stats + assert "observation.image" in normalizer._tensor_stats + assert "action" in normalizer._tensor_stats + + +def test_state_dict_save_load(observation_normalizer): + # Save state + state_dict = observation_normalizer.state_dict() + + # Create new normalizer and load state + features = _create_observation_features() + norm_map = _create_observation_norm_map() + new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) + new_normalizer.load_state_dict(state_dict) + + # Test that it works the same + observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + transition = create_transition(observation=observation) + + result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION] + result2 = new_normalizer(transition)[TransitionKey.OBSERVATION] + + assert torch.allclose(result1["observation.image"], result2["observation.image"]) + + +# Fixtures for ActionUnnormalizer tests +@pytest.fixture +def action_stats_mean_std(): + return { + "mean": np.array([0.0, 0.0, 0.0]), + "std": np.array([1.0, 2.0, 0.5]), + } + + +@pytest.fixture +def action_stats_min_max(): + return { + "min": np.array([-1.0, -2.0, 0.0]), + "max": np.array([1.0, 2.0, 1.0]), + } + + +def _create_action_features(): + return { + "action": PolicyFeature(FeatureType.ACTION, (3,)), + } + + +def _create_action_norm_map_mean_std(): + return { + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + +def _create_action_norm_map_min_max(): + return { + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + + +def test_mean_std_unnormalization(action_stats_mean_std): + features = _create_action_features() + norm_map = _create_action_norm_map_mean_std() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + ) + + normalized_action = torch.tensor([1.0, -0.5, 2.0]) + transition = create_transition(action=normalized_action) + + unnormalized_transition = unnormalizer(transition) + unnormalized_action = unnormalized_transition[TransitionKey.ACTION] + + # action * std + mean + expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0]) + assert torch.allclose(unnormalized_action, expected) + + +def test_min_max_unnormalization(action_stats_min_max): + features = _create_action_features() + norm_map = _create_action_norm_map_min_max() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_min_max} + ) + + # Actions in [-1, 1] + normalized_action = torch.tensor([0.0, -1.0, 1.0]) + transition = create_transition(action=normalized_action) + + unnormalized_transition = unnormalizer(transition) + unnormalized_action = unnormalized_transition[TransitionKey.ACTION] + + # Map from [-1, 1] to [min, max] + # (action + 1) / 2 * (max - min) + min + expected = torch.tensor( + [ + (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0), # 0.0 + (-1.0 + 1) / 2 * (2.0 - (-2.0)) + (-2.0), # -2.0 + (1.0 + 1) / 2 * (1.0 - 0.0) + 0.0, # 1.0 + ] + ) + assert torch.allclose(unnormalized_action, expected) + + +def test_numpy_action_input(action_stats_mean_std): + features = _create_action_features() + norm_map = _create_action_norm_map_mean_std() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + ) + + normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32) + transition = create_transition(action=normalized_action) + + unnormalized_transition = unnormalizer(transition) + unnormalized_action = unnormalized_transition[TransitionKey.ACTION] + + assert isinstance(unnormalized_action, torch.Tensor) + expected = torch.tensor([1.0, -1.0, 1.0]) + assert torch.allclose(unnormalized_action, expected) + + +def test_none_action(action_stats_mean_std): + features = _create_action_features() + norm_map = _create_action_norm_map_mean_std() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + ) + + transition = create_transition() + result = unnormalizer(transition) + + # Should return transition unchanged + assert result == transition + + +def test_action_from_lerobot_dataset(): + mock_dataset = Mock() + mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}} + features = {"action": PolicyFeature(FeatureType.ACTION, (1,))} + norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} + unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + assert "mean" in unnormalizer._tensor_stats["action"] + + +# Fixtures for NormalizerProcessor tests +@pytest.fixture +def full_stats(): + return { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "observation.state": { + "min": np.array([0.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + "action": { + "mean": np.array([0.0, 0.0]), + "std": np.array([1.0, 2.0]), + }, + } + + +def _create_full_features(): + return { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + + +def _create_full_norm_map(): + return { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + +@pytest.fixture +def normalizer_processor(full_stats): + features = _create_full_features() + norm_map = _create_full_norm_map() + return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats) + + +def test_combined_normalization(normalizer_processor): + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition( + observation=observation, + action=action, + reward=1.0, + done=False, + truncated=False, + info={}, + complementary_data={}, + ) + + processed_transition = normalizer_processor(transition) + + # Check normalized observations + processed_obs = processed_transition[TransitionKey.OBSERVATION] + expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 + assert torch.allclose(processed_obs["observation.image"], expected_image) + + # Check normalized action + processed_action = processed_transition[TransitionKey.ACTION] + expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0]) + assert torch.allclose(processed_action, expected_action) + + # Check other fields remain unchanged + assert processed_transition[TransitionKey.REWARD] == 1.0 + assert not processed_transition[TransitionKey.DONE] + + +def test_processor_from_lerobot_dataset(full_stats): + # Mock dataset + mock_dataset = Mock() + mock_dataset.meta.stats = full_stats + + features = _create_full_features() + norm_map = _create_full_norm_map() + + processor = NormalizerProcessor.from_lerobot_dataset( + mock_dataset, features, norm_map, normalize_keys={"observation.image"} + ) + + assert processor.normalize_keys == {"observation.image"} + assert "observation.image" in processor._tensor_stats + assert "action" in processor._tensor_stats + + +def test_get_config(full_stats): + features = _create_full_features() + norm_map = _create_full_norm_map() + processor = NormalizerProcessor( + features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + ) + + config = processor.get_config() + expected_config = { + "normalize_keys": ["observation.image"], + "eps": 1e-6, + "features": { + "observation.image": {"type": "VISUAL", "shape": (3, 96, 96)}, + "observation.state": {"type": "STATE", "shape": (2,)}, + "action": {"type": "ACTION", "shape": (2,)}, + }, + "norm_map": { + "VISUAL": "MEAN_STD", + "STATE": "MIN_MAX", + "ACTION": "MEAN_STD", + }, + } + assert config == expected_config + + +def test_integration_with_robot_processor(normalizer_processor): + """Test integration with RobotProcessor pipeline""" + robot_processor = RobotProcessor([normalizer_processor]) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition( + observation=observation, + action=action, + reward=1.0, + done=False, + truncated=False, + info={}, + complementary_data={}, + ) + + processed_transition = robot_processor(transition) + + # Verify the processing worked + assert isinstance(processed_transition[TransitionKey.OBSERVATION], dict) + assert isinstance(processed_transition[TransitionKey.ACTION], torch.Tensor) + + +# Edge case tests +def test_empty_observation(): + stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + transition = create_transition() + result = normalizer(transition) + + assert result == transition + + +def test_empty_stats(): + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) + observation = {"observation.image": torch.tensor([0.5])} + transition = create_transition(observation=observation) + + result = normalizer(transition) + # Should return observation unchanged since no stats are available + assert torch.allclose( + result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"] + ) + + +def test_partial_stats(): + """If statistics are incomplete, the value should pass through unchanged.""" + stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + observation = {"observation.image": torch.tensor([0.7])} + transition = create_transition(observation=observation) + + processed = normalizer(transition)[TransitionKey.OBSERVATION] + assert torch.allclose(processed["observation.image"], observation["observation.image"]) + + +def test_missing_action_stats_no_error(): + mock_dataset = Mock() + mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + + processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + # The tensor stats should not contain the 'action' key + assert "action" not in processor._tensor_stats + + +def test_serialization_roundtrip(full_stats): + """Test that features and norm_map can be serialized and deserialized correctly.""" + features = _create_full_features() + norm_map = _create_full_norm_map() + original_processor = NormalizerProcessor( + features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + ) + + # Get config (serialization) + config = original_processor.get_config() + + # Create a new processor from the config (deserialization) + new_processor = NormalizerProcessor( + features=config["features"], + norm_map=config["norm_map"], + stats=full_stats, + normalize_keys=set(config["normalize_keys"]), + eps=config["eps"], + ) + + # Test that both processors work the same way + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition( + observation=observation, + action=action, + reward=1.0, + done=False, + truncated=False, + info={}, + complementary_data={}, + ) + + result1 = original_processor(transition) + result2 = new_processor(transition) + + # Compare results + assert torch.allclose( + result1[TransitionKey.OBSERVATION]["observation.image"], + result2[TransitionKey.OBSERVATION]["observation.image"], + ) + assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) + + # Verify features and norm_map are correctly reconstructed + assert new_processor.features.keys() == original_processor.features.keys() + for key in new_processor.features: + assert new_processor.features[key].type == original_processor.features[key].type + assert new_processor.features[key].shape == original_processor.features[key].shape + + assert new_processor.norm_map == original_processor.norm_map diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py new file mode 100644 index 000000000..e48b6bc08 --- /dev/null +++ b/tests/processor/test_observation_processor.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + +from lerobot.configs.types import FeatureType +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.processor import VanillaObservationProcessor +from lerobot.processor.pipeline import TransitionKey +from tests.conftest import assert_contract_is_typed + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + + +def test_process_single_image(): + """Test processing a single image.""" + processor = VanillaObservationProcessor() + + # Create a mock image (H, W, C) format, uint8 + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + + observation = {"pixels": image} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that the image was processed correctly + assert "observation.image" in processed_obs + processed_img = processed_obs["observation.image"] + + # Check shape: should be (1, 3, 64, 64) - batch, channels, height, width + assert processed_img.shape == (1, 3, 64, 64) + + # Check dtype and range + assert processed_img.dtype == torch.float32 + assert processed_img.min() >= 0.0 + assert processed_img.max() <= 1.0 + + +def test_process_image_dict(): + """Test processing multiple images in a dictionary.""" + processor = VanillaObservationProcessor() + + # Create mock images + image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) + + observation = {"pixels": {"camera1": image1, "camera2": image2}} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that both images were processed + assert "observation.images.camera1" in processed_obs + assert "observation.images.camera2" in processed_obs + + # Check shapes + assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32) + assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48) + + +def test_process_batched_image(): + """Test processing already batched images.""" + processor = VanillaObservationProcessor() + + # Create a batched image (B, H, W, C) + image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8) + + observation = {"pixels": image} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that batch dimension is preserved + assert processed_obs["observation.image"].shape == (2, 3, 64, 64) + + +def test_invalid_image_format(): + """Test error handling for invalid image formats.""" + processor = VanillaObservationProcessor() + + # Test wrong channel order (channels first) + image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8) + observation = {"pixels": image} + transition = create_transition(observation=observation) + + with pytest.raises(ValueError, match="Expected channel-last images"): + processor(transition) + + +def test_invalid_image_dtype(): + """Test error handling for invalid image dtype.""" + processor = VanillaObservationProcessor() + + # Test wrong dtype + image = np.random.rand(64, 64, 3).astype(np.float32) + observation = {"pixels": image} + transition = create_transition(observation=observation) + + with pytest.raises(ValueError, match="Expected torch.uint8 images"): + processor(transition) + + +def test_no_pixels_in_observation(): + """Test processor when no pixels are in observation.""" + processor = VanillaObservationProcessor() + + observation = {"other_data": np.array([1, 2, 3])} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Should preserve other data unchanged + assert "other_data" in processed_obs + np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3])) + + +def test_none_observation(): + """Test processor with None observation.""" + processor = VanillaObservationProcessor() + + transition = create_transition() + result = processor(transition) + + assert result == transition + + +def test_serialization_methods(): + """Test serialization methods.""" + processor = VanillaObservationProcessor() + + # Test get_config + config = processor.get_config() + assert isinstance(config, dict) + + # Test state_dict + state = processor.state_dict() + assert isinstance(state, dict) + + # Test load_state_dict (should not raise) + processor.load_state_dict(state) + + # Test reset (should not raise) + processor.reset() + + +def test_process_environment_state(): + """Test processing environment_state.""" + processor = VanillaObservationProcessor() + + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + observation = {"environment_state": env_state} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that environment_state was renamed and processed + assert "observation.environment_state" in processed_obs + assert "environment_state" not in processed_obs + + processed_state = processed_obs["observation.environment_state"] + assert processed_state.shape == (1, 3) # Batch dimension added + assert processed_state.dtype == torch.float32 + torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]])) + + +def test_process_agent_pos(): + """Test processing agent_pos.""" + processor = VanillaObservationProcessor() + + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + observation = {"agent_pos": agent_pos} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that agent_pos was renamed and processed + assert "observation.state" in processed_obs + assert "agent_pos" not in processed_obs + + processed_state = processed_obs["observation.state"] + assert processed_state.shape == (1, 3) # Batch dimension added + assert processed_state.dtype == torch.float32 + torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]])) + + +def test_process_batched_states(): + """Test processing already batched states.""" + processor = VanillaObservationProcessor() + + env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32) + + observation = {"environment_state": env_state, "agent_pos": agent_pos} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that batch dimensions are preserved + assert processed_obs["observation.environment_state"].shape == (2, 2) + assert processed_obs["observation.state"].shape == (2, 2) + + +def test_process_both_states(): + """Test processing both environment_state and agent_pos.""" + processor = VanillaObservationProcessor() + + env_state = np.array([1.0, 2.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5], dtype=np.float32) + + observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that both states were processed + assert "observation.environment_state" in processed_obs + assert "observation.state" in processed_obs + + # Check that original keys were removed + assert "environment_state" not in processed_obs + assert "agent_pos" not in processed_obs + + # Check that other data was preserved + assert processed_obs["other_data"] == "keep_me" + + +def test_no_states_in_observation(): + """Test processor when no states are in observation.""" + processor = VanillaObservationProcessor() + + observation = {"other_data": np.array([1, 2, 3])} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Should preserve data unchanged + np.testing.assert_array_equal(processed_obs, observation) + + +def test_complete_observation_processing(): + """Test processing a complete observation with both images and states.""" + processor = VanillaObservationProcessor() + + # Create mock data + image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + + observation = { + "pixels": image, + "environment_state": env_state, + "agent_pos": agent_pos, + "other_data": "preserve_me", + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that image was processed + assert "observation.image" in processed_obs + assert processed_obs["observation.image"].shape == (1, 3, 32, 32) + + # Check that states were processed + assert "observation.environment_state" in processed_obs + assert "observation.state" in processed_obs + + # Check that original keys were removed + assert "pixels" not in processed_obs + assert "environment_state" not in processed_obs + assert "agent_pos" not in processed_obs + + # Check that other data was preserved + assert processed_obs["other_data"] == "preserve_me" + + +def test_image_only_processing(): + """Test processing observation with only images.""" + processor = VanillaObservationProcessor() + + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + observation = {"pixels": image} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert "observation.image" in processed_obs + assert len(processed_obs) == 1 + + +def test_state_only_processing(): + """Test processing observation with only states.""" + processor = VanillaObservationProcessor() + + agent_pos = np.array([1.0, 2.0], dtype=np.float32) + observation = {"agent_pos": agent_pos} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert "observation.state" in processed_obs + assert "agent_pos" not in processed_obs + + +def test_empty_observation(): + """Test processing empty observation.""" + processor = VanillaObservationProcessor() + + observation = {} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert processed_obs == {} + + +def test_equivalent_to_original_function(): + """Test that ObservationProcessor produces equivalent results to preprocess_observation.""" + # Import the original function for comparison + from lerobot.envs.utils import preprocess_observation + + processor = VanillaObservationProcessor() + + # Create test data similar to what the original function expects + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + + observation = {"pixels": image, "environment_state": env_state, "agent_pos": agent_pos} + + # Process with original function + original_result = preprocess_observation(observation) + + # Process with new processor + transition = create_transition(observation=observation) + processor_result = processor(transition)[TransitionKey.OBSERVATION] + + # Compare results + assert set(original_result.keys()) == set(processor_result.keys()) + + for key in original_result: + torch.testing.assert_close(original_result[key], processor_result[key]) + + +def test_equivalent_with_image_dict(): + """Test equivalence with dictionary of images.""" + from lerobot.envs.utils import preprocess_observation + + processor = VanillaObservationProcessor() + + # Create test data with multiple cameras + image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) + agent_pos = np.array([1.0, 2.0], dtype=np.float32) + + observation = {"pixels": {"cam1": image1, "cam2": image2}, "agent_pos": agent_pos} + + # Process with original function + original_result = preprocess_observation(observation) + + # Process with new processor + transition = create_transition(observation=observation) + processor_result = processor(transition)[TransitionKey.OBSERVATION] + + # Compare results + assert set(original_result.keys()) == set(processor_result.keys()) + + for key in original_result: + torch.testing.assert_close(original_result[key], processor_result[key]) + + +def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory): + processor = VanillaObservationProcessor() + features = { + "pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = processor.feature_contract(features.copy()) + + assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"] + assert "pixels" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory): + processor = VanillaObservationProcessor() + features = { + "observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = processor.feature_contract(features.copy()) + + assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"] + assert "observation.pixels" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory): + processor = VanillaObservationProcessor() + features = { + "pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (7,)), + } + out = processor.feature_contract(features.copy()) + + assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"] + assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"] + assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"] + assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory): + processor = VanillaObservationProcessor() + features = { + "environment_state": policy_feature_factory(FeatureType.STATE, (3,)), + "agent_pos": policy_feature_factory(FeatureType.STATE, (7,)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = processor.feature_contract(features.copy()) + + assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"] + assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"] + assert "environment_state" not in out and "agent_pos" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory): + proc = VanillaObservationProcessor() + features = { + "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), + "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), + } + out = proc.feature_contract(features.copy()) + + assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"] + assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"] + assert "environment_state" not in out and "agent_pos" not in out + assert_contract_is_typed(out) diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py new file mode 100644 index 000000000..5665d5a7d --- /dev/null +++ b/tests/processor/test_pipeline.py @@ -0,0 +1,1919 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.nn as nn + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor +from lerobot.processor.pipeline import TransitionKey +from tests.conftest import assert_contract_is_typed + + +def create_transition( + observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info if info is not None else {}, + TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, + } + + +@dataclass +class MockStep: + """Mock pipeline step for testing - demonstrates best practices. + + This example shows the proper separation: + - JSON-serializable attributes (name, counter) go in get_config() + - Only torch tensors go in state_dict() + + Note: The counter is part of the configuration, so it will be restored + when the step is recreated from config during loading. + """ + + name: str = "mock_step" + counter: int = 0 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Add a counter to the complementary_data.""" + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + comp_data = {} if comp_data is None else dict(comp_data) # Make a copy + + comp_data[f"{self.name}_counter"] = self.counter + self.counter += 1 + + # Create a new transition with updated complementary_data + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition + + def get_config(self) -> dict[str, Any]: + # Return all JSON-serializable attributes that should be persisted + # These will be passed to __init__ when loading + return {"name": self.name, "counter": self.counter} + + def state_dict(self) -> dict[str, torch.Tensor]: + # Only return torch tensors (empty in this case since we have no tensor state) + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + # No tensor state to load + pass + + def reset(self) -> None: + self.counter = 0 + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +@dataclass +class MockStepWithoutOptionalMethods: + """Mock step that only implements the required __call__ method.""" + + multiplier: float = 2.0 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Multiply reward by multiplier.""" + reward = transition.get(TransitionKey.REWARD) + + if reward is not None: + new_transition = transition.copy() + new_transition[TransitionKey.REWARD] = reward * self.multiplier + return new_transition + + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +@dataclass +class MockStepWithTensorState: + """Mock step demonstrating mixed JSON attributes and tensor state.""" + + name: str = "tensor_step" + learning_rate: float = 0.01 + window_size: int = 10 + + def __init__(self, name: str = "tensor_step", learning_rate: float = 0.01, window_size: int = 10): + self.name = name + self.learning_rate = learning_rate + self.window_size = window_size + # Tensor state + self.running_mean = torch.zeros(window_size) + self.running_count = torch.tensor(0) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Update running statistics.""" + reward = transition.get(TransitionKey.REWARD) + + if reward is not None: + # Update running mean + idx = self.running_count % self.window_size + self.running_mean[idx] = reward + self.running_count += 1 + + return transition + + def get_config(self) -> dict[str, Any]: + # Only JSON-serializable attributes + return { + "name": self.name, + "learning_rate": self.learning_rate, + "window_size": self.window_size, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + # Only tensor state + return { + "running_mean": self.running_mean, + "running_count": self.running_count, + } + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + self.running_mean = state["running_mean"] + self.running_count = state["running_count"] + + def reset(self) -> None: + self.running_mean.zero_() + self.running_count.zero_() + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +def test_empty_pipeline(): + """Test pipeline with no steps.""" + pipeline = RobotProcessor() + + transition = create_transition() + result = pipeline(transition) + + assert result == transition + assert len(pipeline) == 0 + + +def test_single_step_pipeline(): + """Test pipeline with a single step.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + transition = create_transition() + result = pipeline(transition) + + assert len(pipeline) == 1 + assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0 + + # Call again to test counter increment + result = pipeline(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 1 + + +def test_multiple_steps_pipeline(): + """Test pipeline with multiple steps.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotProcessor([step1, step2]) + + transition = create_transition() + result = pipeline(transition) + + assert len(pipeline) == 2 + assert result[TransitionKey.COMPLEMENTARY_DATA]["step1_counter"] == 0 + assert result[TransitionKey.COMPLEMENTARY_DATA]["step2_counter"] == 0 + + +def test_invalid_transition_format(): + """Test pipeline with invalid transition format.""" + pipeline = RobotProcessor([MockStep()]) + + # Test with wrong type (tuple instead of dict) + with pytest.raises(ValueError, match="EnvTransition must be a dictionary"): + pipeline((None, None, 0.0, False, False, {}, {})) # Tuple instead of dict + + # Test with wrong type (string) + with pytest.raises(ValueError, match="EnvTransition must be a dictionary"): + pipeline("not a dict") + + +def test_step_through(): + """Test step_through method with dict input.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotProcessor([step1, step2]) + + transition = create_transition() + + results = list(pipeline.step_through(transition)) + + assert len(results) == 3 # Original + 2 steps + assert results[0] == transition # Original + assert "step1_counter" in results[1][TransitionKey.COMPLEMENTARY_DATA] # After step1 + assert "step2_counter" in results[2][TransitionKey.COMPLEMENTARY_DATA] # After step2 + + # Ensure all results are dicts (same format as input) + for result in results: + assert isinstance(result, dict) + assert all(isinstance(k, TransitionKey) for k in result.keys()) + + +def test_step_through_with_dict(): + """Test step_through method with dict input.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotProcessor([step1, step2]) + + batch = { + "observation.image": None, + "action": None, + "next.reward": 0.0, + "next.done": False, + "next.truncated": False, + "info": {}, + } + + results = list(pipeline.step_through(batch)) + + assert len(results) == 3 # Original + 2 steps + + # Ensure all results are EnvTransition dicts (regardless of input format) + for result in results: + assert isinstance(result, dict) + # Check that keys are TransitionKey enums or at least valid transition keys + for key in result: + assert key in [ + TransitionKey.OBSERVATION, + TransitionKey.ACTION, + TransitionKey.REWARD, + TransitionKey.DONE, + TransitionKey.TRUNCATED, + TransitionKey.INFO, + TransitionKey.COMPLEMENTARY_DATA, + ] + + # Check that the processing worked - verify step counters in complementary_data + assert results[1].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0 + assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0 + assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step2_counter") == 0 + + +def test_step_through_no_hooks(): + """Test that step_through doesn't execute hooks.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + hook_calls = [] + + def tracking_hook(idx: int, transition: EnvTransition): + hook_calls.append(f"hook_called_step_{idx}") + + # Register hooks + pipeline.register_before_step_hook(tracking_hook) + pipeline.register_after_step_hook(tracking_hook) + + # Use step_through + transition = create_transition() + results = list(pipeline.step_through(transition)) + + # Verify step was executed (counter should increment) + assert len(results) == 2 # Initial + 1 step + assert results[1][TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0 + + # Verify hooks were NOT called + assert len(hook_calls) == 0 + + # Now use __call__ to verify hooks ARE called there + hook_calls.clear() + pipeline(transition) + + # Verify hooks were called (before and after for 1 step = 2 calls) + assert len(hook_calls) == 2 + assert hook_calls == ["hook_called_step_0", "hook_called_step_0"] + + +def test_indexing(): + """Test pipeline indexing.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotProcessor([step1, step2]) + + # Test integer indexing + assert pipeline[0] is step1 + assert pipeline[1] is step2 + + # Test slice indexing + sub_pipeline = pipeline[0:1] + assert isinstance(sub_pipeline, RobotProcessor) + assert len(sub_pipeline) == 1 + assert sub_pipeline[0] is step1 + + +def test_hooks(): + """Test before/after step hooks.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + before_calls = [] + after_calls = [] + + def before_hook(idx: int, transition: EnvTransition): + before_calls.append(idx) + + def after_hook(idx: int, transition: EnvTransition): + after_calls.append(idx) + + pipeline.register_before_step_hook(before_hook) + pipeline.register_after_step_hook(after_hook) + + transition = create_transition() + pipeline(transition) + + assert before_calls == [0] + assert after_calls == [0] + + +def test_unregister_hooks(): + """Test unregistering hooks from the pipeline.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + # Test before_step_hook + before_calls = [] + + def before_hook(idx: int, transition: EnvTransition): + before_calls.append(idx) + + pipeline.register_before_step_hook(before_hook) + + # Verify hook is registered + transition = create_transition() + pipeline(transition) + assert len(before_calls) == 1 + + # Unregister and verify it's no longer called + pipeline.unregister_before_step_hook(before_hook) + before_calls.clear() + pipeline(transition) + assert len(before_calls) == 0 + + # Test after_step_hook + after_calls = [] + + def after_hook(idx: int, transition: EnvTransition): + after_calls.append(idx) + + pipeline.register_after_step_hook(after_hook) + pipeline(transition) + assert len(after_calls) == 1 + + pipeline.unregister_after_step_hook(after_hook) + after_calls.clear() + pipeline(transition) + assert len(after_calls) == 0 + + +def test_unregister_nonexistent_hook(): + """Test error handling when unregistering hooks that don't exist.""" + pipeline = RobotProcessor([MockStep()]) + + def some_hook(idx: int, transition: EnvTransition): + pass + + def reset_hook(): + pass + + # Test unregistering hooks that were never registered + with pytest.raises(ValueError, match="not found in before_step_hooks"): + pipeline.unregister_before_step_hook(some_hook) + + with pytest.raises(ValueError, match="not found in after_step_hooks"): + pipeline.unregister_after_step_hook(some_hook) + + +def test_multiple_hooks_and_selective_unregister(): + """Test registering multiple hooks and selectively unregistering them.""" + pipeline = RobotProcessor([MockStep("step1"), MockStep("step2")]) + + calls_1 = [] + calls_2 = [] + calls_3 = [] + + def hook1(idx: int, transition: EnvTransition): + calls_1.append(f"hook1_step{idx}") + + def hook2(idx: int, transition: EnvTransition): + calls_2.append(f"hook2_step{idx}") + + def hook3(idx: int, transition: EnvTransition): + calls_3.append(f"hook3_step{idx}") + + # Register multiple hooks + pipeline.register_before_step_hook(hook1) + pipeline.register_before_step_hook(hook2) + pipeline.register_before_step_hook(hook3) + + # Run pipeline - all hooks should be called for both steps + transition = create_transition() + pipeline(transition) + + assert calls_1 == ["hook1_step0", "hook1_step1"] + assert calls_2 == ["hook2_step0", "hook2_step1"] + assert calls_3 == ["hook3_step0", "hook3_step1"] + + # Clear calls + calls_1.clear() + calls_2.clear() + calls_3.clear() + + # Unregister middle hook + pipeline.unregister_before_step_hook(hook2) + + # Run again - only hook1 and hook3 should be called + pipeline(transition) + + assert calls_1 == ["hook1_step0", "hook1_step1"] + assert calls_2 == [] # hook2 was unregistered + assert calls_3 == ["hook3_step0", "hook3_step1"] + + +def test_hook_execution_order_documentation(): + """Test and document that hooks are executed sequentially in registration order.""" + pipeline = RobotProcessor([MockStep("step")]) + + execution_order = [] + + def hook_a(idx: int, transition: EnvTransition): + execution_order.append("A") + + def hook_b(idx: int, transition: EnvTransition): + execution_order.append("B") + + def hook_c(idx: int, transition: EnvTransition): + execution_order.append("C") + + # Register in specific order: A, B, C + pipeline.register_before_step_hook(hook_a) + pipeline.register_before_step_hook(hook_b) + pipeline.register_before_step_hook(hook_c) + + transition = create_transition() + pipeline(transition) + + # Verify execution order matches registration order + assert execution_order == ["A", "B", "C"] + + # Test that after unregistering B and re-registering it, it goes to the end + pipeline.unregister_before_step_hook(hook_b) + execution_order.clear() + + pipeline(transition) + assert execution_order == ["A", "C"] # B is gone + + # Re-register B - it should now be at the end + pipeline.register_before_step_hook(hook_b) + execution_order.clear() + + pipeline(transition) + assert execution_order == ["A", "C", "B"] # B is now last + + +def test_save_and_load_pretrained(): + """Test saving and loading pipeline. + + This test demonstrates that JSON-serializable attributes (like counter) + are saved in the config and restored when the step is recreated. + """ + step1 = MockStep("step1") + step2 = MockStep("step2") + + # Increment counters to have some state + step1.counter = 5 + step2.counter = 10 + + pipeline = RobotProcessor([step1, step2], name="TestPipeline") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Check files were created + config_path = Path(tmp_dir) / "testpipeline.json" # Based on name="TestPipeline" + assert config_path.exists() + + # Check config content + with open(config_path) as f: + config = json.load(f) + + assert config["name"] == "TestPipeline" + assert len(config["steps"]) == 2 + + # Verify counters are saved in config, not in separate state files + assert config["steps"][0]["config"]["counter"] == 5 + assert config["steps"][1]["config"]["counter"] == 10 + + # Load pipeline + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + + assert loaded_pipeline.name == "TestPipeline" + assert len(loaded_pipeline) == 2 + + # Check that counter was restored from config + assert loaded_pipeline.steps[0].counter == 5 + assert loaded_pipeline.steps[1].counter == 10 + + +def test_step_without_optional_methods(): + """Test pipeline with steps that don't implement optional methods.""" + step = MockStepWithoutOptionalMethods(multiplier=3.0) + pipeline = RobotProcessor([step]) + + transition = create_transition(reward=2.0) + result = pipeline(transition) + + assert result[TransitionKey.REWARD] == 6.0 # 2.0 * 3.0 + + # Reset should work even if step doesn't implement reset + pipeline.reset() + + # Save/load should work even without optional methods + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + assert len(loaded_pipeline) == 1 + + +def test_mixed_json_and_tensor_state(): + """Test step with both JSON attributes and tensor state.""" + step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5) + pipeline = RobotProcessor([step]) + + # Process some transitions with rewards + for i in range(10): + transition = create_transition(reward=float(i)) + pipeline(transition) + + # Check state + assert step.running_count.item() == 10 + assert step.learning_rate == 0.05 + + # Save and load + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Check that both config and state files were created + config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor" + state_path = Path(tmp_dir) / "robotprocessor_step_0.safetensors" + assert config_path.exists() + assert state_path.exists() + + # Load and verify + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_step = loaded_pipeline.steps[0] + + # Check JSON attributes were restored + assert loaded_step.name == "stats" + assert loaded_step.learning_rate == 0.05 + assert loaded_step.window_size == 5 + + # Check tensor state was restored + assert loaded_step.running_count.item() == 10 + assert torch.allclose(loaded_step.running_mean, step.running_mean) + + +class MockModuleStep(nn.Module): + """Mock step that inherits from nn.Module to test state_dict handling of module parameters.""" + + def __init__(self, input_dim: int = 10, hidden_dim: int = 5): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.linear = nn.Linear(input_dim, hidden_dim) + self.running_mean = nn.Parameter(torch.zeros(hidden_dim), requires_grad=False) + self.counter = 0 # Non-tensor state + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Process transition and update running mean.""" + obs = transition.get(TransitionKey.OBSERVATION) + + if obs is not None and isinstance(obs, torch.Tensor): + # Process observation through linear layer + processed = self.forward(obs[:, : self.input_dim]) + + # Update running mean in-place (don't reassign the parameter) + with torch.no_grad(): + self.running_mean.mul_(0.9).add_(processed.mean(dim=0), alpha=0.1) + + self.counter += 1 + + return transition + + def get_config(self) -> dict[str, Any]: + return { + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "counter": self.counter, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + """Override to return all module parameters and buffers.""" + # Get the module's state dict (includes all parameters and buffers) + return super().state_dict() + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Override to load all module parameters and buffers.""" + # Use the module's load_state_dict + super().load_state_dict(state) + + def reset(self) -> None: + self.running_mean.zero_() + self.counter = 0 + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +class MockNonModuleStepWithState: + """Mock step that explicitly does NOT inherit from nn.Module but has tensor state. + + This tests the state_dict/load_state_dict path for regular classes. + """ + + def __init__(self, name: str = "non_module_step", feature_dim: int = 10): + self.name = name + self.feature_dim = feature_dim + + # Initialize tensor state - these are regular tensors, not nn.Parameters + self.weights = torch.randn(feature_dim, feature_dim) + self.bias = torch.zeros(feature_dim) + self.running_stats = torch.zeros(feature_dim) + self.step_count = torch.tensor(0) + + # Non-tensor state + self.config_value = 42 + self.history = [] + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Process transition using tensor operations.""" + obs = transition.get(TransitionKey.OBSERVATION) + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + if obs is not None and isinstance(obs, torch.Tensor) and obs.numel() >= self.feature_dim: + # Perform some tensor operations + flat_obs = obs.flatten()[: self.feature_dim] + + # Simple linear transformation (ensure dimensions match for matmul) + output = torch.matmul(self.weights.T, flat_obs) + self.bias + + # Update running stats + self.running_stats = 0.9 * self.running_stats + 0.1 * output + self.step_count += 1 + + # Add to complementary data + comp_data = {} if comp_data is None else dict(comp_data) + comp_data[f"{self.name}_mean_output"] = output.mean().item() + comp_data[f"{self.name}_steps"] = self.step_count.item() + + # Return updated transition + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition + + return transition + + def get_config(self) -> dict[str, Any]: + return { + "name": self.name, + "feature_dim": self.feature_dim, + "config_value": self.config_value, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return only tensor state.""" + return { + "weights": self.weights, + "bias": self.bias, + "running_stats": self.running_stats, + "step_count": self.step_count, + } + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Load tensor state.""" + self.weights = state["weights"] + self.bias = state["bias"] + self.running_stats = state["running_stats"] + self.step_count = state["step_count"] + + def reset(self) -> None: + """Reset statistics but keep learned parameters.""" + self.running_stats.zero_() + self.step_count.zero_() + self.history.clear() + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +# Tests for overrides functionality +@dataclass +class MockStepWithNonSerializableParam: + """Mock step that requires a non-serializable parameter.""" + + def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None): + self.name = name + # Add type validation for multiplier + if isinstance(multiplier, str): + raise ValueError(f"multiplier must be a number, got string '{multiplier}'") + if not isinstance(multiplier, (int, float)): + raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}") + self.multiplier = float(multiplier) + self.env = env # Non-serializable parameter (like gym.Env) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + reward = transition.get(TransitionKey.REWARD) + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + # Use the env parameter if provided + if self.env is not None: + comp_data = {} if comp_data is None else dict(comp_data) + comp_data[f"{self.name}_env_info"] = str(self.env) + + # Apply multiplier to reward + new_transition = transition.copy() + if reward is not None: + new_transition[TransitionKey.REWARD] = reward * self.multiplier + + if comp_data: + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + + return new_transition + + def get_config(self) -> dict[str, Any]: + # Note: env is intentionally NOT included here as it's not serializable + return { + "name": self.name, + "multiplier": self.multiplier, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +@ProcessorStepRegistry.register("registered_mock_step") +@dataclass +class RegisteredMockStep: + """Mock step registered in the registry.""" + + value: int = 42 + device: str = "cpu" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + comp_data = {} if comp_data is None else dict(comp_data) + comp_data["registered_step_value"] = self.value + comp_data["registered_step_device"] = self.device + + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition + + def get_config(self) -> dict[str, Any]: + return { + "value": self.value, + "device": self.device, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +class MockEnvironment: + """Mock environment for testing non-serializable parameters.""" + + def __init__(self, name: str): + self.name = name + + def __str__(self): + return f"MockEnvironment({self.name})" + + +def test_from_pretrained_with_overrides(): + """Test loading processor with parameter overrides.""" + # Create a processor with steps that need overrides + env_step = MockStepWithNonSerializableParam(name="env_step", multiplier=2.0) + registered_step = RegisteredMockStep(value=100, device="cpu") + + pipeline = RobotProcessor([env_step, registered_step], name="TestOverrides") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save the pipeline + pipeline.save_pretrained(tmp_dir) + + # Create a mock environment for override + mock_env = MockEnvironment("test_env") + + # Load with overrides + overrides = { + "MockStepWithNonSerializableParam": { + "env": mock_env, + "multiplier": 3.0, # Override the multiplier too + }, + "registered_mock_step": {"device": "cuda", "value": 200}, + } + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + # Verify the pipeline was loaded correctly + assert len(loaded_pipeline) == 2 + assert loaded_pipeline.name == "TestOverrides" + + # Test the loaded steps + transition = create_transition(reward=1.0) + result = loaded_pipeline(transition) + + # Check that overrides were applied + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert "env_step_env_info" in comp_data + assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)" + assert comp_data["registered_step_value"] == 200 + assert comp_data["registered_step_device"] == "cuda" + + # Check that multiplier override was applied + assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 (overridden multiplier) + + +def test_from_pretrained_with_partial_overrides(): + """Test loading processor with overrides for only some steps.""" + step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0) + step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0) + + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override only one step + overrides = {"MockStepWithNonSerializableParam": {"multiplier": 5.0}} + + # The current implementation applies overrides to ALL steps with the same class name + # Both steps will get the override + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + transition = create_transition(reward=1.0) + result = loaded_pipeline(transition) + + # The reward should be affected by both steps, both getting the override + # First step: 1.0 * 5.0 = 5.0 (overridden) + # Second step: 5.0 * 5.0 = 25.0 (also overridden) + assert result[TransitionKey.REWARD] == 25.0 + + +def test_from_pretrained_invalid_override_key(): + """Test that invalid override keys raise KeyError.""" + step = MockStepWithNonSerializableParam() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override a non-existent step + overrides = {"NonExistentStep": {"param": "value"}} + + with pytest.raises(KeyError, match="Override keys.*do not match any step"): + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + +def test_from_pretrained_multiple_invalid_override_keys(): + """Test that multiple invalid override keys are reported.""" + step = MockStepWithNonSerializableParam() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override multiple non-existent steps + overrides = {"NonExistentStep1": {"param": "value1"}, "NonExistentStep2": {"param": "value2"}} + + with pytest.raises(KeyError) as exc_info: + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + error_msg = str(exc_info.value) + assert "NonExistentStep1" in error_msg + assert "NonExistentStep2" in error_msg + assert "Available step keys" in error_msg + + +def test_from_pretrained_registered_step_override(): + """Test overriding registered steps using registry names.""" + registered_step = RegisteredMockStep(value=50, device="cpu") + pipeline = RobotProcessor([registered_step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override using registry name + overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}} + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + # Test that overrides were applied + transition = create_transition() + result = loaded_pipeline(transition) + + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert comp_data["registered_step_value"] == 999 + assert comp_data["registered_step_device"] == "cuda" + + +def test_from_pretrained_mixed_registered_and_unregistered(): + """Test overriding both registered and unregistered steps.""" + unregistered_step = MockStepWithNonSerializableParam(name="unregistered", multiplier=1.0) + registered_step = RegisteredMockStep(value=10, device="cpu") + + pipeline = RobotProcessor([unregistered_step, registered_step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + mock_env = MockEnvironment("mixed_test") + + overrides = { + "MockStepWithNonSerializableParam": {"env": mock_env, "multiplier": 4.0}, + "registered_mock_step": {"value": 777}, + } + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + # Test both steps + transition = create_transition(reward=2.0) + result = loaded_pipeline(transition) + + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)" + assert comp_data["registered_step_value"] == 777 + assert result[TransitionKey.REWARD] == 8.0 # 2.0 * 4.0 + + +def test_from_pretrained_no_overrides(): + """Test that from_pretrained works without overrides (backward compatibility).""" + step = MockStepWithNonSerializableParam(name="no_override", multiplier=3.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load without overrides + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + + assert len(loaded_pipeline) == 1 + + # Test that the step works (env will be None) + transition = create_transition(reward=1.0) + result = loaded_pipeline(transition) + + assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 + + +def test_from_pretrained_empty_overrides(): + """Test that from_pretrained works with empty overrides dict.""" + step = MockStepWithNonSerializableParam(multiplier=2.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load with empty overrides + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={}) + + assert len(loaded_pipeline) == 1 + + # Test that the step works normally + transition = create_transition(reward=1.0) + result = loaded_pipeline(transition) + + assert result[TransitionKey.REWARD] == 2.0 + + +def test_from_pretrained_override_instantiation_error(): + """Test that instantiation errors with overrides are properly reported.""" + step = MockStepWithNonSerializableParam(multiplier=1.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override with invalid parameter type + overrides = { + "MockStepWithNonSerializableParam": { + "multiplier": "invalid_type" # Should be float, not string + } + } + + with pytest.raises(ValueError, match="Failed to instantiate processor step"): + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + +def test_from_pretrained_with_state_and_overrides(): + """Test that overrides work correctly with steps that have tensor state.""" + step = MockStepWithTensorState(name="tensor_step", learning_rate=0.01, window_size=5) + pipeline = RobotProcessor([step]) + + # Process some data to create state + for i in range(10): + transition = create_transition(reward=float(i)) + pipeline(transition) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load with overrides + overrides = { + "MockStepWithTensorState": { + "learning_rate": 0.05, # Override learning rate + "window_size": 3, # Override window size + } + } + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_step = loaded_pipeline.steps[0] + + # Check that config overrides were applied + assert loaded_step.learning_rate == 0.05 + assert loaded_step.window_size == 3 + + # Check that tensor state was preserved + assert loaded_step.running_count.item() == 10 + + # The running_mean should still have the original window_size (5) from saved state + # but the new step will use window_size=3 for future operations + assert loaded_step.running_mean.shape[0] == 5 # From saved state + + +def test_from_pretrained_override_error_messages(): + """Test that error messages for override failures are helpful.""" + step1 = MockStepWithNonSerializableParam(name="step1") + step2 = RegisteredMockStep() + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Test with invalid override key + overrides = {"WrongStepName": {"param": "value"}} + + with pytest.raises(KeyError) as exc_info: + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + error_msg = str(exc_info.value) + assert "WrongStepName" in error_msg + assert "Available step keys" in error_msg + assert "MockStepWithNonSerializableParam" in error_msg + assert "registered_mock_step" in error_msg + + +def test_repr_empty_processor(): + """Test __repr__ with empty processor.""" + pipeline = RobotProcessor() + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=0: [])" + assert repr_str == expected + + +def test_repr_single_step(): + """Test __repr__ with single step.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])" + assert repr_str == expected + + +def test_repr_multiple_steps_under_limit(): + """Test __repr__ with 2-3 steps (all shown).""" + step1 = MockStep("step1") + step2 = MockStepWithoutOptionalMethods() + pipeline = RobotProcessor([step1, step2]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" + assert repr_str == expected + + # Test with 3 steps (boundary case) + step3 = MockStepWithTensorState() + pipeline = RobotProcessor([step1, step2, step3]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=3: [MockStep, MockStepWithoutOptionalMethods, MockStepWithTensorState])" + assert repr_str == expected + + +def test_repr_many_steps_truncated(): + """Test __repr__ with more than 3 steps (truncated with ellipsis).""" + step1 = MockStep("step1") + step2 = MockStepWithoutOptionalMethods() + step3 = MockStepWithTensorState() + step4 = MockModuleStep() + step5 = MockNonModuleStepWithState() + + pipeline = RobotProcessor([step1, step2, step3, step4, step5]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=5: [MockStep, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" + assert repr_str == expected + + +def test_repr_with_custom_name(): + """Test __repr__ with custom processor name.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step], name="CustomProcessor") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='CustomProcessor', steps=1: [MockStep])" + assert repr_str == expected + + +def test_repr_with_seed(): + """Test __repr__ with seed parameter.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])" + assert repr_str == expected + + +def test_repr_with_custom_name_and_seed(): + """Test __repr__ with both custom name and seed.""" + step1 = MockStep("step1") + step2 = MockStepWithoutOptionalMethods() + pipeline = RobotProcessor([step1, step2], name="MyProcessor") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" + assert repr_str == expected + + +def test_repr_without_seed(): + """Test __repr__ when seed is explicitly None (should not show seed).""" + step = MockStep("test_step") + pipeline = RobotProcessor([step], name="TestProcessor") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])" + assert repr_str == expected + + +def test_repr_various_step_types(): + """Test __repr__ with different types of steps to verify class name extraction.""" + step1 = MockStep() + step2 = MockStepWithTensorState() + step3 = MockModuleStep() + step4 = MockNonModuleStepWithState() + + pipeline = RobotProcessor([step1, step2, step3, step4], name="MixedSteps") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='MixedSteps', steps=4: [MockStep, MockStepWithTensorState, ..., MockNonModuleStepWithState])" + assert repr_str == expected + + +def test_repr_edge_case_long_names(): + """Test __repr__ handles steps with long class names properly.""" + step1 = MockStepWithNonSerializableParam() + step2 = MockStepWithoutOptionalMethods() + step3 = MockStepWithTensorState() + step4 = MockNonModuleStepWithState() + + pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" + assert repr_str == expected + + +# Tests for config filename features and multiple processors +def test_save_with_custom_config_filename(): + """Test saving processor with custom config filename.""" + step = MockStep("test") + pipeline = RobotProcessor([step], name="TestProcessor") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save with custom filename + pipeline.save_pretrained(tmp_dir, config_filename="my_custom_config.json") + + # Check file exists + config_path = Path(tmp_dir) / "my_custom_config.json" + assert config_path.exists() + + # Check content + with open(config_path) as f: + config = json.load(f) + assert config["name"] == "TestProcessor" + + # Load with specific filename + loaded = RobotProcessor.from_pretrained(tmp_dir, config_filename="my_custom_config.json") + assert loaded.name == "TestProcessor" + + +def test_multiple_processors_same_directory(): + """Test saving multiple processors to the same directory with different config files.""" + # Create different processors + preprocessor = RobotProcessor([MockStep("pre1"), MockStep("pre2")], name="preprocessor") + + postprocessor = RobotProcessor([MockStepWithoutOptionalMethods(multiplier=0.5)], name="postprocessor") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save both to same directory + preprocessor.save_pretrained(tmp_dir) + postprocessor.save_pretrained(tmp_dir) + + # Check both config files exist + assert (Path(tmp_dir) / "preprocessor.json").exists() + assert (Path(tmp_dir) / "postprocessor.json").exists() + + # Load them back + loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json") + loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json") + + assert loaded_pre.name == "preprocessor" + assert loaded_post.name == "postprocessor" + assert len(loaded_pre) == 2 + assert len(loaded_post) == 1 + + +def test_auto_detect_single_config(): + """Test automatic config detection when there's only one JSON file.""" + step = MockStepWithTensorState() + pipeline = RobotProcessor([step], name="SingleConfig") + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load without specifying config_filename + loaded = RobotProcessor.from_pretrained(tmp_dir) + assert loaded.name == "SingleConfig" + + +def test_error_multiple_configs_no_filename(): + """Test error when multiple configs exist and no filename specified.""" + proc1 = RobotProcessor([MockStep()], name="processor1") + proc2 = RobotProcessor([MockStep()], name="processor2") + + with tempfile.TemporaryDirectory() as tmp_dir: + proc1.save_pretrained(tmp_dir) + proc2.save_pretrained(tmp_dir) + + # Should raise error + with pytest.raises(ValueError, match="Multiple .json files found"): + RobotProcessor.from_pretrained(tmp_dir) + + +def test_state_file_naming_with_indices(): + """Test that state files include pipeline name and step indices to avoid conflicts.""" + # Create multiple steps of same type with state + step1 = MockStepWithTensorState(name="norm1", window_size=5) + step2 = MockStepWithTensorState(name="norm2", window_size=10) + step3 = MockModuleStep(input_dim=5) + + pipeline = RobotProcessor([step1, step2, step3]) + + # Process some data to create state + for i in range(5): + transition = create_transition(observation=torch.randn(2, 5), reward=float(i)) + pipeline(transition) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Check state files have indices + state_files = sorted(Path(tmp_dir).glob("*.safetensors")) + assert len(state_files) == 3 + + # Files should be named with pipeline name prefix and indices + expected_names = [ + "robotprocessor_step_0.safetensors", + "robotprocessor_step_1.safetensors", + "robotprocessor_step_2.safetensors", + ] + actual_names = [f.name for f in state_files] + assert actual_names == expected_names + + +def test_state_file_naming_with_registry(): + """Test state file naming for registered steps includes pipeline name, index and registry name.""" + + # Register a test step + @ProcessorStepRegistry.register("test_stateful_step") + @dataclass + class TestStatefulStep: + value: int = 0 + + def __init__(self, value: int = 0): + self.value = value + self.state_tensor = torch.randn(3, 3) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def get_config(self): + return {"value": self.value} + + def state_dict(self): + return {"state_tensor": self.state_tensor} + + def load_state_dict(self, state): + self.state_tensor = state["state_tensor"] + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + try: + # Create pipeline with registered steps + step1 = TestStatefulStep(1) + step2 = TestStatefulStep(2) + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Check state files + state_files = sorted(Path(tmp_dir).glob("*.safetensors")) + assert len(state_files) == 2 + + # Should include pipeline name, index and registry name + expected_names = [ + "robotprocessor_step_0_test_stateful_step.safetensors", + "robotprocessor_step_1_test_stateful_step.safetensors", + ] + actual_names = [f.name for f in state_files] + assert actual_names == expected_names + + finally: + # Cleanup registry + ProcessorStepRegistry.unregister("test_stateful_step") + + +# More comprehensive override tests +def test_override_with_nested_config(): + """Test overrides with nested configuration dictionaries.""" + + @ProcessorStepRegistry.register("complex_config_step") + @dataclass + class ComplexConfigStep: + name: str = "complex" + simple_param: int = 42 + nested_config: dict = None + + def __post_init__(self): + if self.nested_config is None: + self.nested_config = {"level1": {"level2": "default"}} + + def __call__(self, transition: EnvTransition) -> EnvTransition: + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + comp_data = dict(comp_data) + comp_data["config_value"] = self.nested_config.get("level1", {}).get("level2", "missing") + + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition + + def get_config(self): + return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config} + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + try: + step = ComplexConfigStep() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load with nested override + loaded = RobotProcessor.from_pretrained( + tmp_dir, + overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}}, + ) + + # Test that override worked + transition = create_transition() + result = loaded(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["config_value"] == "overridden" + finally: + ProcessorStepRegistry.unregister("complex_config_step") + + +def test_override_preserves_defaults(): + """Test that overrides only affect specified parameters.""" + step = MockStepWithNonSerializableParam(name="test", multiplier=2.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override only one parameter + loaded = RobotProcessor.from_pretrained( + tmp_dir, + overrides={ + "MockStepWithNonSerializableParam": { + "multiplier": 5.0 # Only override multiplier + } + }, + ) + + # Check that name was preserved from saved config + loaded_step = loaded.steps[0] + assert loaded_step.name == "test" # Original value + assert loaded_step.multiplier == 5.0 # Overridden value + + +def test_override_type_validation(): + """Test that type errors in overrides are caught properly.""" + step = MockStepWithTensorState(learning_rate=0.01) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override with wrong type + overrides = { + "MockStepWithTensorState": { + "window_size": "not_an_int" # Should be int + } + } + + with pytest.raises(ValueError, match="Failed to instantiate"): + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + +def test_override_with_callables(): + """Test overriding with callable objects.""" + + @ProcessorStepRegistry.register("callable_step") + @dataclass + class CallableStep: + name: str = "callable_step" + transform_fn: Any = None + + def __call__(self, transition: EnvTransition) -> EnvTransition: + obs = transition.get(TransitionKey.OBSERVATION) + if obs is not None and self.transform_fn is not None: + processed_obs = {} + for k, v in obs.items(): + processed_obs[k] = self.transform_fn(v) + + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_obs + return new_transition + return transition + + def get_config(self): + return {"name": self.name} + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + try: + step = CallableStep() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Define a transform function + def double_values(x): + if isinstance(x, (int, float)): + return x * 2 + elif isinstance(x, torch.Tensor): + return x * 2 + return x + + # Load with callable override + loaded = RobotProcessor.from_pretrained( + tmp_dir, overrides={"callable_step": {"transform_fn": double_values}} + ) + + # Test it works + transition = create_transition(observation={"value": torch.tensor(5.0)}) + result = loaded(transition) + assert result[TransitionKey.OBSERVATION]["value"].item() == 10.0 + finally: + ProcessorStepRegistry.unregister("callable_step") + + +def test_override_multiple_same_class_warning(): + """Test behavior when multiple steps of same class exist.""" + step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0) + step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0) + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override affects all instances of the class + loaded = RobotProcessor.from_pretrained( + tmp_dir, overrides={"MockStepWithNonSerializableParam": {"multiplier": 10.0}} + ) + + # Both steps get the same override + assert loaded.steps[0].multiplier == 10.0 + assert loaded.steps[1].multiplier == 10.0 + + # But original names are preserved + assert loaded.steps[0].name == "step1" + assert loaded.steps[1].name == "step2" + + +def test_config_filename_special_characters(): + """Test config filenames with special characters are sanitized.""" + # Processor name with special characters + pipeline = RobotProcessor([MockStep()], name="My/Processor\\With:Special*Chars") + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Check that filename was sanitized + json_files = list(Path(tmp_dir).glob("*.json")) + assert len(json_files) == 1 + + # Should have replaced special chars with underscores + expected_name = "my_processor_with_special_chars.json" + assert json_files[0].name == expected_name + + +def test_state_file_naming_with_multiple_processors(): + """Test that state files are properly prefixed with pipeline names to avoid conflicts.""" + # Create two processors with state + step1 = MockStepWithTensorState(name="norm", window_size=5) + preprocessor = RobotProcessor([step1], name="PreProcessor") + + step2 = MockStepWithTensorState(name="norm", window_size=10) + postprocessor = RobotProcessor([step2], name="PostProcessor") + + # Process some data to create state + for i in range(3): + transition = create_transition(reward=float(i)) + preprocessor(transition) + postprocessor(transition) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save both processors to the same directory + preprocessor.save_pretrained(tmp_dir) + postprocessor.save_pretrained(tmp_dir) + + # Check that all files exist and are distinct + assert (Path(tmp_dir) / "preprocessor.json").exists() + assert (Path(tmp_dir) / "postprocessor.json").exists() + assert (Path(tmp_dir) / "preprocessor_step_0.safetensors").exists() + assert (Path(tmp_dir) / "postprocessor_step_0.safetensors").exists() + + # Load both back and verify they work correctly + loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json") + loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json") + + assert loaded_pre.name == "PreProcessor" + assert loaded_post.name == "PostProcessor" + assert loaded_pre.steps[0].window_size == 5 + assert loaded_post.steps[0].window_size == 10 + + +def test_override_with_device_strings(): + """Test overriding device parameters with string values.""" + + @ProcessorStepRegistry.register("device_aware_step") + @dataclass + class DeviceAwareStep: + device: str = "cpu" + + def __init__(self, device: str = "cpu"): + self.device = device + self.buffer = torch.zeros(10, device=device) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def get_config(self): + return {"device": str(self.device)} + + def state_dict(self): + return {"buffer": self.buffer} + + def load_state_dict(self, state): + self.buffer = state["buffer"] + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + try: + step = DeviceAwareStep(device="cpu") + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override device + if torch.cuda.is_available(): + loaded = RobotProcessor.from_pretrained( + tmp_dir, overrides={"device_aware_step": {"device": "cuda:0"}} + ) + + loaded_step = loaded.steps[0] + assert loaded_step.device == "cuda:0" + # Note: buffer will still be on CPU from saved state + # until .to() is called on the processor + + finally: + ProcessorStepRegistry.unregister("device_aware_step") + + +def test_from_pretrained_nonexistent_path(): + """Test error handling when loading from non-existent sources.""" + from huggingface_hub.errors import HfHubHTTPError, HFValidationError + + # Test with an invalid repo ID (too many slashes) - caught by HF validation + with pytest.raises(HFValidationError): + RobotProcessor.from_pretrained("/path/that/does/not/exist") + + # Test with a non-existent but valid Hub repo format + with pytest.raises((FileNotFoundError, HfHubHTTPError)): + RobotProcessor.from_pretrained("nonexistent-user/nonexistent-repo") + + # Test with a local directory that exists but has no config files + with tempfile.TemporaryDirectory() as tmp_dir: + with pytest.raises(FileNotFoundError, match="No .json configuration files found"): + RobotProcessor.from_pretrained(tmp_dir) + + +def test_save_load_with_custom_converter_functions(): + """Test that custom to_transition and to_output functions are NOT saved.""" + + def custom_to_transition(batch): + # Custom conversion logic + return { + TransitionKey.OBSERVATION: batch.get("obs"), + TransitionKey.ACTION: batch.get("act"), + TransitionKey.REWARD: batch.get("rew", 0.0), + TransitionKey.DONE: batch.get("done", False), + TransitionKey.TRUNCATED: batch.get("truncated", False), + TransitionKey.INFO: {}, + TransitionKey.COMPLEMENTARY_DATA: {}, + } + + def custom_to_output(transition): + # Custom output format + return { + "obs": transition.get(TransitionKey.OBSERVATION), + "act": transition.get(TransitionKey.ACTION), + "rew": transition.get(TransitionKey.REWARD), + "done": transition.get(TransitionKey.DONE), + "truncated": transition.get(TransitionKey.TRUNCATED), + } + + # Create processor with custom converters + pipeline = RobotProcessor([MockStep()], to_transition=custom_to_transition, to_output=custom_to_output) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load - should use default converters + loaded = RobotProcessor.from_pretrained(tmp_dir) + + # Verify it uses default converters by checking with standard batch format + batch = { + "observation.image": torch.randn(1, 3, 32, 32), + "action": torch.randn(1, 7), + "next.reward": torch.tensor([1.0]), + "next.done": torch.tensor([False]), + "next.truncated": torch.tensor([False]), + "info": {}, + } + + # Should work with standard format (wouldn't work with custom converter) + result = loaded(batch) + assert "observation.image" in result # Standard format preserved + + +class NonCompliantStep: + """Intentionally non-compliant: missing feature_contract.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + +def test_construction_rejects_step_without_feature_contract(): + with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"): + RobotProcessor([NonCompliantStep()]) + + +class NonCallableStep: + """Intentionally non-compliant: missing __call__.""" + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +def test_construction_rejects_step_without_call(): + with pytest.raises(TypeError, match=r"must define __call__"): + RobotProcessor([NonCallableStep()]) + + +@dataclass +class FeatureContractAddStep: + """Adds a PolicyFeature""" + + key: str = "a" + value: PolicyFeature = PolicyFeature(type=FeatureType.STATE, shape=(1,)) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features[self.key] = self.value + return features + + +@dataclass +class FeatureContractMutateStep: + """Mutates a PolicyFeature""" + + key: str = "a" + fn: Callable[[PolicyFeature | None], PolicyFeature] = lambda x: x # noqa: E731 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features[self.key] = self.fn(features.get(self.key)) + return features + + +@dataclass +class FeatureContractBadReturnStep: + """Returns a non-dict""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return ["not-a-dict"] + + +@dataclass +class FeatureContractRemoveStep: + """Removes a PolicyFeature""" + + key: str + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features.pop(self.key, None) + return features + + +def test_feature_contract_orders_and_merges(policy_feature_factory): + p = RobotProcessor( + [ + FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), + FeatureContractMutateStep("a", lambda v: PolicyFeature(type=v.type, shape=(3,))), + FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))), + ] + ) + out = p.feature_contract({}) + + assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,) + assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,) + assert_contract_is_typed(out) + + +def test_feature_contract_respects_initial_without_mutation(policy_feature_factory): + initial = { + "seed": policy_feature_factory(FeatureType.STATE, (7,)), + "nested": policy_feature_factory(FeatureType.ENV, (0,)), + } + p = RobotProcessor( + [ + FeatureContractMutateStep("seed", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 1,))), + FeatureContractMutateStep( + "nested", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 5,)) + ), + ] + ) + out = p.feature_contract(initial_features=initial) + + assert out["seed"].shape == (8,) + assert out["nested"].shape == (5,) + # Initial dict must be preserved + assert initial["seed"].shape == (7,) + assert initial["nested"].shape == (0,) + + assert_contract_is_typed(out) + + +def test_feature_contract_type_error_on_bad_step(): + p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()]) + with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"): + _ = p.feature_contract({}) + + +def test_feature_contract_execution_order_tracking(): + class Track: + def __init__(self, label): + self.label = label + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + code = {"A": 1, "B": 2, "C": 3}[self.label] + pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=())) + features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,)) + return features + + out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({}) + assert out["order"].shape == (1, 2, 3) + + +def test_feature_contract_remove_key(policy_feature_factory): + p = RobotProcessor( + [ + FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), + FeatureContractRemoveStep("a"), + ] + ) + out = p.feature_contract({}) + assert "a" not in out + + +def test_feature_contract_remove_from_initial(policy_feature_factory): + initial = { + "keep": policy_feature_factory(FeatureType.STATE, (1,)), + "drop": policy_feature_factory(FeatureType.STATE, (1,)), + } + p = RobotProcessor([FeatureContractRemoveStep("drop")]) + out = p.feature_contract(initial_features=initial) + assert "drop" not in out and out["keep"] == initial["keep"] diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py new file mode 100644 index 000000000..229d57f9f --- /dev/null +++ b/tests/processor/test_rename_processor.py @@ -0,0 +1,467 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +from pathlib import Path + +import numpy as np +import torch + +from lerobot.configs.types import FeatureType +from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey +from tests.conftest import assert_contract_is_typed + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + + +def test_basic_renaming(): + """Test basic key renaming functionality.""" + rename_map = { + "old_key1": "new_key1", + "old_key2": "new_key2", + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "old_key1": torch.tensor([1.0, 2.0]), + "old_key2": np.array([3.0, 4.0]), + "unchanged_key": "keep_me", + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check renamed keys + assert "new_key1" in processed_obs + assert "new_key2" in processed_obs + assert "old_key1" not in processed_obs + assert "old_key2" not in processed_obs + + # Check values are preserved + torch.testing.assert_close(processed_obs["new_key1"], torch.tensor([1.0, 2.0])) + np.testing.assert_array_equal(processed_obs["new_key2"], np.array([3.0, 4.0])) + + # Check unchanged key is preserved + assert processed_obs["unchanged_key"] == "keep_me" + + +def test_empty_rename_map(): + """Test processor with empty rename map (should pass through unchanged).""" + processor = RenameProcessor(rename_map={}) + + observation = { + "key1": torch.tensor([1.0]), + "key2": "value2", + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # All keys should be unchanged + assert processed_obs.keys() == observation.keys() + torch.testing.assert_close(processed_obs["key1"], observation["key1"]) + assert processed_obs["key2"] == observation["key2"] + + +def test_none_observation(): + """Test processor with None observation.""" + processor = RenameProcessor(rename_map={"old": "new"}) + + transition = create_transition() + result = processor(transition) + + # Should return transition unchanged + assert result == transition + + +def test_overlapping_rename(): + """Test renaming when new names might conflict.""" + rename_map = { + "a": "b", + "b": "c", # This creates a potential conflict + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "a": 1, + "b": 2, + "x": 3, + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that renaming happens correctly + assert "a" not in processed_obs + assert processed_obs["b"] == 1 # 'a' renamed to 'b' + assert processed_obs["c"] == 2 # original 'b' renamed to 'c' + assert processed_obs["x"] == 3 + + +def test_partial_rename(): + """Test renaming only some keys.""" + rename_map = { + "observation.state": "observation.proprio_state", + "pixels": "observation.image", + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "observation.state": torch.randn(10), + "pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8), + "reward": 1.0, + "info": {"episode": 1}, + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check renamed keys + assert "observation.proprio_state" in processed_obs + assert "observation.image" in processed_obs + assert "observation.state" not in processed_obs + assert "pixels" not in processed_obs + + # Check unchanged keys + assert processed_obs["reward"] == 1.0 + assert processed_obs["info"] == {"episode": 1} + + +def test_get_config(): + """Test configuration serialization.""" + rename_map = { + "old1": "new1", + "old2": "new2", + } + processor = RenameProcessor(rename_map=rename_map) + + config = processor.get_config() + assert config == {"rename_map": rename_map} + + +def test_state_dict(): + """Test state dict (should be empty for RenameProcessor).""" + processor = RenameProcessor(rename_map={"old": "new"}) + + state = processor.state_dict() + assert state == {} + + # Load state dict should work even with empty dict + processor.load_state_dict({}) + + +def test_integration_with_robot_processor(): + """Test integration with RobotProcessor pipeline.""" + rename_map = { + "agent_pos": "observation.state", + "pixels": "observation.image", + } + rename_processor = RenameProcessor(rename_map=rename_map) + + pipeline = RobotProcessor([rename_processor]) + + observation = { + "agent_pos": np.array([1.0, 2.0, 3.0]), + "pixels": np.zeros((32, 32, 3), dtype=np.uint8), + "other_data": "preserve_me", + } + transition = create_transition( + observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={} + ) + + result = pipeline(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check renaming worked through pipeline + assert "observation.state" in processed_obs + assert "observation.image" in processed_obs + assert "agent_pos" not in processed_obs + assert "pixels" not in processed_obs + assert processed_obs["other_data"] == "preserve_me" + + # Check other transition elements unchanged + assert result[TransitionKey.REWARD] == 0.5 + assert result[TransitionKey.DONE] is False + + +def test_save_and_load_pretrained(): + """Test saving and loading processor with RobotProcessor.""" + rename_map = { + "old_state": "observation.state", + "old_image": "observation.image", + } + processor = RenameProcessor(rename_map=rename_map) + pipeline = RobotProcessor([processor], name="TestRenameProcessor") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Check files were created + config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor" + assert config_path.exists() + + # No state files should be created for RenameProcessor + state_files = list(Path(tmp_dir).glob("*.safetensors")) + assert len(state_files) == 0 + + # Load pipeline + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + + assert loaded_pipeline.name == "TestRenameProcessor" + assert len(loaded_pipeline) == 1 + + # Check that loaded processor works correctly + loaded_processor = loaded_pipeline.steps[0] + assert isinstance(loaded_processor, RenameProcessor) + assert loaded_processor.rename_map == rename_map + + # Test functionality after loading + observation = {"old_state": [1, 2, 3], "old_image": "image_data"} + transition = create_transition(observation=observation) + + result = loaded_pipeline(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert "observation.state" in processed_obs + assert "observation.image" in processed_obs + assert processed_obs["observation.state"] == [1, 2, 3] + assert processed_obs["observation.image"] == "image_data" + + +def test_registry_functionality(): + """Test that RenameProcessor is properly registered.""" + # Check that it's registered + assert "rename_processor" in ProcessorStepRegistry.list() + + # Get from registry + retrieved_class = ProcessorStepRegistry.get("rename_processor") + assert retrieved_class is RenameProcessor + + # Create instance from registry + instance = retrieved_class(rename_map={"old": "new"}) + assert isinstance(instance, RenameProcessor) + assert instance.rename_map == {"old": "new"} + + +def test_registry_based_save_load(): + """Test save/load using registry name instead of module path.""" + processor = RenameProcessor(rename_map={"key1": "renamed_key1"}) + pipeline = RobotProcessor([processor]) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save and load + pipeline.save_pretrained(tmp_dir) + + # Verify config uses registry name + import json + + with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor" + config = json.load(f) + + assert "registry_name" in config["steps"][0] + assert config["steps"][0]["registry_name"] == "rename_processor" + assert "class" not in config["steps"][0] # Should use registry, not module path + + # Load should work + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_processor = loaded_pipeline.steps[0] + assert isinstance(loaded_processor, RenameProcessor) + assert loaded_processor.rename_map == {"key1": "renamed_key1"} + + +def test_chained_rename_processors(): + """Test multiple RenameProcessors in a pipeline.""" + # First processor: rename raw keys to intermediate format + processor1 = RenameProcessor( + rename_map={ + "pos": "agent_position", + "img": "camera_image", + } + ) + + # Second processor: rename to final format + processor2 = RenameProcessor( + rename_map={ + "agent_position": "observation.state", + "camera_image": "observation.image", + } + ) + + pipeline = RobotProcessor([processor1, processor2]) + + observation = { + "pos": np.array([1.0, 2.0]), + "img": "image_data", + "extra": "keep_me", + } + transition = create_transition(observation=observation) + + # Step through to see intermediate results + results = list(pipeline.step_through(transition)) + + # After first processor + assert "agent_position" in results[1][TransitionKey.OBSERVATION] + assert "camera_image" in results[1][TransitionKey.OBSERVATION] + + # After second processor + final_obs = results[2][TransitionKey.OBSERVATION] + assert "observation.state" in final_obs + assert "observation.image" in final_obs + assert final_obs["extra"] == "keep_me" + + # Original keys should be gone + assert "pos" not in final_obs + assert "img" not in final_obs + assert "agent_position" not in final_obs + assert "camera_image" not in final_obs + + +def test_nested_observation_rename(): + """Test renaming with nested observation structures.""" + rename_map = { + "observation.images.left": "observation.camera.left_view", + "observation.images.right": "observation.camera.right_view", + "observation.proprio": "observation.proprioception", + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "observation.images.left": torch.randn(3, 64, 64), + "observation.images.right": torch.randn(3, 64, 64), + "observation.proprio": torch.randn(7), + "observation.gripper": torch.tensor([0.0]), # Not renamed + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check renames + assert "observation.camera.left_view" in processed_obs + assert "observation.camera.right_view" in processed_obs + assert "observation.proprioception" in processed_obs + + # Check unchanged key + assert "observation.gripper" in processed_obs + + # Check old keys removed + assert "observation.images.left" not in processed_obs + assert "observation.images.right" not in processed_obs + assert "observation.proprio" not in processed_obs + + +def test_value_types_preserved(): + """Test that various value types are preserved during renaming.""" + rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"} + processor = RenameProcessor(rename_map=rename_map) + + tensor_value = torch.randn(3, 3) + array_value = np.random.rand(2, 2) + + observation = { + "old_tensor": tensor_value, + "old_array": array_value, + "old_scalar": 42, + "old_string": "hello", + "old_dict": {"nested": "value"}, + "old_list": [1, 2, 3], + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that values and types are preserved + assert torch.equal(processed_obs["new_tensor"], tensor_value) + assert np.array_equal(processed_obs["new_array"], array_value) + assert processed_obs["new_scalar"] == 42 + assert processed_obs["old_string"] == "hello" + assert processed_obs["old_dict"] == {"nested": "value"} + assert processed_obs["old_list"] == [1, 2, 3] + + +def test_feature_contract_basic_renaming(policy_feature_factory): + processor = RenameProcessor(rename_map={"a": "x", "b": "y"}) + features = { + "a": policy_feature_factory(FeatureType.STATE, (2,)), + "b": policy_feature_factory(FeatureType.ACTION, (3,)), + "c": policy_feature_factory(FeatureType.ENV, (1,)), + } + + out = processor.feature_contract(features.copy()) + + # Values preserved and typed + assert out["x"] == features["a"] + assert out["y"] == features["b"] + assert out["c"] == features["c"] + + assert_contract_is_typed(out) + # Input not mutated + assert set(features) == {"a", "b", "c"} + + +def test_feature_contract_overlapping_keys(policy_feature_factory): + # Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c' + processor = RenameProcessor(rename_map={"a": "b", "b": "c"}) + features = { + "a": policy_feature_factory(FeatureType.STATE, (1,)), + "b": policy_feature_factory(FeatureType.STATE, (2,)), + } + out = processor.feature_contract(features) + + assert set(out) == {"b", "c"} + assert out["b"] == features["a"] # 'a' renamed to'b' + assert out["c"] == features["b"] # 'b' renamed to 'c' + assert_contract_is_typed(out) + + +def test_feature_contract_chained_processors(policy_feature_factory): + # Chain two rename processors at the contract level + processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"}) + processor2 = RenameProcessor( + rename_map={"agent_position": "observation.state", "camera_image": "observation.image"} + ) + pipeline = RobotProcessor([processor1, processor2]) + + spec = { + "pos": policy_feature_factory(FeatureType.STATE, (7,)), + "img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "extra": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = pipeline.feature_contract(initial_features=spec) + + assert set(out) == {"observation.state", "observation.image", "extra"} + assert out["observation.state"] == spec["pos"] + assert out["observation.image"] == spec["img"] + assert out["extra"] == spec["extra"] + assert_contract_is_typed(out)