Files
lerobot/tests/processor/test_batch_conversion.py
Adil Zouitine 88f7bf01c1 feat(pipeline): universal processor for LeRobot (#1431)
* Refactor observation preprocessing to use a modular pipeline system

- Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations.
- Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline.
- Added tests for the new processing components and ensured they match the original functionality.
- Removed hardcoded logic in favor of a more flexible, composable architecture.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactor observation processing and improve modularity

- Updated `ObservationProcessor` to enhance the modular design for processing observations.
- Cleaned up imports and improved code readability by removing unnecessary lines and comments.
- Ensured backward compatibility while integrating new processing components.
- Added tests to validate the functionality of the updated processing architecture.

* Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability.

* Refactor processing architecture to use RobotProcessor

- Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity.
- Introduced ProcessorStepRegistry for better management of processing steps.
- Updated relevant documentation and tests to reflect the new processing structure.
- Enhanced the save/load functionality to support the new processor design.
- Added a model card template for RobotProcessor to facilitate sharing and documentation.

* Add RobotProcessor tutorial to documentation

- Introduced a new tutorial on using RobotProcessor for preprocessing robot data.
- Added a section in the table of contents for easy navigation to the new tutorial.
- The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add normalization processor and related components

- Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization.
- Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks.
- Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports.
- Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity.
- Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Enhance processing architecture with new components

- Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility.
- Updated `__init__.py` to include `RenameProcessor` in module exports.
- Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling.
- Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness.

* chore (docs): add docstring for processor

* fix (test): test factory

* fix(test): policies

* Update tests/processor/test_observation_processor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>

* chore(test): add suggestion made by copilot regarding numpy test

* fix(test): import issue

* Refactor normalization components and update tests

- Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity.
- Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`.
- Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes.
- Enhanced handling of missing statistics in normalization processes.

* chore (docstrin):Improve docstring for NormalizerProcessor

* feat (device processor): Implement device processor

* chore (batch handling): Enhance processing components with batch conversion utilities

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix(test): linting issue

* chore (output format): improves output format

* chore (type): add typing for multiprocess envs

* feat (overrides): Implement support for loading processors with parameter overrides

- Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter.
- Enhanced error handling for invalid override keys and instantiation errors.
- Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps.
- Added comprehensive tests to validate the new functionality and ensure backward compatibility.

* chore(normalization): addressing comments from copilot

* chore(learner): nit comment from copilot

* feat(pipeline): Enhance step_through method to support both tuple and dict inputs

* refactor(pipeline): Simplify observation and padding data handling in batch transitions

* Apply suggestions from code review

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(pipeline): Transition from tuple to dictionary format for EnvTransition

- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability.
- Replaced instances of TransitionIndex with TransitionKey for accessing transition components.
- Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.

* refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling

- Introduced constants for observation keys to enhance readability.
- Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly.
- Updated the environment state and agent position assignments to use the new constants, improving maintainability.

* feat(pipeline): Add hook unregistration functionality and enhance documentation

- Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management.
- Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks.
- Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks.

* refactor(pipeline): Clarify hook behavior and improve documentation

- Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability.
- Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions.
- Enhanced documentation to clearly outline the purpose of hooks and their execution semantics.
- Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method.

* feat(pipeline): Add __repr__ method to RobotProcessor for improved readability

- Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed.
- Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values.
- Ensured that the representation handles long lists of steps with truncation for better readability.

* chore(pipeline): Move _CFG_NAME along other class member

* refactor(pipeline): Utilize get_safe_torch_device for device assignment

- Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling.
- This change enhances code readability and maintains consistency in device management across the RobotProcessor class.

* refactor(pipeline): Enhance state filename generation and profiling method

- Updated state filename generation to use the registry name when available, improving clarity in saved files.
- Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling.
- Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results.

* chore(doc): address pip install commant lerobot that not exist yet

* feat(pipeline): Enhance configuration filename handling and state file naming

- Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default.
- Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved.
- Added automatic detection for configuration files when loading from a directory, with error handling for multiple files.
- Updated tests to validate new features, including custom filenames and automatic config detection.

* refactor(pipeline): Improve state file naming conventions for clarity and uniqueness

- Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory.
- Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts.
- Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files.

* docs(pipeline): Add clarification for repo name sanitization process

* Feat/pipeline add feature contract (#1637)

* Add feature contract to pipelinestep and pipeline

* Add tests

* Add processor tests

* PR feedback

* encorperate pr feedback

* type in doc

* oops

* docs(pipeline): Clarify transition handling and hook behavior

- Updated documentation to specify that hooks always receive transitions in EnvTransition format, ensuring consistent behavior across input formats.
- Refactored the step_through method to yield only EnvTransition objects, regardless of the input format, and updated related tests to reflect this change.
- Enhanced test assertions to verify the structure of results and the correctness of processing steps.

* refactor(pipeline): Remove to() method for device management

- Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices.
- Removed associated unit tests that validated the functionality of the to() method across various scenarios.
- Streamlined the pipeline code by focusing on other device management strategies.

* refactor(pipeline): Remove model card generation and streamline processor methods

- Eliminated the _generate_model_card method from RobotProcessor, which was responsible for generating README.md files from a template.
- Updated save_pretrained method to remove model card generation, focusing on serialization of processor definitions and parameters.
- Added default implementations for get_config, state_dict, load_state_dict, reset, and feature_contract methods in various processor classes to enhance consistency and usability.

* refactor(observation): Streamline observation preprocessing and remove unused processor methods

- Updated the `preprocess_observation` function to enhance image handling and ensure proper tensor formatting.
- Removed the `RobotProcessor` and associated transition handling from the `rollout` function, simplifying the observation processing flow.
- Integrated direct calls to `preprocess_observation` for improved clarity and efficiency in the evaluation script.

* refactor(pipeline): Rename parameters for clarity and enhance save/load functionality

- Updated parameter names in the save_pretrained and from_pretrained methods for improved readability, changing destination_path to save_directory and source to pretrained_model_name_or_path.
- Enhanced the save_pretrained method to ensure directory creation and file handling is consistent with the new parameter names.
- Streamlined the loading process in from_pretrained to utilize loaded_config for better clarity and maintainability.

* refactor(pipeline): minor improvements (#1684)

* chore(pipeline): remove unused features + device torch + envtransition keys

* refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor

* refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code

* test(pipeline): fix broken test after refactors

* docs(pipeline): update docstrings VanillaObservationProcessor

* chore(pipeline): move None check to base pipeline classes

---------

Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-08-06 16:11:04 +02:00

283 lines
11 KiB
Python

import torch
from lerobot.processor.pipeline import (
RobotProcessor,
TransitionKey,
_default_batch_to_transition,
_default_transition_to_batch,
)
def _dummy_batch():
"""Create a dummy batch using the new format with observation.* and next.* keys."""
return {
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.image.right": torch.randn(1, 3, 128, 128),
"observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
"action": torch.tensor([[0.5]]),
"next.reward": 1.0,
"next.done": False,
"next.truncated": False,
"info": {"key": "value"},
}
def test_observation_grouping_roundtrip():
"""Test that observation.* keys are properly grouped and ungrouped."""
proc = RobotProcessor([])
batch_in = _dummy_batch()
batch_out = proc(batch_in)
# Check that all observation.* keys are preserved
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")}
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")}
assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys())
# Check tensor values
assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"])
assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"])
assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"])
# Check other fields
assert torch.allclose(batch_out["action"], batch_in["action"])
assert batch_out["next.reward"] == batch_in["next.reward"]
assert batch_out["next.done"] == batch_in["next.done"]
assert batch_out["next.truncated"] == batch_in["next.truncated"]
assert batch_out["info"] == batch_in["info"]
def test_batch_to_transition_observation_grouping():
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
batch = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.state": [1, 2, 3, 4],
"action": "action_data",
"next.reward": 1.5,
"next.done": True,
"next.truncated": False,
"info": {"episode": 42},
}
transition = _default_batch_to_transition(batch)
# Check observation is a dict with all observation.* keys
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
assert "observation.image.top" in transition[TransitionKey.OBSERVATION]
assert "observation.image.left" in transition[TransitionKey.OBSERVATION]
assert "observation.state" in transition[TransitionKey.OBSERVATION]
# Check values are preserved
assert torch.allclose(
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
)
assert torch.allclose(
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
)
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
# Check other fields
assert transition[TransitionKey.ACTION] == "action_data"
assert transition[TransitionKey.REWARD] == 1.5
assert transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {"episode": 42}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
def test_transition_to_batch_observation_flattening():
"""Test that _default_transition_to_batch correctly flattens observation dict."""
observation_dict = {
"observation.image.top": torch.randn(1, 3, 128, 128),
"observation.image.left": torch.randn(1, 3, 128, 128),
"observation.state": [1, 2, 3, 4],
}
transition = {
TransitionKey.OBSERVATION: observation_dict,
TransitionKey.ACTION: "action_data",
TransitionKey.REWARD: 1.5,
TransitionKey.DONE: True,
TransitionKey.TRUNCATED: False,
TransitionKey.INFO: {"episode": 42},
TransitionKey.COMPLEMENTARY_DATA: {},
}
batch = _default_transition_to_batch(transition)
# Check that observation.* keys are flattened back to batch
assert "observation.image.top" in batch
assert "observation.image.left" in batch
assert "observation.state" in batch
# Check values are preserved
assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"])
assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"])
assert batch["observation.state"] == [1, 2, 3, 4]
# Check other fields are mapped to next.* format
assert batch["action"] == "action_data"
assert batch["next.reward"] == 1.5
assert batch["next.done"]
assert not batch["next.truncated"]
assert batch["info"] == {"episode": 42}
def test_no_observation_keys():
"""Test behavior when there are no observation.* keys."""
batch = {
"action": "action_data",
"next.reward": 2.0,
"next.done": False,
"next.truncated": True,
"info": {"test": "no_obs"},
}
transition = _default_batch_to_transition(batch)
# Observation should be None when no observation.* keys
assert transition[TransitionKey.OBSERVATION] is None
# Check other fields
assert transition[TransitionKey.ACTION] == "action_data"
assert transition[TransitionKey.REWARD] == 2.0
assert not transition[TransitionKey.DONE]
assert transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
# Round trip should work
reconstructed_batch = _default_transition_to_batch(transition)
assert reconstructed_batch["action"] == "action_data"
assert reconstructed_batch["next.reward"] == 2.0
assert not reconstructed_batch["next.done"]
assert reconstructed_batch["next.truncated"]
assert reconstructed_batch["info"] == {"test": "no_obs"}
def test_minimal_batch():
"""Test with minimal batch containing only observation.* and action."""
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
transition = _default_batch_to_transition(batch)
# Check observation
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
assert transition[TransitionKey.ACTION] == "minimal_action"
# Check defaults
assert transition[TransitionKey.REWARD] == 0.0
assert not transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
assert reconstructed_batch["observation.state"] == "minimal_state"
assert reconstructed_batch["action"] == "minimal_action"
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
assert not reconstructed_batch["next.truncated"]
assert reconstructed_batch["info"] == {}
def test_empty_batch():
"""Test behavior with empty batch."""
batch = {}
transition = _default_batch_to_transition(batch)
# All fields should have defaults
assert transition[TransitionKey.OBSERVATION] is None
assert transition[TransitionKey.ACTION] is None
assert transition[TransitionKey.REWARD] == 0.0
assert not transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
assert reconstructed_batch["action"] is None
assert reconstructed_batch["next.reward"] == 0.0
assert not reconstructed_batch["next.done"]
assert not reconstructed_batch["next.truncated"]
assert reconstructed_batch["info"] == {}
def test_complex_nested_observation():
"""Test with complex nested observation data."""
batch = {
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
"observation.state": torch.randn(7),
"action": torch.randn(8),
"next.reward": 3.14,
"next.done": False,
"next.truncated": True,
"info": {"episode_length": 200, "success": True},
}
transition = _default_batch_to_transition(batch)
reconstructed_batch = _default_transition_to_batch(transition)
# Check that all observation keys are preserved
original_obs_keys = {k for k in batch if k.startswith("observation.")}
reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")}
assert original_obs_keys == reconstructed_obs_keys
# Check tensor values
assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"])
# Check nested dict with tensors
assert torch.allclose(
batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"]
)
assert torch.allclose(
batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"]
)
# Check action tensor
assert torch.allclose(batch["action"], reconstructed_batch["action"])
# Check other fields
assert batch["next.reward"] == reconstructed_batch["next.reward"]
assert batch["next.done"] == reconstructed_batch["next.done"]
assert batch["next.truncated"] == reconstructed_batch["next.truncated"]
assert batch["info"] == reconstructed_batch["info"]
def test_custom_converter():
"""Test that custom converters can still be used."""
def to_tr(batch):
# Custom converter that modifies the reward
tr = _default_batch_to_transition(batch)
# Double the reward
reward = tr.get(TransitionKey.REWARD, 0.0)
new_tr = tr.copy()
new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0
return new_tr
def to_batch(tr):
batch = _default_transition_to_batch(tr)
return batch
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
batch = {
"observation.state": torch.randn(1, 4),
"action": torch.randn(1, 2),
"next.reward": 1.0,
"next.done": False,
}
result = processor(batch)
# Check the reward was doubled by our custom converter
assert result["next.reward"] == 2.0
assert torch.allclose(result["observation.state"], batch["observation.state"])
assert torch.allclose(result["action"], batch["action"])