feat(pipeline): universal processor for LeRobot (#1431)

* Refactor observation preprocessing to use a modular pipeline system

- Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations.
- Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline.
- Added tests for the new processing components and ensured they match the original functionality.
- Removed hardcoded logic in favor of a more flexible, composable architecture.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactor observation processing and improve modularity

- Updated `ObservationProcessor` to enhance the modular design for processing observations.
- Cleaned up imports and improved code readability by removing unnecessary lines and comments.
- Ensured backward compatibility while integrating new processing components.
- Added tests to validate the functionality of the updated processing architecture.

* Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability.

* Refactor processing architecture to use RobotProcessor

- Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity.
- Introduced ProcessorStepRegistry for better management of processing steps.
- Updated relevant documentation and tests to reflect the new processing structure.
- Enhanced the save/load functionality to support the new processor design.
- Added a model card template for RobotProcessor to facilitate sharing and documentation.

* Add RobotProcessor tutorial to documentation

- Introduced a new tutorial on using RobotProcessor for preprocessing robot data.
- Added a section in the table of contents for easy navigation to the new tutorial.
- The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add normalization processor and related components

- Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization.
- Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks.
- Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports.
- Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity.
- Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Enhance processing architecture with new components

- Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility.
- Updated `__init__.py` to include `RenameProcessor` in module exports.
- Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling.
- Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness.

* chore (docs): add docstring for processor

* fix (test): test factory

* fix(test): policies

* Update tests/processor/test_observation_processor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>

* chore(test): add suggestion made by copilot regarding numpy test

* fix(test): import issue

* Refactor normalization components and update tests

- Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity.
- Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`.
- Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes.
- Enhanced handling of missing statistics in normalization processes.

* chore (docstrin):Improve docstring for NormalizerProcessor

* feat (device processor): Implement device processor

* chore (batch handling): Enhance processing components with batch conversion utilities

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(test): linting issue

* chore (output format): improves output format

* chore (type): add typing for multiprocess envs

* feat (overrides): Implement support for loading processors with parameter overrides

- Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter.
- Enhanced error handling for invalid override keys and instantiation errors.
- Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps.
- Added comprehensive tests to validate the new functionality and ensure backward compatibility.

* chore(normalization): addressing comments from copilot

* chore(learner): nit comment from copilot

* feat(pipeline): Enhance step_through method to support both tuple and dict inputs

* refactor(pipeline): Simplify observation and padding data handling in batch transitions

* Apply suggestions from code review

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(pipeline): Transition from tuple to dictionary format for EnvTransition

- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability.
- Replaced instances of TransitionIndex with TransitionKey for accessing transition components.
- Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.

* refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling

- Introduced constants for observation keys to enhance readability.
- Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly.
- Updated the environment state and agent position assignments to use the new constants, improving maintainability.

* feat(pipeline): Add hook unregistration functionality and enhance documentation

- Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management.
- Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks.
- Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks.

* refactor(pipeline): Clarify hook behavior and improve documentation

- Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability.
- Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions.
- Enhanced documentation to clearly outline the purpose of hooks and their execution semantics.
- Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method.

* feat(pipeline): Add __repr__ method to RobotProcessor for improved readability

- Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed.
- Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values.
- Ensured that the representation handles long lists of steps with truncation for better readability.

* chore(pipeline): Move _CFG_NAME along other class member

* refactor(pipeline): Utilize get_safe_torch_device for device assignment

- Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling.
- This change enhances code readability and maintains consistency in device management across the RobotProcessor class.

* refactor(pipeline): Enhance state filename generation and profiling method

- Updated state filename generation to use the registry name when available, improving clarity in saved files.
- Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling.
- Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results.

* chore(doc): address pip install commant lerobot that not exist yet

* feat(pipeline): Enhance configuration filename handling and state file naming

- Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default.
- Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved.
- Added automatic detection for configuration files when loading from a directory, with error handling for multiple files.
- Updated tests to validate new features, including custom filenames and automatic config detection.

* refactor(pipeline): Improve state file naming conventions for clarity and uniqueness

- Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory.
- Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts.
- Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files.

* docs(pipeline): Add clarification for repo name sanitization process

* Feat/pipeline add feature contract (#1637)

* Add feature contract to pipelinestep and pipeline

* Add tests

* Add processor tests

* PR feedback

* encorperate pr feedback

* type in doc

* oops

* docs(pipeline): Clarify transition handling and hook behavior

- Updated documentation to specify that hooks always receive transitions in EnvTransition format, ensuring consistent behavior across input formats.
- Refactored the step_through method to yield only EnvTransition objects, regardless of the input format, and updated related tests to reflect this change.
- Enhanced test assertions to verify the structure of results and the correctness of processing steps.

* refactor(pipeline): Remove to() method for device management

- Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices.
- Removed associated unit tests that validated the functionality of the to() method across various scenarios.
- Streamlined the pipeline code by focusing on other device management strategies.

* refactor(pipeline): Remove model card generation and streamline processor methods

- Eliminated the _generate_model_card method from RobotProcessor, which was responsible for generating README.md files from a template.
- Updated save_pretrained method to remove model card generation, focusing on serialization of processor definitions and parameters.
- Added default implementations for get_config, state_dict, load_state_dict, reset, and feature_contract methods in various processor classes to enhance consistency and usability.

* refactor(observation): Streamline observation preprocessing and remove unused processor methods

- Updated the `preprocess_observation` function to enhance image handling and ensure proper tensor formatting.
- Removed the `RobotProcessor` and associated transition handling from the `rollout` function, simplifying the observation processing flow.
- Integrated direct calls to `preprocess_observation` for improved clarity and efficiency in the evaluation script.

* refactor(pipeline): Rename parameters for clarity and enhance save/load functionality

- Updated parameter names in the save_pretrained and from_pretrained methods for improved readability, changing destination_path to save_directory and source to pretrained_model_name_or_path.
- Enhanced the save_pretrained method to ensure directory creation and file handling is consistent with the new parameter names.
- Streamlined the loading process in from_pretrained to utilize loaded_config for better clarity and maintainability.

* refactor(pipeline): minor improvements (#1684)

* chore(pipeline): remove unused features + device torch + envtransition keys

* refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor

* refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code

* test(pipeline): fix broken test after refactors

* docs(pipeline): update docstrings VanillaObservationProcessor

* chore(pipeline): move None check to base pipeline classes

---------

Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Adil Zouitine
2025-08-06 16:11:04 +02:00
committed by GitHub
parent 6daa579ce1
commit 88f7bf01c1
12 changed files with 5738 additions and 0 deletions

View File

@@ -0,0 +1,54 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .device_processor import DeviceProcessor
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
from .observation_processor import VanillaObservationProcessor
from .pipeline import (
ActionProcessor,
DoneProcessor,
EnvTransition,
IdentityProcessor,
InfoProcessor,
ObservationProcessor,
ProcessorStep,
ProcessorStepRegistry,
RewardProcessor,
RobotProcessor,
TransitionKey,
TruncatedProcessor,
)
from .rename_processor import RenameProcessor
__all__ = [
"ActionProcessor",
"DeviceProcessor",
"DoneProcessor",
"EnvTransition",
"IdentityProcessor",
"InfoProcessor",
"NormalizerProcessor",
"UnnormalizerProcessor",
"ObservationProcessor",
"ProcessorStep",
"ProcessorStepRegistry",
"RenameProcessor",
"RewardProcessor",
"RobotProcessor",
"TransitionKey",
"TruncatedProcessor",
"VanillaObservationProcessor",
]

View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, TransitionKey
from lerobot.utils.utils import get_safe_torch_device
@dataclass
class DeviceProcessor:
"""Processes transitions by moving tensors to the specified device.
This processor ensures that all tensors in the transition are moved to the
specified device (CPU or GPU) before they are returned.
"""
device: torch.device = "cpu"
def __post_init__(self):
self.device = get_safe_torch_device(self.device)
self.non_blocking = "cuda" in str(self.device)
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Create a copy of the transition
new_transition = transition.copy()
# Process observation tensors
observation = transition.get(TransitionKey.OBSERVATION)
if observation is not None:
new_observation = {
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()
}
new_transition[TransitionKey.OBSERVATION] = new_observation
# Process action tensor
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, torch.Tensor):
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
# Process reward tensor
reward = transition.get(TransitionKey.REWARD)
if reward is not None and isinstance(reward, torch.Tensor):
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
# Process done tensor
done = transition.get(TransitionKey.DONE)
if done is not None and isinstance(done, torch.Tensor):
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
# Process truncated tensor
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is not None and isinstance(truncated, torch.Tensor):
new_transition[TransitionKey.TRUNCATED] = truncated.to(
self.device, non_blocking=self.non_blocking
)
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {"device": self.device}
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -0,0 +1,331 @@
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
"""Convert numpy arrays and other types to torch tensors."""
tensor_stats: dict[str, dict[str, Tensor]] = {}
for key, sub in stats.items():
tensor_stats[key] = {}
for stat_name, value in sub.items():
if isinstance(value, np.ndarray):
tensor_val = torch.from_numpy(value.astype(np.float32))
elif isinstance(value, torch.Tensor):
tensor_val = value.to(dtype=torch.float32)
elif isinstance(value, (int, float, list, tuple)):
tensor_val = torch.tensor(value, dtype=torch.float32)
else:
raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}")
tensor_stats[key][stat_name] = tensor_val
return tensor_stats
@dataclass
@ProcessorStepRegistry.register(name="normalizer_processor")
class NormalizerProcessor:
"""Normalizes observations and actions in a single processor step.
This processor handles normalization of both observation and action tensors
using either mean/std normalization or min/max scaling to a [-1, 1] range.
For each tensor key in the stats dictionary, the processor will:
- Use mean/std normalization if those statistics are provided: (x - mean) / std
- Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1
The processor can be configured to normalize only specific keys by setting
the normalize_keys parameter.
"""
# Features and normalisation map are mandatory to match the design of normalize.py
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
# Pre-computed statistics coming from dataset.meta.stats for instance.
stats: dict[str, dict[str, Any]] | None = None
# Explicit subset of keys to normalise. If ``None`` every key (except
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
# membership checks O(1).
normalize_keys: set[str] | None = None
eps: float = 1e-8
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@classmethod
def from_lerobot_dataset(
cls,
dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
*,
normalize_keys: set[str] | None = None,
eps: float = 1e-8,
) -> NormalizerProcessor:
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
The features and norm_map parameters are mandatory to match the design
pattern used in normalize.py.
"""
return cls(
features=features,
norm_map=norm_map,
stats=dataset.meta.stats,
normalize_keys=normalize_keys,
eps=eps,
)
def __post_init__(self):
# Handle deserialization from JSON config
if self.features and isinstance(list(self.features.values())[0], dict):
# Features came from JSON - need to reconstruct PolicyFeature objects
reconstructed_features = {}
for key, ft_dict in self.features.items():
reconstructed_features[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.features = reconstructed_features
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
# norm_map came from JSON - need to reconstruct enum keys and values
reconstructed_norm_map = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed_norm_map
# Convert statistics once so we avoid repeated numpy→Tensor conversions
# during runtime.
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
# Ensure *normalize_keys* is a set for fast look-ups and compare by
# value later when returning the configuration.
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
self.normalize_keys = set(self.normalize_keys)
def _normalize_obs(self, observation):
if observation is None:
return None
# Decide which keys should be normalised for this call.
if self.normalize_keys is not None:
keys_to_norm = self.normalize_keys
else:
# Use feature map to skip action keys.
keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION}
processed = dict(observation)
for key in keys_to_norm:
if key not in processed or key not in self._tensor_stats:
continue
orig_val = processed[key]
tensor = (
orig_val.to(dtype=torch.float32)
if isinstance(orig_val, torch.Tensor)
else torch.as_tensor(orig_val, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = (tensor - mean) / (std + self.eps)
elif "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
return processed
def _normalize_action(self, action):
if action is None or "action" not in self._tensor_stats:
return action
tensor = (
action.to(dtype=torch.float32)
if isinstance(action, torch.Tensor)
else torch.as_tensor(action, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return (tensor - mean) / (std + self.eps)
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._normalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with normalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
return new_transition
def get_config(self) -> dict[str, Any]:
config = {
"eps": self.eps,
"features": {
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
},
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
}
if self.normalize_keys is not None:
# Serialise as a list for YAML / JSON friendliness
config["normalize_keys"] = sorted(self.normalize_keys)
return config
def state_dict(self) -> dict[str, Tensor]:
flat = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat[f"{key}.{stat_name}"] = tensor
return flat
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
def reset(self):
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@dataclass
@ProcessorStepRegistry.register(name="unnormalizer_processor")
class UnnormalizerProcessor:
"""Inverse normalisation for observations and actions.
Exactly mirrors :class:`NormalizerProcessor` but applies the inverse
transform.
"""
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
stats: dict[str, dict[str, Any]] | None = None
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@classmethod
def from_lerobot_dataset(
cls,
dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
) -> UnnormalizerProcessor:
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats)
def __post_init__(self):
# Handle deserialization from JSON config
if self.features and isinstance(list(self.features.values())[0], dict):
# Features came from JSON - need to reconstruct PolicyFeature objects
reconstructed_features = {}
for key, ft_dict in self.features.items():
reconstructed_features[key] = PolicyFeature(
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
)
self.features = reconstructed_features
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
# norm_map came from JSON - need to reconstruct enum keys and values
reconstructed_norm_map = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed_norm_map
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
def _unnormalize_obs(self, observation):
if observation is None:
return None
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
processed = dict(observation)
for key in keys:
if key not in processed or key not in self._tensor_stats:
continue
orig_val = processed[key]
tensor = (
orig_val.to(dtype=torch.float32)
if isinstance(orig_val, torch.Tensor)
else torch.as_tensor(orig_val, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = tensor * std + mean
elif "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
return processed
def _unnormalize_action(self, action):
if action is None or "action" not in self._tensor_stats:
return action
tensor = (
action.to(dtype=torch.float32)
if isinstance(action, torch.Tensor)
else torch.as_tensor(action, dtype=torch.float32)
)
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
return tensor * std + mean
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
return (tensor + 1) / 2 * (max_val - min_val) + min_val
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
# Create a new transition with unnormalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"features": {
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
},
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
}
def state_dict(self) -> dict[str, Tensor]:
flat = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat[f"{key}.{stat_name}"] = tensor
return flat
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
def reset(self):
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import einops
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="observation_processor")
class VanillaObservationProcessor(ObservationProcessor):
"""
Processes environment observations into the LeRobot format by handling both images and states.
Image processing:
- Converts channel-last (H, W, C) images to channel-first (C, H, W)
- Normalizes uint8 images ([0, 255]) to float32 ([0, 1])
- Adds a batch dimension if missing
- Supports single images and image dictionaries
State processing:
- Maps 'environment_state' to observation.environment_state
- Maps 'agent_pos' to observation.state
- Converts numpy arrays to tensors
- Adds a batch dimension if missing
"""
def _process_single_image(self, img: np.ndarray) -> Tensor:
"""Process a single image array."""
# Convert to tensor
img_tensor = torch.from_numpy(img)
# Add batch dimension if needed
if img_tensor.ndim == 3:
img_tensor = img_tensor.unsqueeze(0)
# Validate image format
_, h, w, c = img_tensor.shape
if not (c < h and c < w):
raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}")
if img_tensor.dtype != torch.uint8:
raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}")
# Convert to channel-first format
img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
# Convert to float32 and normalize to [0, 1]
img_tensor = img_tensor.type(torch.float32) / 255.0
return img_tensor
def _process_observation(self, observation):
"""
Processes both image and state observations.
"""
processed_obs = observation.copy()
if "pixels" in processed_obs:
pixels = processed_obs.pop("pixels")
if isinstance(pixels, dict):
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
else:
imgs = {OBS_IMAGE: pixels}
for imgkey, img in imgs.items():
processed_obs[imgkey] = self._process_single_image(img)
if "environment_state" in processed_obs:
env_state_np = processed_obs.pop("environment_state")
env_state = torch.from_numpy(env_state_np).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
processed_obs[OBS_ENV_STATE] = env_state
if "agent_pos" in processed_obs:
agent_pos_np = processed_obs.pop("agent_pos")
agent_pos = torch.from_numpy(agent_pos_np).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
processed_obs[OBS_STATE] = agent_pos
return processed_obs
def observation(self, observation):
return self._process_observation(observation)
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms feature keys to a standardized contract.
This method handles several renaming patterns:
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
- Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1').
- Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1').
- environment_state -> OBS_ENV_STATE,
- agent_pos -> OBS_STATE,
- observation.environment_state -> OBS_ENV_STATE,
- observation.agent_pos -> OBS_STATE
"""
exact_pairs = {
"pixels": OBS_IMAGE,
"environment_state": OBS_ENV_STATE,
"agent_pos": OBS_STATE,
}
prefix_pairs = {
"pixels.": f"{OBS_IMAGES}.",
}
for key in list(features.keys()):
matched_prefix = False
for old_prefix, new_prefix in prefix_pairs.items():
prefixed_old = f"observation.{old_prefix}"
if key.startswith(prefixed_old):
suffix = key[len(prefixed_old) :]
features[f"{new_prefix}{suffix}"] = features.pop(key)
matched_prefix = True
break
if key.startswith(old_prefix):
suffix = key[len(old_prefix) :]
features[f"{new_prefix}{suffix}"] = features.pop(key)
matched_prefix = True
break
if matched_prefix:
continue
for old, new in exact_pairs.items():
if key == old or key == f"observation.{old}":
if key in features:
features[new] = features.pop(key)
break
return features

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,51 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import (
ObservationProcessor,
ProcessorStepRegistry,
)
@dataclass
@ProcessorStepRegistry.register(name="rename_processor")
class RenameProcessor(ObservationProcessor):
"""Rename processor that renames keys in the observation."""
rename_map: dict[str, str] = field(default_factory=dict)
def observation(self, observation):
processed_obs = {}
for key, value in observation.items():
if key in self.rename_map:
processed_obs[self.rename_map[key]] = value
else:
processed_obs[key] = value
return processed_obs
def get_config(self) -> dict[str, Any]:
return {"rename_map": self.rename_map}
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms:
- Each key in the observation that appears in `rename_map` is renamed to its value.
- Keys not in `rename_map` remain unchanged.
"""
return {self.rename_map.get(k, k): v for k, v in features.items()}