* Refactor observation preprocessing to use a modular pipeline system - Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations. - Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline. - Added tests for the new processing components and ensured they match the original functionality. - Removed hardcoded logic in favor of a more flexible, composable architecture. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor observation processing and improve modularity - Updated `ObservationProcessor` to enhance the modular design for processing observations. - Cleaned up imports and improved code readability by removing unnecessary lines and comments. - Ensured backward compatibility while integrating new processing components. - Added tests to validate the functionality of the updated processing architecture. * Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability. * Refactor processing architecture to use RobotProcessor - Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity. - Introduced ProcessorStepRegistry for better management of processing steps. - Updated relevant documentation and tests to reflect the new processing structure. - Enhanced the save/load functionality to support the new processor design. - Added a model card template for RobotProcessor to facilitate sharing and documentation. * Add RobotProcessor tutorial to documentation - Introduced a new tutorial on using RobotProcessor for preprocessing robot data. - Added a section in the table of contents for easy navigation to the new tutorial. - The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Transition from tuple to dictionary format for EnvTransition - Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase. * refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling - Introduced constants for observation keys to enhance readability. - Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly. - Updated the environment state and agent position assignments to use the new constants, improving maintainability. * feat(pipeline): Add hook unregistration functionality and enhance documentation - Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks. * refactor(pipeline): Clarify hook behavior and improve documentation - Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method. * feat(pipeline): Add __repr__ method to RobotProcessor for improved readability - Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed. - Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values. - Ensured that the representation handles long lists of steps with truncation for better readability. * chore(pipeline): Move _CFG_NAME along other class member * refactor(pipeline): Utilize get_safe_torch_device for device assignment - Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling. - This change enhances code readability and maintains consistency in device management across the RobotProcessor class. * refactor(pipeline): Enhance state filename generation and profiling method - Updated state filename generation to use the registry name when available, improving clarity in saved files. - Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling. - Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results. * chore(doc): address pip install commant lerobot that not exist yet * feat(pipeline): Enhance configuration filename handling and state file naming - Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection. * refactor(pipeline): Improve state file naming conventions for clarity and uniqueness - Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files. * docs(pipeline): Add clarification for repo name sanitization process * Feat/pipeline add feature contract (#1637) * Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops * docs(pipeline): Clarify transition handling and hook behavior - Updated documentation to specify that hooks always receive transitions in EnvTransition format, ensuring consistent behavior across input formats. - Refactored the step_through method to yield only EnvTransition objects, regardless of the input format, and updated related tests to reflect this change. - Enhanced test assertions to verify the structure of results and the correctness of processing steps. * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * refactor(pipeline): Remove model card generation and streamline processor methods - Eliminated the _generate_model_card method from RobotProcessor, which was responsible for generating README.md files from a template. - Updated save_pretrained method to remove model card generation, focusing on serialization of processor definitions and parameters. - Added default implementations for get_config, state_dict, load_state_dict, reset, and feature_contract methods in various processor classes to enhance consistency and usability. * refactor(observation): Streamline observation preprocessing and remove unused processor methods - Updated the `preprocess_observation` function to enhance image handling and ensure proper tensor formatting. - Removed the `RobotProcessor` and associated transition handling from the `rollout` function, simplifying the observation processing flow. - Integrated direct calls to `preprocess_observation` for improved clarity and efficiency in the evaluation script. * refactor(pipeline): Rename parameters for clarity and enhance save/load functionality - Updated parameter names in the save_pretrained and from_pretrained methods for improved readability, changing destination_path to save_directory and source to pretrained_model_name_or_path. - Enhanced the save_pretrained method to ensure directory creation and file handling is consistent with the new parameter names. - Streamlined the loading process in from_pretrained to utilize loaded_config for better clarity and maintainability. * refactor(pipeline): minor improvements (#1684) * chore(pipeline): remove unused features + device torch + envtransition keys * refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor * refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code * test(pipeline): fix broken test after refactors * docs(pipeline): update docstrings VanillaObservationProcessor * chore(pipeline): move None check to base pipeline classes --------- Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
283 lines
11 KiB
Python
283 lines
11 KiB
Python
import torch
|
|
|
|
from lerobot.processor.pipeline import (
|
|
RobotProcessor,
|
|
TransitionKey,
|
|
_default_batch_to_transition,
|
|
_default_transition_to_batch,
|
|
)
|
|
|
|
|
|
def _dummy_batch():
|
|
"""Create a dummy batch using the new format with observation.* and next.* keys."""
|
|
return {
|
|
"observation.image.left": torch.randn(1, 3, 128, 128),
|
|
"observation.image.right": torch.randn(1, 3, 128, 128),
|
|
"observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
|
|
"action": torch.tensor([[0.5]]),
|
|
"next.reward": 1.0,
|
|
"next.done": False,
|
|
"next.truncated": False,
|
|
"info": {"key": "value"},
|
|
}
|
|
|
|
|
|
def test_observation_grouping_roundtrip():
|
|
"""Test that observation.* keys are properly grouped and ungrouped."""
|
|
proc = RobotProcessor([])
|
|
batch_in = _dummy_batch()
|
|
batch_out = proc(batch_in)
|
|
|
|
# Check that all observation.* keys are preserved
|
|
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")}
|
|
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")}
|
|
|
|
assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys())
|
|
|
|
# Check tensor values
|
|
assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"])
|
|
assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"])
|
|
assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"])
|
|
|
|
# Check other fields
|
|
assert torch.allclose(batch_out["action"], batch_in["action"])
|
|
assert batch_out["next.reward"] == batch_in["next.reward"]
|
|
assert batch_out["next.done"] == batch_in["next.done"]
|
|
assert batch_out["next.truncated"] == batch_in["next.truncated"]
|
|
assert batch_out["info"] == batch_in["info"]
|
|
|
|
|
|
def test_batch_to_transition_observation_grouping():
|
|
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
|
|
batch = {
|
|
"observation.image.top": torch.randn(1, 3, 128, 128),
|
|
"observation.image.left": torch.randn(1, 3, 128, 128),
|
|
"observation.state": [1, 2, 3, 4],
|
|
"action": "action_data",
|
|
"next.reward": 1.5,
|
|
"next.done": True,
|
|
"next.truncated": False,
|
|
"info": {"episode": 42},
|
|
}
|
|
|
|
transition = _default_batch_to_transition(batch)
|
|
|
|
# Check observation is a dict with all observation.* keys
|
|
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
|
|
assert "observation.image.top" in transition[TransitionKey.OBSERVATION]
|
|
assert "observation.image.left" in transition[TransitionKey.OBSERVATION]
|
|
assert "observation.state" in transition[TransitionKey.OBSERVATION]
|
|
|
|
# Check values are preserved
|
|
assert torch.allclose(
|
|
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
|
|
)
|
|
assert torch.allclose(
|
|
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
|
|
)
|
|
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
|
|
|
# Check other fields
|
|
assert transition[TransitionKey.ACTION] == "action_data"
|
|
assert transition[TransitionKey.REWARD] == 1.5
|
|
assert transition[TransitionKey.DONE]
|
|
assert not transition[TransitionKey.TRUNCATED]
|
|
assert transition[TransitionKey.INFO] == {"episode": 42}
|
|
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
|
|
|
|
|
def test_transition_to_batch_observation_flattening():
|
|
"""Test that _default_transition_to_batch correctly flattens observation dict."""
|
|
observation_dict = {
|
|
"observation.image.top": torch.randn(1, 3, 128, 128),
|
|
"observation.image.left": torch.randn(1, 3, 128, 128),
|
|
"observation.state": [1, 2, 3, 4],
|
|
}
|
|
|
|
transition = {
|
|
TransitionKey.OBSERVATION: observation_dict,
|
|
TransitionKey.ACTION: "action_data",
|
|
TransitionKey.REWARD: 1.5,
|
|
TransitionKey.DONE: True,
|
|
TransitionKey.TRUNCATED: False,
|
|
TransitionKey.INFO: {"episode": 42},
|
|
TransitionKey.COMPLEMENTARY_DATA: {},
|
|
}
|
|
|
|
batch = _default_transition_to_batch(transition)
|
|
|
|
# Check that observation.* keys are flattened back to batch
|
|
assert "observation.image.top" in batch
|
|
assert "observation.image.left" in batch
|
|
assert "observation.state" in batch
|
|
|
|
# Check values are preserved
|
|
assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"])
|
|
assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"])
|
|
assert batch["observation.state"] == [1, 2, 3, 4]
|
|
|
|
# Check other fields are mapped to next.* format
|
|
assert batch["action"] == "action_data"
|
|
assert batch["next.reward"] == 1.5
|
|
assert batch["next.done"]
|
|
assert not batch["next.truncated"]
|
|
assert batch["info"] == {"episode": 42}
|
|
|
|
|
|
def test_no_observation_keys():
|
|
"""Test behavior when there are no observation.* keys."""
|
|
batch = {
|
|
"action": "action_data",
|
|
"next.reward": 2.0,
|
|
"next.done": False,
|
|
"next.truncated": True,
|
|
"info": {"test": "no_obs"},
|
|
}
|
|
|
|
transition = _default_batch_to_transition(batch)
|
|
|
|
# Observation should be None when no observation.* keys
|
|
assert transition[TransitionKey.OBSERVATION] is None
|
|
|
|
# Check other fields
|
|
assert transition[TransitionKey.ACTION] == "action_data"
|
|
assert transition[TransitionKey.REWARD] == 2.0
|
|
assert not transition[TransitionKey.DONE]
|
|
assert transition[TransitionKey.TRUNCATED]
|
|
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
|
|
|
|
# Round trip should work
|
|
reconstructed_batch = _default_transition_to_batch(transition)
|
|
assert reconstructed_batch["action"] == "action_data"
|
|
assert reconstructed_batch["next.reward"] == 2.0
|
|
assert not reconstructed_batch["next.done"]
|
|
assert reconstructed_batch["next.truncated"]
|
|
assert reconstructed_batch["info"] == {"test": "no_obs"}
|
|
|
|
|
|
def test_minimal_batch():
|
|
"""Test with minimal batch containing only observation.* and action."""
|
|
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
|
|
|
|
transition = _default_batch_to_transition(batch)
|
|
|
|
# Check observation
|
|
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
|
assert transition[TransitionKey.ACTION] == "minimal_action"
|
|
|
|
# Check defaults
|
|
assert transition[TransitionKey.REWARD] == 0.0
|
|
assert not transition[TransitionKey.DONE]
|
|
assert not transition[TransitionKey.TRUNCATED]
|
|
assert transition[TransitionKey.INFO] == {}
|
|
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
|
|
|
# Round trip
|
|
reconstructed_batch = _default_transition_to_batch(transition)
|
|
assert reconstructed_batch["observation.state"] == "minimal_state"
|
|
assert reconstructed_batch["action"] == "minimal_action"
|
|
assert reconstructed_batch["next.reward"] == 0.0
|
|
assert not reconstructed_batch["next.done"]
|
|
assert not reconstructed_batch["next.truncated"]
|
|
assert reconstructed_batch["info"] == {}
|
|
|
|
|
|
def test_empty_batch():
|
|
"""Test behavior with empty batch."""
|
|
batch = {}
|
|
|
|
transition = _default_batch_to_transition(batch)
|
|
|
|
# All fields should have defaults
|
|
assert transition[TransitionKey.OBSERVATION] is None
|
|
assert transition[TransitionKey.ACTION] is None
|
|
assert transition[TransitionKey.REWARD] == 0.0
|
|
assert not transition[TransitionKey.DONE]
|
|
assert not transition[TransitionKey.TRUNCATED]
|
|
assert transition[TransitionKey.INFO] == {}
|
|
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
|
|
|
# Round trip
|
|
reconstructed_batch = _default_transition_to_batch(transition)
|
|
assert reconstructed_batch["action"] is None
|
|
assert reconstructed_batch["next.reward"] == 0.0
|
|
assert not reconstructed_batch["next.done"]
|
|
assert not reconstructed_batch["next.truncated"]
|
|
assert reconstructed_batch["info"] == {}
|
|
|
|
|
|
def test_complex_nested_observation():
|
|
"""Test with complex nested observation data."""
|
|
batch = {
|
|
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
|
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
|
"observation.state": torch.randn(7),
|
|
"action": torch.randn(8),
|
|
"next.reward": 3.14,
|
|
"next.done": False,
|
|
"next.truncated": True,
|
|
"info": {"episode_length": 200, "success": True},
|
|
}
|
|
|
|
transition = _default_batch_to_transition(batch)
|
|
reconstructed_batch = _default_transition_to_batch(transition)
|
|
|
|
# Check that all observation keys are preserved
|
|
original_obs_keys = {k for k in batch if k.startswith("observation.")}
|
|
reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")}
|
|
|
|
assert original_obs_keys == reconstructed_obs_keys
|
|
|
|
# Check tensor values
|
|
assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"])
|
|
|
|
# Check nested dict with tensors
|
|
assert torch.allclose(
|
|
batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"]
|
|
)
|
|
assert torch.allclose(
|
|
batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"]
|
|
)
|
|
|
|
# Check action tensor
|
|
assert torch.allclose(batch["action"], reconstructed_batch["action"])
|
|
|
|
# Check other fields
|
|
assert batch["next.reward"] == reconstructed_batch["next.reward"]
|
|
assert batch["next.done"] == reconstructed_batch["next.done"]
|
|
assert batch["next.truncated"] == reconstructed_batch["next.truncated"]
|
|
assert batch["info"] == reconstructed_batch["info"]
|
|
|
|
|
|
def test_custom_converter():
|
|
"""Test that custom converters can still be used."""
|
|
|
|
def to_tr(batch):
|
|
# Custom converter that modifies the reward
|
|
tr = _default_batch_to_transition(batch)
|
|
# Double the reward
|
|
reward = tr.get(TransitionKey.REWARD, 0.0)
|
|
new_tr = tr.copy()
|
|
new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0
|
|
return new_tr
|
|
|
|
def to_batch(tr):
|
|
batch = _default_transition_to_batch(tr)
|
|
return batch
|
|
|
|
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
|
|
|
|
batch = {
|
|
"observation.state": torch.randn(1, 4),
|
|
"action": torch.randn(1, 2),
|
|
"next.reward": 1.0,
|
|
"next.done": False,
|
|
}
|
|
|
|
result = processor(batch)
|
|
|
|
# Check the reward was doubled by our custom converter
|
|
assert result["next.reward"] == 2.0
|
|
assert torch.allclose(result["observation.state"], batch["observation.state"])
|
|
assert torch.allclose(result["action"], batch["action"])
|