From 78b866116fd73f9c0c83f9a7bc8fde8ea4b9b96d Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Thu, 18 Sep 2025 15:25:26 +0200 Subject: [PATCH] feat(processors): use pipelines across the codebase (#1452) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Refactor observation preprocessing to use a modular pipeline system - Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations. - Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline. - Added tests for the new processing components and ensured they match the original functionality. - Removed hardcoded logic in favor of a more flexible, composable architecture. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor observation processing and improve modularity - Updated `ObservationProcessor` to enhance the modular design for processing observations. - Cleaned up imports and improved code readability by removing unnecessary lines and comments. - Ensured backward compatibility while integrating new processing components. - Added tests to validate the functionality of the updated processing architecture. * Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability. * Refactor processing architecture to use RobotProcessor - Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity. - Introduced ProcessorStepRegistry for better management of processing steps. - Updated relevant documentation and tests to reflect the new processing structure. - Enhanced the save/load functionality to support the new processor design. - Added a model card template for RobotProcessor to facilitate sharing and documentation. * Add RobotProcessor tutorial to documentation - Introduced a new tutorial on using RobotProcessor for preprocessing robot data. - Added a section in the table of contents for easy navigation to the new tutorial. - The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Transition from tuple to dictionary format for EnvTransition - Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase. * refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling - Introduced constants for observation keys to enhance readability. - Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly. - Updated the environment state and agent position assignments to use the new constants, improving maintainability. * feat(pipeline): Add hook unregistration functionality and enhance documentation - Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks. * refactor(pipeline): Clarify hook behavior and improve documentation - Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method. * feat(pipeline): Add __repr__ method to RobotProcessor for improved readability - Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed. - Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values. - Ensured that the representation handles long lists of steps with truncation for better readability. * chore(pipeline): Move _CFG_NAME along other class member * refactor(pipeline): Utilize get_safe_torch_device for device assignment - Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling. - This change enhances code readability and maintains consistency in device management across the RobotProcessor class. * refactor(pipeline): Enhance state filename generation and profiling method - Updated state filename generation to use the registry name when available, improving clarity in saved files. - Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling. - Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results. * chore(doc): address pip install commant lerobot that not exist yet * feat(pipeline): Enhance configuration filename handling and state file naming - Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection. * refactor(pipeline): Improve state file naming conventions for clarity and uniqueness - Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files. * docs(pipeline): Add clarification for repo name sanitization process * Feat/pipeline add feature contract (#1637) * Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops * docs(pipeline): Clarify transition handling and hook behavior - Updated documentation to specify that hooks always receive transitions in EnvTransition format, ensuring consistent behavior across input formats. - Refactored the step_through method to yield only EnvTransition objects, regardless of the input format, and updated related tests to reflect this change. - Enhanced test assertions to verify the structure of results and the correctness of processing steps. * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * refactor(pipeline): Remove model card generation and streamline processor methods - Eliminated the _generate_model_card method from RobotProcessor, which was responsible for generating README.md files from a template. - Updated save_pretrained method to remove model card generation, focusing on serialization of processor definitions and parameters. - Added default implementations for get_config, state_dict, load_state_dict, reset, and feature_contract methods in various processor classes to enhance consistency and usability. * refactor(observation): Streamline observation preprocessing and remove unused processor methods - Updated the `preprocess_observation` function to enhance image handling and ensure proper tensor formatting. - Removed the `RobotProcessor` and associated transition handling from the `rollout` function, simplifying the observation processing flow. - Integrated direct calls to `preprocess_observation` for improved clarity and efficiency in the evaluation script. * refactor(pipeline): Rename parameters for clarity and enhance save/load functionality - Updated parameter names in the save_pretrained and from_pretrained methods for improved readability, changing destination_path to save_directory and source to pretrained_model_name_or_path. - Enhanced the save_pretrained method to ensure directory creation and file handling is consistent with the new parameter names. - Streamlined the loading process in from_pretrained to utilize loaded_config for better clarity and maintainability. * refactor(pipeline): minor improvements (#1684) * chore(pipeline): remove unused features + device torch + envtransition keys * refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor * refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code * test(pipeline): fix broken test after refactors * docs(pipeline): update docstrings VanillaObservationProcessor * chore(pipeline): move None check to base pipeline classes * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * refactor(normalization): Remove unused state dict transformation methods and streamline imports - Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality. * refactor(normalization): Clean up imports in normalize_processor.py * feat(batch_processor): Add feature_contract method to ToBatchProcessor - Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor. - This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs. * fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. * refactor(processors): Standardize processor naming conventions - Updated processor names across various files to use a consistent "robot_preprocessor" and "robot_postprocessor" format. - Modified the make_processor functions in factory, act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet to reflect the new naming scheme. - Enhanced the pipeline configuration to align with the updated processor names, improving clarity and maintainability. * refactor(factory): Update processor configuration and type hints - Changed return type of get_policy_class to type[PreTrainedPolicy] for improved type safety. - Enhanced make_processor function to utilize dataset_stats in processor creation for better flexibility. - Updated ProcessorConfigKwargs to include dataset_stats, allowing for more comprehensive processor configurations. - Streamlined processor initialization by removing unnecessary kwargs and ensuring clarity in processor type handling. * refactor(factory, pi0fast): Update processor function names and parameters - Renamed make_pi0_processor to make_pi0fast_processor for clarity and consistency. - Updated parameter names in the factory's make_processor function to use pretrained_model_name_or_path instead of source, enhancing readability and alignment with naming conventions. * fix(train.py) push postprocessor with preprocessor - Add preprocesser policy overrides for device and rename_map - Add rename_map to DatasetRecordConfig (record.py) * refactor(device_processor): Update device handling and improve type hints - Changed device attribute type from torch.device to str for better clarity. - Introduced a private _device attribute to store the actual torch.device instance. - Updated tests to conditionally check for CUDA availability, ensuring compatibility across different environments. - Refactored device-related assertions in tests to use a consistent approach for device type verification. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test(tokenizer_processor): Add require_package decorator for transformers - Introduced @require_package("transformers") decorator in multiple test functions to ensure the transformers package is available before running tests. - This change enhances test reliability by preventing failures due to missing dependencies. * refactor(migrate_policy_normalization): Enhance preprocessor and postprocessor structure - Introduced RenameProcessor in the preprocessor to handle renaming features. - Combined input and output features in a single NormalizerProcessor for improved efficiency. - Updated RobotProcessor initialization to clarify step naming for preprocessor and postprocessor. - Added DeviceProcessor to both preprocessor and postprocessor for better device management. * Integrate pipeline and add phone teleop (#1681) * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * fix(ci): temporary fix on dataset deps version * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * refactor(train): Update memory pinning logic for mps compatibility * feat: initial commit phone teleop * ugly delta control * use quaternion * Refactor observation preprocessing to use a modular pipeline system - Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations. - Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline. - Added tests for the new processing components and ensured they match the original functionality. - Removed hardcoded logic in favor of a more flexible, composable architecture. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor observation processing and improve modularity - Updated `ObservationProcessor` to enhance the modular design for processing observations. - Cleaned up imports and improved code readability by removing unnecessary lines and comments. - Ensured backward compatibility while integrating new processing components. - Added tests to validate the functionality of the updated processing architecture. * Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability. * Refactor processing architecture to use RobotProcessor - Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity. - Introduced ProcessorStepRegistry for better management of processing steps. - Updated relevant documentation and tests to reflect the new processing structure. - Enhanced the save/load functionality to support the new processor design. - Added a model card template for RobotProcessor to facilitate sharing and documentation. * Add RobotProcessor tutorial to documentation - Introduced a new tutorial on using RobotProcessor for preprocessing robot data. - Added a section in the table of contents for easy navigation to the new tutorial. - The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Transition from tuple to dictionary format for EnvTransition - Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase. * refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling - Introduced constants for observation keys to enhance readability. - Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly. - Updated the environment state and agent position assignments to use the new constants, improving maintainability. * feat(pipeline): Add hook unregistration functionality and enhance documentation - Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks. * refactor(pipeline): Clarify hook behavior and improve documentation - Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method. * feat(pipeline): Add __repr__ method to RobotProcessor for improved readability - Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed. - Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values. - Ensured that the representation handles long lists of steps with truncation for better readability. * chore(pipeline): Move _CFG_NAME along other class member * refactor(pipeline): Utilize get_safe_torch_device for device assignment - Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling. - This change enhances code readability and maintains consistency in device management across the RobotProcessor class. * refactor(pipeline): Enhance state filename generation and profiling method - Updated state filename generation to use the registry name when available, improving clarity in saved files. - Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling. - Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results. * chore(doc): address pip install commant lerobot that not exist yet * feat(pipeline): Enhance configuration filename handling and state file naming - Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection. * refactor(pipeline): Improve state file naming conventions for clarity and uniqueness - Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files. * docs(pipeline): Add clarification for repo name sanitization process * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * Add debug + calib * cleanup * Add pipeline * fix int * Add record example * nit * Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops * cleaned up steps and integrated pipeline with feature_contract * refactor steps and robot to pipeline * cleanup pipeline * cleanup code further * make it run * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * refactor(normalization): Remove unused state dict transformation methods and streamline imports - Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality. * refactor(normalization): Clean up imports in normalize_processor.py * feat(batch_processor): Add feature_contract method to ToBatchProcessor - Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor. - This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs. * fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * Do some todos and cleanup * change feature_contract to dataset_features * use one method for conversion pipeline output to add_frame dict and use base processors where possible * Add back in and use record_loop * update todo * rename to_dataset_frame * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix reference frame * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * update data visualization * update teleop example * fix record bugs * Add replay * Not code * feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. * Add eval script * fix `q_curr` in InverseKinematicsEEToJoints to the IK solution * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * refactor(normalization): Remove unused state dict transformation methods and streamline imports - Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality. * refactor(normalization): Clean up imports in normalize_processor.py * feat(batch_processor): Add feature_contract method to ToBatchProcessor - Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor. - This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs. * fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. * refactor(processors): Standardize processor naming conventions - Updated processor names across various files to use a consistent "robot_preprocessor" and "robot_postprocessor" format. - Modified the make_processor functions in factory, act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet to reflect the new naming scheme. - Enhanced the pipeline configuration to align with the updated processor names, improving clarity and maintainability. * refactor(factory): Update processor configuration and type hints - Changed return type of get_policy_class to type[PreTrainedPolicy] for improved type safety. - Enhanced make_processor function to utilize dataset_stats in processor creation for better flexibility. - Updated ProcessorConfigKwargs to include dataset_stats, allowing for more comprehensive processor configurations. - Streamlined processor initialization by removing unnecessary kwargs and ensuring clarity in processor type handling. * Fix eval and android gripper * add some tests * refactor(factory, pi0fast): Update processor function names and parameters - Renamed make_pi0_processor to make_pi0fast_processor for clarity and consistency. - Updated parameter names in the factory's make_processor function to use pretrained_model_name_or_path instead of source, enhancing readability and alignment with naming conventions. * fix(train.py) push postprocessor with preprocessor - Add preprocesser policy overrides for device and rename_map - Add rename_map to DatasetRecordConfig (record.py) * Cleanup pr * fix more git diff pr issues * add path as type in save_pretrained * small nit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename test file * fix: make dataset_features/feature_contract is optional * fix tests * Encorperate pr feedback * clean up record.py * add ascii art, fix normal record * remove merge issues * fix merge * remove features * Add feedback PR * fix last 4 tests * remove features check * rename to transform_features * add transform_features * fix lekiwi eval and update eval api example --------- Signed-off-by: Adil Zouitine Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Adil Zouitine Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Michel Aractingi * refactor(TokenizerProcessor): improve dependency handling and observation management - Updated TokenizerProcessor to conditionally import AutoTokenizer based on the availability of the transformers library, enhancing flexibility. - Modified tokenizer attribute type to Any to accommodate scenarios where transformers may not be installed. - Improved observation handling by using a more concise approach to manage the transition dictionary, ensuring compatibility with existing data structures. - Added error handling for missing transformers library, providing clear guidance for users on installation requirements. * feat(dependencies): Add scipy as a required dependency - Included `scipy>=1.15.2` in the project dependencies to enhance functionality and support for scientific computing tasks. * feat(policies): convert save_policy_to_safetensors with pipeline * refactor(normalization): remove Normalize and Unnormalize classes - Deleted the Normalize and Unnormalize classes from the normalization module to streamline the codebase. - Updated tests to ensure compatibility with the removal of these classes, focusing on the new NormalizerProcessor and UnnormalizerProcessor implementations. - Enhanced the handling of normalization statistics and improved overall code clarity. * refactor(factory): streamline processor loading by removing unused comments - Removed commented-out code related to loading pretrained processors in the make_processor function. - This change enhances code clarity and maintains focus on the current implementation. * feat(DeviceProcessor): Enhance tensor processing with device detection and float dtype conversion - Improved the _process_tensor method to preserve GPU placement for tensors already on a GPU, facilitating multi-GPU training scenarios. - Introduced a new _detect_device method in TokenizerProcessor to ensure tokenized tensors match the device of existing tensors in transitions. - Added comprehensive unit tests to validate the functionality of device detection and float dtype conversion across various scenarios. * feat(tests): Add comprehensive tests for various policy processors - Introduced new test files for ACT, Classifier, Diffusion, PI0, SAC, SmolVLA, TDMPC, and VQBeT policy processors. - Each test file includes unit tests to validate functionality, including handling of batch sizes, device management, and data type conversions. - Enhanced test coverage to ensure robustness and reliability of processor implementations across different scenarios. * refactor(train): Remove unnecessary tensor device handling in training loop * Refactor`gym_manipulator.py` using the universal pipeline (#1650) * Migrate gym_manipulator to use the pipeline Added get_teleop_events function to capture relevant events from teleop devices unrelated to actions * Added the capability to record a dataset * Added the replay functionality with the pipeline * Refactored `actor.py` to use the pipeline * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * RL works at this commit - fixed actor.py and bugs in gym_manipulator * change folder structure to reduce the size of gym_manip * Refactored hilserl config * Remove dataset and mode from HilSerlEnvConfig to a GymManipulatorConfig to reduce verbose of configs during training * format docs * removed get_teleop_events from abc * Refactor environment configuration and processing pipeline for GymHIL support. Removed device attribute from HILSerlRobotEnvConfig, added DummyTeleopDevice for simulation, and updated processor creation to accommodate GymHIL environments. * Improved typing for HILRobotEnv config and GymManipulator config * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Migrated `gym_manipulator` to use a more modular structure similar to phone teleop * Refactor gripper handling and transition processing in HIL and robot kinematic processors - Updated gripper position handling to use a consistent key format across processors - Improved the EEReferenceAndDelta class to handle reference joint positions. - Added support for discrete gripper actions in the GripperVelocityToJoint processor. - Refactored the gym manipulator to improve modularity and clarity in processing steps. * Added delta_action_processor mapping wrapper * Added missing file delta_action_processor and improved imports in `gym_manipulator` * nit * Added missing file joint_observation_processor * Enhance processing architecture with new teleoperation processors - Introduced `AddTeleopActionAsComplimentaryData` and `AddTeleopEventsAsInfo` for integrating teleoperator actions and events into transitions. - Added `Torch2NumpyActionProcessor` and `Numpy2TorchActionProcessor` for seamless conversion between PyTorch tensors and NumPy arrays. - Updated `__init__.py` to include new processors in module exports, improving modularity and clarity in the processing pipeline. - GymHIL is now fully supported with HIL using the pipeline * Refactor configuration structure for gym_hil integration - Renamed sections for better readability, such as changing "Gym Wrappers Configuration" to "Processor Configuration." - Enhanced documentation with clear examples for dataset collection and policy evaluation configurations. * Enhance reset configuration and teleoperation event handling - Added `terminate_on_success` parameter to `ResetConfig` and `InterventionActionProcessor` for controlling episode termination behavior upon success detection. - Updated documentation to clarify the impact of `terminate_on_success` on data collection for reward classifier training. - Refactored teleoperation event handling to use `TeleopEvents` constants for improved readability and maintainability across various modules. * fix(keyboard teleop), delta action keys * Added transform features and feature contract * Added transform features for image crop * Enum for TeleopEvents * Update tranform_features delta action proc --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Remove HILEnvConfig references * chore(processor): Add default names for preprocessor and postprocessor in constants - Introduced `PREPROCESSOR_DEFAULT_NAME` and `POSTPROCESSOR_DEFAULT_NAME` constants for consistent naming across various processor implementations. - Updated processor creation in multiple policy files to utilize these constants, enhancing code readability and maintainability. - Modified the training script to load and save the preprocessor and postprocessor using the new constants. * feat(processor): multiple improvements to the pipeline porting (#1749) * [Port codebase pipeline] General fixes for RL and scripts (#1748) * Refactor dataset configuration in documentation and codebase - Updated dataset configuration keys from `dataset_root` to `root` and `num_episodes` to `num_episodes_to_record` for consistency. - Adjusted replay episode handling by renaming `episode` to `replay_episode`. - Enhanced documentation - added specific processor to transform from policy actions to delta actions * Added Robot action to tensor processor Added new processor script for dealing with gym specific action processing * removed RobotAction2Tensor processor; imrpoved choosing observations in actor * nit in delta action * added missing reset functions to kinematics * Adapt teleoperate and replay to pipeline similar to record * refactor(processors): move to inheritance (#1750) * fix(teleoperator): improvements phone implementation (#1752) * fix(teleoperator): protect shared state in phone implementation * refactor(teleop): separate classes in phone * fix: solve breaking changes (#1753) * refactor(policies): multiple improvements (#1754) * refactor(processor): simpler logic in device processor (#1755) * refactor(processor): euclidean distance in delta action processor (#1757) * refactor(processor): improvements to joint observations processor migration (#1758) * refactor(processor): improvements to tokenizer migration (#1759) * refactor(processor): improvements to tokenizer migration * fix(tests): tokenizer tests regression from #1750 * fix(processors): fix float comparison and config in hil processors (#1760) * chore(teleop): remove unnecessary callbacks in KeyboardEndEffectorTeleop (#1761) * refactor(processor): improvements normalize pipeline migration (#1756) * refactor(processor): several improvements normalize processor step * refactor(processor): more improvements normalize processor * refactor(processor): more changes to normalizer * refactor(processor): take a different approach to DRY * refactor(processor): final design * chore(record): revert comment and continue deleted (#1764) * refactor(examples): pipeline phone examples (#1769) * refactor(examples): phone teleop + teleop script * refactor(examples): phone replay + replay * chore(examples): rename phone example files & folders * feat(processor): fix improvements to the pipeline porting (#1796) * refactor(processor): enhance tensor device handling in normalization process (#1795) * refactor(tests): remove unsupported device detection test for complementary data (#1797) * chore(tests): update ToBatchProcessor test (#1798) * refactor(tests): remove in-place mutation tests for actions and complementary data in batch processor * test(tests): add tests for action and task processing in batch processor * add names for android and ios phone (#1799) * use _tensor_stats in normalize processor (#1800) * fix(normalize_processor): correct device reference for tensor epsilon handling (#1801) * add point 5 add missing feature contracts (#1806) * Fix PR comments 1452 (#1807) * use key to determine image * Address rest of PR comments * use PolicyFeatures in transform_features --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --------- Co-authored-by: Michel Aractingi Co-authored-by: Adil Zouitine Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * refactor(constants, processor): standardize action and observation keys across multiple files (#1808) - Added new constants for truncated and done states in constants.py. - Updated references to action and observation keys in pipeline_features.py, converters.py, hil_processor.py, tokenizer_processor.py, and robot_kinematic_processor.py to use the new constants for improved readability and maintainability. * refactor(processor): improve processor pipeline typing with generic type (#1810) * refactor(processor): introduce generic type for to_output - Always return `TOutput` - Remove `_prepare_transition`, so `__call__` now always returns `TOutput` - Update tests accordingly - This refactor paves the way for adding settings for `to_transition` and `to_output` in `make_processor` and the post-processor * refactor(processor): consolidate ProcessorKwargs usage across policies - Removed the ProcessorTypes module and integrated ProcessorKwargs directly into the processor pipeline. - Updated multiple policy files to utilize the new ProcessorKwargs structure for preprocessor and postprocessor arguments. - Simplified the handling of processor kwargs by initializing them to empty dictionaries when not provided. * refactor(converters): implement unified tensor conversion function (#1830) - Introduced `to_tensor` function using `singledispatch` to handle various input types, including scalars, arrays, and dictionaries, converting them to PyTorch tensors. - Replaced previous tensor conversion logic in `gym_action_processor`, `normalize_processor`, and `test_converters` with the new `to_tensor` function for improved readability and maintainability. - Updated tests to cover new functionality and ensure correct tensor conversion behavior. * Revert "refactor(converters): implement unified tensor conversion function (#…" (#1840) This reverts commit a837685bf870919fc07ada287a71711cebabb1ea. * refactor(converters): implement unified tensor conversion function (#1841) - Introduced `to_tensor` function using `singledispatch` to handle various input types, including scalars, arrays, and dictionaries, converting them to PyTorch tensors. - Replaced previous tensor conversion logic in `gym_action_processor`, `normalize_processor`, and `test_converters` with the new `to_tensor` function for improved readability and maintainability. - Updated tests to cover new functionality and ensure correct tensor conversion behavior. Co-authored-by: AdilZouitine * refactor(converters): gather converters and refactor the logic (#1833) * refactor(converters): move batch transition functions to converters module - Moved `_default_batch_to_transition` and `_default_transition_to_batch` functions from `pipeline.py` to `converters.py` for better organization and separation of concerns. - Updated references in `RobotProcessor` to use the new location of these functions. - Added tests to ensure correct functionality of the transition functions, including handling of index and task_index fields. - Removed redundant tests from `pipeline.py` to streamline the test suite. * refactor(processor): reorganize EnvTransition and TransitionKey definitions - Moved `EnvTransition` and `TransitionKey` classes from `pipeline.py` to a new `core.py` module for better structure and maintainability. - Updated import statements across relevant modules to reflect the new location of these definitions, ensuring consistent access throughout the codebase. * refactor(converters): rename and update dataset frame conversion functions - Replaced `to_dataset_frame` with `transition_to_dataset_frame` for clarity and consistency in naming. - Updated references in `record.py`, `pipeline.py`, and tests to use the new function name. - Introduced `merge_transitions` to streamline the merging of transitions, enhancing readability and maintainability. - Adjusted related tests to ensure correct functionality with the new naming conventions. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(processor): solve conflict artefacts * refactor(converters): remove unused identity function and update type hints for merge_transitions * refactor(processor): remove unused identity import and clean up gym_manipulator.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma * refactor(processors): add transform_features method to various processors (#1843) * refactor(processors): update transition handling in RewardClassifierProcessor and InverseKinematicsEEToJoints (#1844) * refactor(processors): unify import statements by consolidating pipeline imports into the main processor module (#1845) * refactor(processors): add extended api for specialized pipelines (#1848) * refactor(processors): enhance transform_features method across multiple processors (#1849) * refactor(processors): enhance transform_features method across multiple processors - Updated the transform_features method in various processors to utilize a copy of the features dictionary, ensuring immutability of the original features. - Added handling for new feature keys and removed obsolete ones in the MapTensorToDeltaActionDict, JointVelocityProcessor, and others. - Improved readability and maintainability by following consistent patterns in feature transformation. * refactor(processors): standardize action and observation keys in delta_action_processor and joint_observations_processor - Updated action and observation keys to use constants for improved readability and maintainability. - Refactored the transform_features method in multiple processors to ensure consistent handling of feature keys. - Enhanced error handling by raising exceptions for missing required components in action and observation processing. - Removed obsolete code and improved overall structure for better clarity. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(processors): remove unused import in joint_observations_processor * refactor(processors): simplify transform_features method in delta_action_processor * refactor(processors): streamline transform_features method in ImageCropResizeProcessor * refactor(processors): improve error handling and streamline transform_features method in phone_processor - Raised a ValueError for missing position and rotation in action to enhance error handling. * refactor(processors): enhance error handling in JointVelocityProcessor - Added a ValueError raise for missing current joint positions in the observation method to improve error handling and ensure the integrity of the transform_features method. * refactor(processors): simplify transform_features method in robot kinematic processors * refactor(processors): standardize action keys in phone_processor * fix(processor): RKP feature obs -> act --------- Signed-off-by: Adil Zouitine Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma * chore(processor): rename RobotProcessor -> DataProcessorPipeline (#1850) * chore(processor): rename specialized processor -> XYZProcessorStep (#1852) * chore(processor): rename converters function names (#1853) * chore(processor): rename to_transition_teleop_action -> action_to_transition * chore(processor): rename to_transition_robot_observation -> observation_to_transition * chore(processor): rename to_output_robot_action -> transition_to_robot_action * chore(processor): add Step suffix to all processors (#1854) * refactor(processor): rename MapDeltaActionToRobotAction and MapTensorToDeltaActionDict for consistency * refactor(processor): rename DeviceProcessor to DeviceProcessorStep for consistency across modules * refactor(processor): rename Torch2NumpyActionProcessor to Torch2NumpyActionProcessorStep for consistency * refactor(processor): rename Numpy2TorchActionProcessor to Numpy2TorchActionProcessorStep for consistency * refactor(processor): rename AddTeleopActionAsComplimentaryData to AddTeleopActionAsComplimentaryDataStep for consistency * refactor(processor): rename ImageCropResizeProcessor and AddTeleopEventsAsInfo for consistency * refactor(processor): rename TimeLimitProcessor to TimeLimitProcessorStep for consistency * refactor(processor): rename GripperPenaltyProcessor to GripperPenaltyProcessorStep for consistency * refactor(processor): rename InterventionActionProcessor to InterventionActionProcessorStep for consistency * refactor(processor): rename RewardClassifierProcessor to RewardClassifierProcessorStep for consistency * refactor(processor): rename JointVelocityProcessor to JointVelocityProcessorStep for consistency * refactor(processor): rename MotorCurrentProcessor to MotorCurrentProcessorStep for consistency * refactor(processor): rename NormalizerProcessor and UnnormalizerProcessor to NormalizerProcessorStep and UnnormalizerProcessorStep for consistency * refactor(processor): rename VanillaObservationProcessor to VanillaObservationProcessorStep for consistency * refactor(processor): rename RenameProcessor to RenameProcessorStep for consistency * refactor(processor): rename TokenizerProcessor to TokenizerProcessorStep for consistency * refactor(processor): rename ToBatchProcessor to AddBatchDimensionProcessorStep for consistency * refactor(processor): update config file name in test for RenameProcessorStep consistency * refactor(processor): rename internal tokenizer variable for clarity (#1855) - Changed the internal tokenizer variable name from `_tokenizer` to `input_tokenizer` for improved readability and consistency. - Updated references throughout the class to reflect the new variable name. * chore(processor): rename merge_features -> combine_feature_dicts (#1856) * refactor(processor): rename internal device variable for clarity (#1857) - Changed the internal device variable from `_device` to `tensor_device` for improved readability and consistency. - Updated references throughout the class to reflect the new variable name. * chore(processor): rename teleop_phone variable names (#1858) * chore(processor): add type alias RobotProcessorPipeline and PolicyProcessorPipeline (#1859) * feat(processor): introduce PolicyProcessorPipeline and RobotProcessorPipeline as type aliases for DataProcessorPipeline - Added PolicyProcessorPipeline and RobotProcessorPipeline type aliases to enhance clarity and maintainability in the processor module. - Updated the __all__ list to include the new pipelines for better module export consistency. * refactor(processor): replace DataProcessorPipeline with PolicyProcessorPipeline across multiple modules - Updated all instances of DataProcessorPipeline to PolicyProcessorPipeline in various processor files for consistency and clarity. - Adjusted function signatures to reflect the new pipeline type, enhancing maintainability and readability. * refactor(processor): update hotswap_stats function to use PolicyProcessorPipeline - Changed the parameter name from robot_processor to policy_processor for clarity. - Ensured consistency with recent updates to the processor module by reflecting the new pipeline type in the function signature. * refactor(processor): replace DataProcessorPipeline with PolicyProcessorPipeline in migrate_policy_normalization.py - Updated the preprocessor and postprocessor to use PolicyProcessorPipeline for consistency with recent changes in the processor module. - Enhanced clarity and maintainability by aligning with the new pipeline structure. * refactor(processor): update hotswap_stats to use PolicyProcessorPipeline - Changed the parameter type in hotswap_stats from DataProcessorPipeline to PolicyProcessorPipeline for consistency with recent updates. - Enhanced clarity by updating the function documentation to reflect the new pipeline type. * refactor(processor): replace DataProcessorPipeline with RobotProcessorPipeline across multiple files - Updated instances of DataProcessorPipeline to RobotProcessorPipeline in evaluate.py, record.py, replay.py, teleoperate.py, and other relevant files for consistency and clarity. - Adjusted function signatures and variable types to reflect the new pipeline structure, enhancing maintainability and readability. * refactor(processor): enforce config_filename requirement for HF Hub loading (#1860) - Updated the DataProcessorPipeline to require a specific config_filename when loading from Hugging Face Hub, enhancing clarity and preventing errors. - Simplified local path checks and improved error handling for invalid paths. - Adjusted tests to reflect the new requirement and ensure proper error handling for various loading scenarios. * feat(record): add transition features to dataset and handle scalar vs array formatting in converters (#1861) - Introduced new transition features (`next.reward`, `next.done`, `next.truncated`) in the dataset during recording. - Updated the `transition_to_dataset_frame` function to handle scalar values correctly, ensuring compatibility with expected array formats for reward, done, and truncated features. * refactor(pipeline): enforce ProcessorStep inheritance for pipeline steps (#1862) - Updated the DataProcessorPipeline to require that all steps inherit from ProcessorStep, enhancing type safety and clarity. - Adjusted tests to utilize a MockTokenizerProcessorStep that adheres to the ProcessorStep interface, ensuring consistent behavior across tests. - Refactored various mock step classes in tests to inherit from ProcessorStep for improved consistency and maintainability. * refactor(dependencies): remove scipy dependency and introduce custom rotation utilities (#1863) - Removed the scipy dependency from the project to streamline requirements. - Added a new `rotation.py` module containing a custom `Rotation` class that replicates essential functionalities of `scipy.spatial.transform.Rotation`, allowing for rotation vector, matrix, and quaternion conversions without external dependencies. - Updated the `robot_kinematic_processor.py` to utilize the new custom rotation utilities. * feat(teleoperation): introduce HasTeleopEvents protocol and enhance teleop event handling (#1866) - Added the HasTeleopEvents protocol to define a standard for teleoperators that provide control events. - Implemented a runtime check to ensure teleoperators implement the get_teleop_events() method. - Updated AddTeleopEventsAsInfoStep to utilize the new protocol, enhancing compatibility with custom teleoperators. - Improved documentation for clarity on teleoperation event extraction and compatibility with built-in teleoperators. * fix(deps): use in-house rotation utils over scipy throughout the codebase * refactor(constants): rename preprocessor and postprocessor constants for clarity (#1868) - Updated constant names from PREPROCESSOR_DEFAULT_NAME and POSTPROCESSOR_DEFAULT_NAME to POLICY_PREPROCESSOR_DEFAULT_NAME and POLICY_POSTPROCESSOR_DEFAULT_NAME for better context. - Adjusted references across multiple files to use the new constant names, ensuring consistency in the codebase. * refactor(tests): update processor test assertions to reflect new preprocessor and postprocessor names (#1869) - Changed assertions in multiple processor test files to verify the updated names from "robot_preprocessor" and "robot_postprocessor" to "policy_preprocessor" and "policy_postprocessor" for consistency with recent refactoring. * refactor(utils): simplify log_rerun_data function (#1864) * refactor(logging): enhance log_rerun_data to handle observation and action separately - Updated the `log_rerun_data` function to accept and log observation and action data more clearly, improving readability and maintainability. - Refactored the `record_loop` and `teleop_loop` functions to extract and pass observation and action data to `log_rerun_data`, ensuring consistent logging format. * refactor(tests): update test_log_rerun_data to align with log_rerun_data changes - Modified test cases in `test_visualization_utils.py` to extract and pass observation and action data separately to `log_rerun_data`, improving clarity and consistency with recent function updates. - Ensured that the tests reflect the new structure of `log_rerun_data` for better maintainability. * refactor(processors): simplify calls to log_rerun + replace lambda functions with identity_transition --------- Co-authored-by: Steven Palma * fix(processor): recover type inference for use of processors (#1873) * refactor(processors): Improve Normalization Processor Performance and Device/Dtype Adaptability (#1880) * refactor(processors): reorder processor steps for consistency across implementations - Updated the order of processor steps in multiple files to ensure consistency, placing AddBatchDimensionProcessorStep and DeviceProcessorStep before NormalizerProcessorStep. - Adjusted related test assertions to reflect the new order of steps in the preprocessor, enhancing clarity and maintainability. * refactor(normalization): remove dtype specification in tensor conversion for adaptation logic - Updated tensor conversion in the _NormalizationMixin class to remove explicit dtype specification, allowing for automatic adaptation of tensor types. - Adjusted related tests to ensure proper functionality with the new tensor conversion logic, verifying that normalizers adapt correctly to input types. * chore(docs): update doctrines pipeline files (#1872) * docs(processor): update docstrings batch_processor * docs(processor): update docstrings device_processor * docs(processor): update docstrings tokenizer_processor * update docstrings processor_act * update docstrings for pipeline_features * update docstrings for utils * update docstring for processor_diffusion * update docstrings factory * add docstrings to pi0 processor * add docstring to pi0fast processor * add docstring classifier processor * add docstring to sac processor * add docstring smolvla processor * add docstring to tdmpc processor * add docstring to vqbet processor * add docstrings to converters * add docstrings for delta_action_processor * add docstring to gym action processor * update hil processor * add docstring to joint obs processor * add docstring to migrate_normalize_processor * update docstrings normalize processor * update docstring normalize processor * update docstrings observation processor * update docstrings rename_processor * add docstrings robot_kinematic_processor * cleanup rl comments * add docstring to train.py * add docstring to teleoperate.py * add docstrings to phone_processor.py * add docstrings to teleop_phone.py * add docstrings to control_utils.py * add docstrings to visualization_utils.py --------- Co-authored-by: Pepijn * refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions (#1900) * refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions - Updated the `rollout` and `eval_policy` functions to accept preprocessor and postprocessor parameters, enhancing the flexibility of the evaluation pipeline. - Adjusted the implementation to apply preprocessing and postprocessing steps during policy evaluation, improving the overall data handling and processing flow. * refactor(eval): remove redundant observation device conversion in rollout function - Eliminated unnecessary device conversion for the observation dictionary within the `rollout` function, streamlining the code and enhancing readability. - This change simplifies the observation handling process, aligning with the preference for clearer solutions. * debug * refactor(utils): enhance task handling in add_envs_task function - Improved the `add_envs_task` function to validate the output of `task_description` and `task` calls, ensuring they return lists of strings. - Removed the use of `else` statement for environments without language instructions, simplifying the logic and enhancing readability. - Streamlined the observation dictionary handling by ensuring consistent data types for task attributes. * refactor(converters): rename _from_tensor to from_tensor_to_numpy for clarity (#1902) - Updated the function name from _from_tensor to from_tensor_to_numpy to better reflect its purpose of converting PyTorch tensors to numpy arrays or scalars. - Adjusted all references to the renamed function throughout the codebase to maintain consistency. - Enhanced the _NormalizationMixin class to reconstruct the stats dictionary from tensor stats using the new function, ensuring compatibility after loading state dicts. - Added tests to verify the correct reconstruction of stats and functionality of methods dependent on self.stats after loading. * refactor(pipeline): feature contract now categorizes between OBS or Action (#1867) * refactor(processor): signature of transform_features * refactor(processor): remove prefixes + processor respect new transform_features signature + update test accordingly * refactor(processor): rename now is only for visual * refactor(processor): update normalize processor * refactor(processor): update vanilla processor features * refactor(processor): feature contract now uses its own enum * chore(processor): rename renameprocessor * chore(processor): minor changes * refactor(processor): add create & change aggregate * refactor(processor): update aggregate * refactor(processor): simplify to functions, fix features contracts and rename function * test(processor): remove to converter tests as now they are very simple * chore(docs): recover docs joint observations processor * fix(processor): update RKP * fix(tests): recv diff test_pipeline * chore(tests): add docs to test * chore(processor): leave obs language constant untouched * fix(processor): correct new shape of feature in crop image processor * refactor(eval): specify type parameters for preprocessor and postprocessor in eval_policy function (#1904) * chore(processor): remove action prefixes (#1905) * test(processor): all processors use now the same create_transition (#1906) * test(processor): all processors use now the same create_transition * test(processor): use identity instead of lambda for transition in pipelines * fix(processor): specialized processors respect contract by raising if none (#1909) * fix(processor): specialized processor now raise * test(processor): fix tests for now raise specialized processors * test(processor): use identity in newly introduced pipeline * refactor(processor): clarify action types, distinguish PolicyAction, RobotAction, and EnvAction (#1908) * refactor(processor): split action from policy, robots and environment - Updated function names to robot_action_to_transition and robot_transition_to_action across multiple files to better reflect their purpose in processing robot actions. - Adjusted references in the RobotProcessorPipeline and related components to ensure compatibility with the new naming convention. - Enhanced type annotations for action parameters to improve code readability and maintainability. * refactor(converters): rename robot_transition_to_action to transition_to_robot_action - Updated function names across multiple files to improve clarity and consistency in processing robot actions. - Adjusted references in RobotProcessorPipeline and related components to align with the new naming convention. - Simplified action handling in the AddBatchDimensionProcessorStep by removing unnecessary checks for action presence. * refactor(converters): update references to transition_to_robot_action - Renamed all instances of robot_transition_to_action to transition_to_robot_action across multiple files for consistency and clarity in the processing of robot actions. - Adjusted the RobotProcessorPipeline configurations to reflect the new naming convention, enhancing code readability. * refactor(processor): update Torch2NumpyActionProcessorStep to extend ActionProcessorStep - Changed the base class of Torch2NumpyActionProcessorStep from PolicyActionProcessorStep to ActionProcessorStep, aligning it with the current architecture of action processing. - This modification enhances the clarity of the class's role in the processing pipeline. * fix(processor): main action processor can take also EnvAction --------- Co-authored-by: Steven Palma * refactor(processor): phone processor is now an RobotActionProcessorStep * fix(processor): use subprocessors in AddBatchDimensionProcessorStep only if we have the ingredients * fix(robots): remove action prefix hard-coded in teleop keyboard and gamepad * feat(processor): enhance type safety with generic DataProcessorPipeline for policy and robot pipelines (#1915) * refactor(processor): enhance type annotations for processors in record, replay, teleoperate, and control utils - Updated type annotations for preprocessor and postprocessor parameters in record_loop and predict_action functions to specify the expected dictionary types. - Adjusted robot_action_processor type in ReplayConfig and TeleoperateConfig to improve clarity and maintainability. - Ensured consistency in type definitions across multiple files, enhancing overall code readability. * refactor(processor): enhance type annotations for RobotProcessorPipeline in various files - Updated type annotations for RobotProcessorPipeline instances in evaluate.py, record.py, replay.py, teleoperate.py, and other related files to specify input and output types more clearly. - Introduced new type conversions for PolicyAction and EnvTransition to improve type safety and maintainability across the processing pipelines. - Ensured consistency in type definitions, enhancing overall code readability and reducing potential runtime errors. * refactor(processor): update transition handling in processors to use transition_to_batch - Replaced direct transition handling with transition_to_batch in various processor tests and implementations to ensure consistent batching of input data. - Updated assertions in tests to reflect changes in data structure, enhancing clarity and maintainability. - Improved overall code readability by standardizing the way transitions are processed across different processor types. * refactor(tests): standardize transition key usage in processor tests - Updated assertions in processor test files to utilize the TransitionKey for action references, enhancing consistency across tests. - Replaced direct string references with TransitionKey constants for improved readability and maintainability. - Ensured that all relevant tests reflect these changes, contributing to a more uniform approach in handling transitions. * refactor(processor): unify action imports and enhance type clarity across multiple files - Updated imports in various files to include RobotAction and PolicyAction directly from the processor module, improving clarity and consistency. - Removed redundant imports from core, streamlining the codebase and enhancing maintainability. - Adjusted type annotations and references in the RobotProcessorPipeline and related components to align with the new import structure, ensuring better type safety and readability. * refactor(processor): migrate policy normalization to use factory functions - Updated the migration script to utilize `make_pre_post_processors` and `make_policy_config` from `lerobot.policies.factory`, enhancing consistency with the current codebase. - Improved normalization statistics extraction and processor pipeline creation, ensuring compatibility with the new `PolicyProcessorPipeline` architecture. - Cleaned up configuration handling by removing unnecessary fields and adding normalization mapping directly to the config. - Enhanced type safety and readability by refining feature type and normalization mode handling. * debug(scripts): simplify record with processors (#1918) Co-authored-by: Adil Zouitine * refactor(processor): update migration script for policy normalization and hub integration - Modified the migration script to include a branch argument for pushing to the hub, enhancing flexibility in version control. - Improved error handling by ensuring the policy type is extracted from the configuration, promoting robustness. - Streamlined the process of saving and pushing model components to the hub, allowing for a single commit with optional PR creation. - Updated the commit message and description for better clarity on the migration changes and benefits, ensuring users are informed of the new architecture and usage. * fixes for processors used in phone teleop * fixes for rotation matrix * add empty obs and act in create_initial_features * use observation instead of obs * docs(processor): update docstrings pipeline (#1920) * chore(docs): Processor doc (#1685) * chore(docs): initialize doc * Added script for the second part of the processor doc * precommit style nit * improved part 2 of processor guide * Add comprehensive documentation for processors in robotics - Introduced a detailed guide on processors, covering their role in transforming raw robot data into model-ready inputs and vice versa. - Explained core concepts such as EnvTransition, ProcessorStep, and RobotProcessor, along with their functionalities. - Included examples of common processor steps like normalization, device management, batch processing, and text tokenization. - Provided insights on building complete pipelines, integrating processors into training loops, and saving/loading configurations. - Emphasized best practices and advanced features for effective usage of processors in robotics applications. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(docs): Enhance introduction to processors with additional converter functions - Updated the introduction to processors documentation to include default batch-to-transition and transition-to-batch converters. - Added detailed descriptions and examples for new specialized converter functions: `to_transition_teleop_action`, `to_transition_robot_observation`, `to_output_robot_action`, and `to_dataset_frame`. - Improved clarity on how these converters facilitate integration with existing robotics applications. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Improved doc implement_your_own_pipeline - Use normalization processor as default example - Add section on transform features - Add section on overrides. * Add phone docs and use pipeline for robots/teleop docs * Fix typo in documentation for adapters in robots/teleop section * Enhance documentation for processors with detailed explanations and examples - Updated the introduction to processors, clarifying the role of `EnvTransition` and `ProcessorStep`. - Introduced `DataProcessorPipeline` as a generic orchestrator for chaining processor steps. - Added comprehensive descriptions of new converter functions and their applications. - Improved clarity on type safety and the differences between `RobotProcessorPipeline` and `PolicyProcessorPipeline`. - Included examples for various processing scenarios, emphasizing best practices for data handling in robotics. * Enhance documentation for processor migration and debugging - Added detailed sections on the migration of models to the new `PolicyProcessorPipeline` system, including breaking changes and migration scripts. - Introduced a comprehensive guide for debugging processor pipelines, covering common issues, step-by-step inspection, and runtime monitoring techniques. - Updated examples to reflect new usage patterns and best practices for processor implementation and error handling. - Clarified the role of various processor steps and their configurations in the context of robotics applications. --------- Co-authored-by: Michel Aractingi Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Pepijn * docs: Add new section for debugging processor pipelines - Introduced a new documentation entry for debugging processor pipelines, enhancing the existing guide on processors. - This addition aims to provide users with insights and best practices for troubleshooting and optimizing their processor workflows. * fix(processor): phone examples (#1921) * fix(processor): phone examples * chore(processor): simplify gripper in phone example kinematic chain --------- Co-authored-by: Steven Palma * refactor(processors): several additions (#1926) * chore(processor): remove merge_transitions functions (#1925) * refactor(processors): move processors out of configs (#1927) * chore(processor): streamline combine_features_dict (#1928) * chore(policies): use new constants (#1929) * fix(deps): right version transformers (#1930) * fix(tests): add none + disable async tests for now (#1931) * refactor(processor): transform_features loop + EAFP (#1932) * fix(processors): make sure nested dict are also shallow copied (#1939) * refactor(processor): replace ModelHubMixin with HubMixin and enhance save_pretrained method (#1937) - Updated DataProcessorPipeline to use HubMixin instead of ModelHubMixin for improved functionality. - Refactored save_pretrained method to handle saving * refactor(docs): streamline monitoring hooks and enhance performance reporting - Removed the log_shapes and measure_performance hooks, simplifying the monitoring process to focus on NaN checks. - Updated performance reporting to include maximum processing times alongside average times for better insights. - Clarified documentation regarding the processing pipeline and feature transformations. * fix teleop, record and eval (#1940) * fix cmd record, eval * chore(processor): update input output of main 3 processors for better semantics (#1942) * chore(processor): update input output of main 3 processors for better semantics * refactor(processor): replace Any with RobotObservation for improved type safety in processors * fix(processors): no PolicyObservation * chore(processor): update with RobotObservation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: AdilZouitine Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * test(processor): fix batch expectation * feat(example): Add SO100 EE pipeline control (teleop+record) (#1943) * feat(examples): add ee so100 processors teleop & record * refactor(processor): improve FK processor for better use compatability * docs(processor): enhance tutorial on implementing custom processors - Updated the tutorial to use `NormalizerProcessorStep` as the primary example, clarifying its role in normalizing observations and actions. - Improved explanations of the need for custom processors, emphasizing data compatibility and processing requirements. - Added code snippets demonstrating the normalization process and the configuration of processor pipelines. - Enhanced the introduction to processors, detailing their function as translators between raw robot data and model inputs. - Included examples of real-world processor configurations for both training and inference scenarios. * docs(debug): enhance debugging guide for processor pipelines - Streamlined the introduction to clarify the challenges of debugging complex processor pipelines. - Expanded the section on hooks, detailing their purpose and implementation for runtime monitoring. - Introduced step-by-step debugging techniques, emphasizing the use of the `step_through()` method for inspecting intermediate states. - Added examples of feature validation to ensure data structure contracts are met. - Consolidated best practices for debugging, highlighting the synergy between hooks, step-through debugging, and feature validation. * chore(processors): tokenizers raises and remove tensor conversion (#1949) * chore(processor): remove unused transition_features dict * feat(ee): add so100_to_so100_EE replay and evaluate examples * chore(examples): homogenize style across example files (#1955) * chore(examples): homogenize style across example files * chore(examples): homogenize style across example files eval + replay * chore(examples): homogenize headers * test(async): fix feature manipulation (#1957) * test(async): fix feature manipulation * chore(processor): remove unused functions * fix(processor): Preserve stats overrides in normalizer load_state_dict and fix training resumption (#1958) * feat(processor): enhance normalization handling and state management - Added support for additional normalization modes including IDENTITY. - Introduced a new function `clean_state_dict` to remove specific substrings from state dict keys. - Implemented preservation of explicitly provided normalization statistics during state loading. - Updated training script to conditionally provide dataset statistics based on resume state. - Expanded tests to verify the correct behavior of stats override preservation and loading. * fix(train): remove redundant comment regarding state loading - Removed a comment that noted the preprocessor and postprocessor state is already loaded when resuming training, as it was deemed unnecessary for clarity. * test(processor): update tests to handle missing or invalid task keys - Modified tests to assert that the processor raises appropriate exceptions when the task key is missing or has an invalid value in the complementary data. - Ensured that the tests cover cases for None, integer, and mixed list task values, improving robustness against invalid inputs. * fix(processor): enforce signatures * chore(processor): update comments in record.py * test(processor): fix isinstance and cuda test * modify phone docs * fix(processor): reorder output steps to ensure correct processing sequence (#1961) - Moved DeviceProcessorStep to the end of the output steps in multiple processor files to maintain the intended processing order. - Updated corresponding tests to reflect the change in step order. * fix(processors): assumptions for robot_action_processor & teleop_action_processor (#1964) * fix(processors): new assumptions pipeline * fix(processors): ee jj phone teleop replay record working * chore(processors): update comments and default vars * chore(processor): remove unnecessary copy * chore(processor): added todo assumption gripper * fix(processors): eval using detected device * finish phone docs * fix correct image link * feat(processor): implement migration detection and error handling for processor configurations (#1968) * feat(processor): implement migration detection and error handling for processor configurations - Added ProcessorMigrationError to handle migration requirements for old model formats. - Enhanced DataProcessorPipeline.from_pretrained to include robust migration detection logic. - Implemented methods for resolving configuration sources, validating loaded configs, and checking for valid processor configurations. - Introduced comprehensive tests for migration detection and configuration validation to ensure correct behavior. * refactor(processor): simplify loading logic and enhance migration detection - Refactored DataProcessorPipeline to implement a simplified three-way loading strategy for configuration files. - Introduced explicit config_filename parameter to avoid ambiguity during loading. - Updated ProcessorMigrationError to provide clearer error messages for migration requirements. - Enhanced tests to cover new loading logic and ensure proper migration detection. - Removed deprecated methods related to config source resolution. * fix(processor) RL (#1953) * fix(gym_manipulator) general fixes to make it compitable * fix for dataset v3.0 * fix for gym_manipulator * add map policy action to robot action wrappers in a seperate scripts * added unittest for policy to robot bridge * fixes for gripper penalty * fix style * fix gamepad controller * fixes for sim teleop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify numpy2torch to a regular processor as a quick fix * missing imports?! * - Removed the use of `AddRobotObservationAsComplimentaryData` from `gym_manipulator` and thus the codebase - Added get_raw_joint_positions functions to RobotEnv - Pass raw_joint_positions as input to the action_pipeline in `gym_manipulator` - Add `InverseKinematicsRLStep` to be tailored towards the need of RL which requires the use of the IK solution as the main reference point of the control loop - Added the option `use_ik_solution` in `EEReferenceDelta` step to rely on the ik solution rather than the joint values * -Updated links to all the config files to place them in the new repo with configs compatible with the pipeline --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma * fix(tests): update test cases for loading pipelines with specific config filenames - Modified test cases to include explicit configuration filenames when loading pipelines in `test_policy_robot_bridge.py`. - Ensured that the tests reflect the correct loading behavior for both robot-to-policy and policy-to-robot transitions. * fix(examples): train mps processor (#1970) * fix(examples): train mps processor * fix(processor): add MPS compatibility for float64 tensors - Implemented a workaround to convert float64 tensors to float32 when using the MPS device, as MPS does not support float64. - Added unit tests to verify the automatic conversion of float64 tensors to float32 and ensure compatibility with various tensor types on the MPS device. --------- Co-authored-by: AdilZouitine --------- Signed-off-by: Adil Zouitine Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma Co-authored-by: Michel Aractingi Co-authored-by: Steven Palma Co-authored-by: Pepijn --- .gitignore | 4 + docs/source/_toctree.yml | 19 +- docs/source/backwardcomp.mdx | 56 + docs/source/debug_processor_pipeline.mdx | 299 ++ docs/source/hilserl.mdx | 443 ++- docs/source/hilserl_sim.mdx | 90 +- docs/source/il_robots.mdx | 15 +- docs/source/il_sim.mdx | 60 +- docs/source/implement_your_own_processor.mdx | 273 ++ docs/source/introduction_processors.mdx | 314 ++ docs/source/phone_teleop.mdx | 192 ++ docs/source/processors_robots_teleop.mdx | 151 + examples/3_train_policy.py | 8 +- examples/5_train_with_streaming.py | 14 +- examples/lekiwi/evaluate.py | 63 +- examples/lekiwi/record.py | 51 +- examples/lekiwi/replay.py | 33 +- examples/lekiwi/teleoperate.py | 35 +- examples/phone_to_so100/evaluate.py | 197 ++ examples/phone_to_so100/record.py | 204 ++ examples/phone_to_so100/replay.py | 99 + examples/phone_to_so100/teleoperate.py | 114 + examples/so100_to_so100_EE/evaluate.py | 198 ++ examples/so100_to_so100_EE/record.py | 203 ++ examples/so100_to_so100_EE/replay.py | 100 + examples/so100_to_so100_EE/teleoperate.py | 122 + pyproject.toml | 6 +- src/lerobot/configs/policies.py | 3 +- src/lerobot/configs/types.py | 6 + src/lerobot/constants.py | 9 + src/lerobot/datasets/pipeline_features.py | 141 + src/lerobot/datasets/utils.py | 455 ++- src/lerobot/envs/configs.py | 143 +- src/lerobot/envs/factory.py | 4 +- src/lerobot/envs/utils.py | 24 +- src/lerobot/policies/__init__.py | 11 + src/lerobot/policies/act/modeling_act.py | 16 - src/lerobot/policies/act/processor_act.py | 85 + .../policies/diffusion/modeling_diffusion.py | 16 - .../policies/diffusion/processor_diffusion.py | 92 + src/lerobot/policies/factory.py | 243 +- src/lerobot/policies/normalize.py | 420 --- src/lerobot/policies/pi0/modeling_pi0.py | 146 +- src/lerobot/policies/pi0/processor_pi0.py | 166 ++ .../policies/pi0fast/modeling_pi0fast.py | 15 - .../policies/pi0fast/processor_pi0fast.py | 92 + src/lerobot/policies/sac/modeling_sac.py | 63 +- src/lerobot/policies/sac/processor_sac.py | 92 + .../sac/reward_model/modeling_classifier.py | 15 - .../sac/reward_model/processor_classifier.py | 82 + .../policies/smolvla/modeling_smolvla.py | 171 +- .../policies/smolvla/processor_smolvla.py | 141 + src/lerobot/policies/tdmpc/modeling_tdmpc.py | 24 +- src/lerobot/policies/tdmpc/processor_tdmpc.py | 90 + src/lerobot/policies/vqbet/modeling_vqbet.py | 17 +- src/lerobot/policies/vqbet/processor_vqbet.py | 91 + src/lerobot/processor/__init__.py | 133 +- src/lerobot/processor/batch_processor.py | 254 ++ src/lerobot/processor/converters.py | 412 +++ src/lerobot/processor/core.py | 56 + .../processor/delta_action_processor.py | 145 + src/lerobot/processor/device_processor.py | 188 +- src/lerobot/processor/factory.py | 62 + src/lerobot/processor/gym_action_processor.py | 97 + src/lerobot/processor/hil_processor.py | 596 ++++ .../processor/joint_observations_processor.py | 211 ++ .../processor/migrate_policy_normalization.py | 646 +++++ src/lerobot/processor/normalize_processor.py | 710 +++-- .../processor/observation_processor.py | 149 +- src/lerobot/processor/pipeline.py | 2174 ++++++++------ src/lerobot/processor/policy_robot_bridge.py | 52 + src/lerobot/processor/rename_processor.py | 62 +- src/lerobot/processor/tokenizer_processor.py | 270 ++ src/lerobot/record.py | 171 +- src/lerobot/replay.py | 19 +- src/lerobot/robots/so100_follower/__init__.py | 3 +- .../so100_follower/config_so100_follower.py | 32 - .../robot_kinematic_processor.py | 616 ++++ .../so100_follower_end_effector.py | 200 -- src/lerobot/robots/utils.py | 5 +- src/lerobot/scripts/eval.py | 46 +- src/lerobot/scripts/rl/actor.py | 121 +- src/lerobot/scripts/rl/gym_manipulator.py | 2553 ++++------------- src/lerobot/scripts/rl/learner.py | 26 +- src/lerobot/scripts/train.py | 70 +- src/lerobot/teleoperate.py | 95 +- src/lerobot/teleoperators/__init__.py | 2 +- .../teleoperators/gamepad/gamepad_utils.py | 24 +- .../teleoperators/gamepad/teleop_gamepad.py | 43 + .../teleoperators/keyboard/teleop_keyboard.py | 74 +- src/lerobot/teleoperators/phone/__init__.py | 18 + .../teleoperators/phone/config_phone.py | 36 + .../teleoperators/phone/phone_processor.py | 110 + .../teleoperators/phone/teleop_phone.py | 421 +++ src/lerobot/teleoperators/utils.py | 12 + src/lerobot/utils/control_utils.py | 96 +- src/lerobot/utils/import_utils.py | 1 + src/lerobot/utils/rotation.py | 270 ++ src/lerobot/utils/train_utils.py | 12 +- src/lerobot/utils/visualization_utils.py | 83 +- .../actions.safetensors | 2 +- .../param_stats.safetensors | 4 +- .../actions.safetensors | 2 +- .../param_stats.safetensors | 4 +- .../pusht_diffusion_/actions.safetensors | 2 +- .../pusht_diffusion_/grad_stats.safetensors | 2 +- .../pusht_diffusion_/param_stats.safetensors | 4 +- .../policies/save_policy_to_safetensors.py | 13 +- .../actions.safetensors | 2 +- .../grad_stats.safetensors | 2 +- .../output_dict.safetensors | 2 +- .../param_stats.safetensors | 4 +- .../actions.safetensors | 2 +- .../grad_stats.safetensors | 2 +- .../output_dict.safetensors | 2 +- .../param_stats.safetensors | 4 +- tests/conftest.py | 10 +- tests/datasets/test_dataset_utils.py | 132 + tests/datasets/test_utils.py | 86 - tests/policies/test_policies.py | 110 +- tests/processor/test_act_processor.py | 412 +++ tests/processor/test_batch_conversion.py | 56 +- tests/processor/test_batch_processor.py | 1184 ++++++++ tests/processor/test_classifier_processor.py | 362 +++ tests/processor/test_converters.py | 292 ++ tests/processor/test_device_processor.py | 1161 ++++++++ tests/processor/test_diffusion_processor.py | 398 +++ tests/processor/test_migration_detection.py | 341 +++ tests/processor/test_normalize_processor.py | 1415 ++++++++- tests/processor/test_observation_processor.py | 209 +- tests/processor/test_pi0_processor.py | 424 +++ tests/processor/test_pipeline.py | 753 +++-- .../test_pipeline_from_pretrained_helpers.py | 259 ++ tests/processor/test_policy_robot_bridge.py | 525 ++++ tests/processor/test_rename_processor.py | 200 +- tests/processor/test_sac_processor.py | 414 +++ tests/processor/test_smolvla_processor.py | 459 +++ tests/processor/test_tdmpc_processor.py | 467 +++ tests/processor/test_tokenizer_processor.py | 1029 +++++++ tests/processor/test_vqbet_processor.py | 462 +++ tests/utils/test_visualization_utils.py | 209 ++ 141 files changed, 23478 insertions(+), 5556 deletions(-) create mode 100644 docs/source/debug_processor_pipeline.mdx create mode 100644 docs/source/implement_your_own_processor.mdx create mode 100644 docs/source/introduction_processors.mdx create mode 100644 docs/source/phone_teleop.mdx create mode 100644 docs/source/processors_robots_teleop.mdx create mode 100644 examples/phone_to_so100/evaluate.py create mode 100644 examples/phone_to_so100/record.py create mode 100644 examples/phone_to_so100/replay.py create mode 100644 examples/phone_to_so100/teleoperate.py create mode 100644 examples/so100_to_so100_EE/evaluate.py create mode 100644 examples/so100_to_so100_EE/record.py create mode 100644 examples/so100_to_so100_EE/replay.py create mode 100644 examples/so100_to_so100_EE/teleoperate.py create mode 100644 src/lerobot/datasets/pipeline_features.py create mode 100644 src/lerobot/policies/act/processor_act.py create mode 100644 src/lerobot/policies/diffusion/processor_diffusion.py delete mode 100644 src/lerobot/policies/normalize.py create mode 100644 src/lerobot/policies/pi0/processor_pi0.py create mode 100644 src/lerobot/policies/pi0fast/processor_pi0fast.py create mode 100644 src/lerobot/policies/sac/processor_sac.py create mode 100644 src/lerobot/policies/sac/reward_model/processor_classifier.py create mode 100644 src/lerobot/policies/smolvla/processor_smolvla.py create mode 100644 src/lerobot/policies/tdmpc/processor_tdmpc.py create mode 100644 src/lerobot/policies/vqbet/processor_vqbet.py create mode 100644 src/lerobot/processor/batch_processor.py create mode 100644 src/lerobot/processor/converters.py create mode 100644 src/lerobot/processor/core.py create mode 100644 src/lerobot/processor/delta_action_processor.py create mode 100644 src/lerobot/processor/factory.py create mode 100644 src/lerobot/processor/gym_action_processor.py create mode 100644 src/lerobot/processor/hil_processor.py create mode 100644 src/lerobot/processor/joint_observations_processor.py create mode 100644 src/lerobot/processor/migrate_policy_normalization.py create mode 100644 src/lerobot/processor/policy_robot_bridge.py create mode 100644 src/lerobot/processor/tokenizer_processor.py create mode 100644 src/lerobot/robots/so100_follower/robot_kinematic_processor.py delete mode 100644 src/lerobot/robots/so100_follower/so100_follower_end_effector.py create mode 100644 src/lerobot/teleoperators/phone/__init__.py create mode 100644 src/lerobot/teleoperators/phone/config_phone.py create mode 100644 src/lerobot/teleoperators/phone/phone_processor.py create mode 100644 src/lerobot/teleoperators/phone/teleop_phone.py create mode 100644 src/lerobot/utils/rotation.py create mode 100644 tests/datasets/test_dataset_utils.py delete mode 100644 tests/datasets/test_utils.py create mode 100644 tests/processor/test_act_processor.py create mode 100644 tests/processor/test_batch_processor.py create mode 100644 tests/processor/test_classifier_processor.py create mode 100644 tests/processor/test_converters.py create mode 100644 tests/processor/test_device_processor.py create mode 100644 tests/processor/test_diffusion_processor.py create mode 100644 tests/processor/test_migration_detection.py create mode 100644 tests/processor/test_pi0_processor.py create mode 100644 tests/processor/test_pipeline_from_pretrained_helpers.py create mode 100644 tests/processor/test_policy_robot_bridge.py create mode 100644 tests/processor/test_sac_processor.py create mode 100644 tests/processor/test_smolvla_processor.py create mode 100644 tests/processor/test_tdmpc_processor.py create mode 100644 tests/processor/test_tokenizer_processor.py create mode 100644 tests/processor/test_vqbet_processor.py create mode 100644 tests/utils/test_visualization_utils.py diff --git a/.gitignore b/.gitignore index c4d1f769..b47e22cb 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,7 @@ outputs/ # Dev folders .cache/* +*.stl +*.urdf +*.xml +*.part diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 9f5de823..7d6b69fb 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -30,9 +30,18 @@ - local: smolvla title: Finetune SmolVLA title: "Policies" + +- sections: + - local: introduction_processors + title: Introduction to Robot Processors + - local: debug_processor_pipeline + title: Debug your processor pipeline + - local: implement_your_own_processor + title: Implement your own processor + - local: processors_robots_teleop + title: Processors for Robots and Teleoperators + title: "Robot Processors" - sections: - - local: hope_jr - title: Hope Jr - local: so101 title: SO-101 - local: so100 @@ -41,9 +50,15 @@ title: Koch v1.1 - local: lekiwi title: LeKiwi + - local: hope_jr + title: Hope Jr - local: reachy2 title: Reachy 2 title: "Robots" +- sections: + - local: phone_teleop + title: Phone + title: "Teleoperators" - sections: - local: notebooks title: Notebooks diff --git a/docs/source/backwardcomp.mdx b/docs/source/backwardcomp.mdx index 0e1d0163..3366c8ab 100644 --- a/docs/source/backwardcomp.mdx +++ b/docs/source/backwardcomp.mdx @@ -1,5 +1,61 @@ # Backward compatibility +## Policy Normalization Migration (PR #1452) + +**Breaking Change**: LeRobot policies no longer have built-in normalization layers embedded in their weights. Normalization is now handled by external `PolicyProcessorPipeline` components. + +### What changed? + +| | Before PR #1452 | After PR #1452 | +| -------------------------- | ------------------------------------------------ | ------------------------------------------------------------ | +| **Normalization Location** | Embedded in model weights (`normalize_inputs.*`) | External `PolicyProcessorPipeline` components | +| **Model State Dict** | Contains normalization statistics | **Clean weights only** - no normalization parameters | +| **Usage** | `policy(batch)` handles everything | `preprocessor(batch)` → `policy(...)` → `postprocessor(...)` | + +### Impact on existing models + +- Models trained **before** PR #1452 have normalization embedded in their weights +- These models need migration to work with the new `PolicyProcessorPipeline` system +- The migration extracts normalization statistics and creates separate processor pipelines + +### Migrating old models + +Use the migration script to convert models with embedded normalization: + +```shell +python src/lerobot/processor/migrate_policy_normalization.py \ + --pretrained-path lerobot/act_aloha_sim_transfer_cube_human \ + --push-to-hub \ + --branch migrated +``` + +The script: + +1. **Extracts** normalization statistics from model weights +2. **Creates** external preprocessor and postprocessor pipelines +3. **Removes** normalization layers from model weights +4. **Saves** clean model + processor pipelines +5. **Pushes** to Hub with automatic PR creation + +### Using migrated models + +```python +# New usage pattern (after migration) +from lerobot.policies.factory import make_policy, make_pre_post_processors + +# Load model and processors separately +policy = make_policy(config, ds_meta=dataset.meta) +preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=config, + dataset_stats=dataset.meta.stats +) + +# Process data through pipeline +processed_batch = preprocessor(raw_batch) +action = policy.select_action(processed_batch) +final_action = postprocessor(action) +``` + ## Hardware API redesign PR [#777](https://github.com/huggingface/lerobot/pull/777) improves the LeRobot calibration but is **not backward-compatible**. Below is a overview of what changed and how you can continue to work with datasets created before this pull request. diff --git a/docs/source/debug_processor_pipeline.mdx b/docs/source/debug_processor_pipeline.mdx new file mode 100644 index 00000000..4826c947 --- /dev/null +++ b/docs/source/debug_processor_pipeline.mdx @@ -0,0 +1,299 @@ +# Debug Your Processor Pipeline + +Processor pipelines can be complex, especially when chaining multiple transformation steps. +Unlike simple function calls, pipelines lack natural observability, you can't easily see what happens +between each step or where things go wrong. +This guide provides debugging tools and techniques specifically designed to address these challenges +and help you understand data flow through your pipelines. + +We'll explore three complementary debugging approaches: **hooks** for runtime monitoring, **step-through debugging** for detailed inspection, and **feature validation** for catching structural mismatches. Each serves a different purpose and together they provide complete visibility into your pipeline's behavior. + +## Understanding Hooks + +Hooks are functions that get called at specific points during pipeline execution. +They provide a way to inspect, monitor, or modify data without changing your pipeline code. +Think of them as "event listeners" for your pipeline. + +### What is a Hook? + +A hook is a callback function that gets automatically invoked at specific moments during pipeline execution. +The concept comes from event-driven programming, imagine you could "hook into" the pipeline's execution flow to observe or react to what's happening. + +Think of hooks like inserting checkpoints into your pipeline. Every time the pipeline reaches one of these checkpoints, it pauses briefly to call your hook function, giving you a chance to inspect the current state, log information, and validate data. + +A hook is simply a function that accepts two parameters: + +- `step_idx: int` - The index of the current processing step (0, 1, 2, etc.) +- `transition: EnvTransition` - The data transition at that point in the pipeline + +The beauty of hooks is their non-invasive nature: you can add monitoring, validation, or debugging logic without changing a single line of your pipeline code. The pipeline remains clean and focused on its core logic, while hooks handle the cross-cutting concerns like logging, monitoring, and debugging. + +### Before vs After Hooks + +The pipeline supports two types of hooks: + +- **Before hooks** (`register_before_step_hook`) - Called before each step executes +- **After hooks** (`register_after_step_hook`) - Called after each step completes + +```python +def before_hook(step_idx: int, transition: EnvTransition): + """Called before step processes the transition.""" + print(f"About to execute step {step_idx}") + # Useful for: logging, validation, setup + +def after_hook(step_idx: int, transition: EnvTransition): + """Called after step has processed the transition.""" + print(f"Completed step {step_idx}") + # Useful for: monitoring results, cleanup, debugging + +processor.register_before_step_hook(before_hook) +processor.register_after_step_hook(after_hook) +``` + +### Implementing a NaN Detection Hook + +Here's a practical example of a hook that detects NaN values: + +```python +def check_nans(step_idx: int, transition: EnvTransition): + """Check for NaN values in observations.""" + obs = transition.get(TransitionKey.OBSERVATION) + if obs: + for key, value in obs.items(): + if isinstance(value, torch.Tensor) and torch.isnan(value).any(): + print(f"NaN detected in {key} at step {step_idx}") + +# Register the hook to run after each step +processor.register_after_step_hook(check_nans) + +# Process your data - the hook will be called automatically +output = processor(input_data) + +# Remove the hook when done debugging +processor.unregister_after_step_hook(check_nans) +``` + +### How Hooks Work Internally + +Understanding the internal mechanism helps you use hooks more effectively. The pipeline maintains two separate lists: one for before-step hooks and another for after-step hooks. When you register a hook, it's simply appended to the appropriate list. + +During execution, the pipeline follows a strict sequence: for each processing step, it first calls all before-hooks in registration order, then executes the actual step transformation, and finally calls all after-hooks in registration order. This creates a predictable, sandwich-like structure around each step. + +The key insight is that hooks don't change the core pipeline logic—they're purely additive. The pipeline's `_forward` method orchestrates this dance between hooks and processing steps, ensuring that your debugging or monitoring code runs at exactly the right moments without interfering with the main data flow. + +Here's a simplified view of how the pipeline executes hooks: + +```python +class DataProcessorPipeline: + def __init__(self): + self.steps = [...] + self.before_step_hooks = [] # List of before hooks + self.after_step_hooks = [] # List of after hooks + + def _forward(self, transition): + """Internal method that processes the transition through all steps.""" + for step_idx, processor_step in enumerate(self.steps): + # 1. Call all BEFORE hooks + for hook in self.before_step_hooks: + hook(step_idx, transition) + + # 2. Execute the actual processing step + transition = processor_step(transition) + + # 3. Call all AFTER hooks + for hook in self.after_step_hooks: + hook(step_idx, transition) + + return transition + + def register_before_step_hook(self, hook_fn): + self.before_step_hooks.append(hook_fn) + + def register_after_step_hook(self, hook_fn): + self.after_step_hooks.append(hook_fn) +``` + +### Execution Flow + +The execution flow looks like this: + +``` +Input → Before Hook → Step 0 → After Hook → Before Hook → Step 1 → After Hook → ... → Output +``` + +For example, with 3 steps and both hook types: + +```python +def timing_before(step_idx, transition): + print(f"⏱️ Starting step {step_idx}") + +def validation_after(step_idx, transition): + print(f"✅ Completed step {step_idx}") + +processor.register_before_step_hook(timing_before) +processor.register_after_step_hook(validation_after) + +# This will output: +# ⏱️ Starting step 0 +# ✅ Completed step 0 +# ⏱️ Starting step 1 +# ✅ Completed step 1 +# ⏱️ Starting step 2 +# ✅ Completed step 2 +``` + +### Multiple Hooks + +You can register multiple hooks of the same type - they execute in the order registered: + +```python +def log_shapes(step_idx: int, transition: EnvTransition): + obs = transition.get(TransitionKey.OBSERVATION) + if obs: + print(f"Step {step_idx} observation shapes:") + for key, value in obs.items(): + if isinstance(value, torch.Tensor): + print(f" {key}: {value.shape}") + +processor.register_after_step_hook(check_nans) # Executes first +processor.register_after_step_hook(log_shapes) # Executes second + +# Both hooks will be called after each step in registration order +output = processor(input_data) +``` + +While hooks are excellent for monitoring specific issues (like NaN detection) or gathering metrics during normal pipeline execution, sometimes you need to dive deeper. When you want to understand exactly what happens at each step or debug complex transformation logic, step-through debugging provides the detailed inspection you need. + +## Step-Through Debugging + +Step-through debugging is like having a slow-motion replay for your pipeline. Instead of watching your data get transformed in one quick blur from input to output, you can pause and examine what happens after each individual step. + +This approach is particularly valuable when you're trying to understand a complex pipeline, debug unexpected behavior, or verify that each transformation is working as expected. Unlike hooks, which are great for automated monitoring, step-through debugging gives you manual, interactive control over the inspection process. + +The `step_through()` method is a generator that yields the transition state after each processing step, allowing you to inspect intermediate results. Think of it as creating a series of snapshots of your data as it flows through the pipeline—each snapshot shows you exactly what your data looks like after one more transformation has been applied. + +### How Step-Through Works + +The `step_through()` method fundamentally changes how the pipeline executes. Instead of running all steps in sequence and only returning the final result, it transforms the pipeline into an iterator that yields intermediate results. + +Here's what happens internally: the method starts by converting your input data into the pipeline's internal transition format, then yields this initial state. Next, it applies the first processing step and yields the result. Then it applies the second step to that result and yields again, and so on. Each `yield` gives you a complete snapshot of the transition at that point. + +This generator pattern is powerful because it's lazy—the pipeline only computes the next step when you ask for it. This means you can stop at any point, inspect the current state thoroughly, and decide whether to continue. You're not forced to run the entire pipeline just to debug one problematic step. + +Instead of running the entire pipeline and only seeing the final result, `step_through()` pauses after each step and gives you the intermediate transition: + +```python +# This creates a generator that yields intermediate states +for i, intermediate_result in enumerate(processor.step_through(input_data)): + print(f"=== After step {i} ===") + + # Inspect the observation at this stage + obs = intermediate_result.get(TransitionKey.OBSERVATION) + if obs: + for key, value in obs.items(): + if isinstance(value, torch.Tensor): + print(f"{key}: shape={value.shape}, dtype={value.dtype}") +``` + +### Interactive Debugging with Breakpoints + +You can add breakpoints in the step-through loop to interactively debug: + +```python +# Step through the pipeline with debugging +for i, intermediate in enumerate(processor.step_through(data)): + print(f"Step {i}: {processor.steps[i].__class__.__name__}") + + # Set a breakpoint to inspect the current state + breakpoint() # Debugger will pause here + + # You can now inspect 'intermediate' in the debugger: + # - Check tensor shapes and values + # - Verify expected transformations + # - Look for unexpected changes +``` + +During the debugger session, you can: + +- Examine `intermediate[TransitionKey.OBSERVATION]` to see observation data +- Check `intermediate[TransitionKey.ACTION]` for action transformations +- Inspect any part of the transition to understand what each step does + +Step-through debugging is perfect for understanding the _data_ transformations, but what about the _structure_ of that data? While hooks and step-through help you debug runtime behavior, you also need to ensure your pipeline produces data in the format expected by downstream components. This is where feature contract validation comes in. + +## Validating Feature Contracts + +Feature contracts define what data structure your pipeline expects as input and produces as output. +Validating these contracts helps catch mismatches early. + +### Understanding Feature Contracts + +Each processor step has a `transform_features()` method that describes how it changes the data structure: + +```python +# Get the expected output features from your pipeline +initial_features = { + PipelineFeatureType.OBSERVATION: { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(7,)), + "observation.image": PolicyFeature(type=FeatureType.IMAGE, shape=(3, 224, 224)) + }, + PipelineFeatureType.ACTION: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)) + } +} + +# Check what your pipeline will output +output_features = processor.transform_features(initial_features) + +print("Input features:") +for feature_type, features in initial_features.items(): + print(f" {feature_type}:") + for key, feature in features.items(): + print(f" {key}: {feature.type.value}, shape={feature.shape}") + +print("\nOutput features:") +for feature_type, features in output_features.items(): + print(f" {feature_type}:") + for key, feature in features.items(): + print(f" {key}: {feature.type.value}, shape={feature.shape}") +``` + +### Verifying Expected Features + +Check that your pipeline produces the features you expect: + +```python +# Define what features you expect the pipeline to produce +expected_keys = ["observation.state", "observation.image", "action"] + +print("Validating feature contract...") +for expected_key in expected_keys: + found = False + for feature_type, features in output_features.items(): + if expected_key in features: + feature = features[expected_key] + print(f"✅ {expected_key}: {feature.type.value}, shape={feature.shape}") + found = True + break + + if not found: + print(f"❌ Missing expected feature: {expected_key}") +``` + +This validation helps ensure your pipeline will work correctly with downstream components that expect specific data structures. + +## Summary + +Now that you understand the three debugging approaches, you can tackle any pipeline issue systematically: + +1. **Hooks** - For runtime monitoring and validation without modifying pipeline code +2. **Step-through** - For inspecting intermediate states and understanding transformations +3. **Feature validation** - For ensuring data structure contracts are met + +**When to use each approach:** + +- Start with **step-through debugging** when you need to understand what your pipeline does or when something unexpected happens +- Add **hooks** for continuous monitoring during development and production to catch issues automatically +- Use **feature validation** before deployment to ensure your pipeline works with downstream components + +These three tools work together to give you the complete observability that complex pipelines naturally lack. With hooks watching for issues, step-through helping you understand behavior, and feature validation ensuring compatibility, you'll be able to debug any pipeline confidently and efficiently. diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index f8a5c69b..f6bac1ff 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -4,7 +4,13 @@ In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process. -It combines three key ingredients: 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe. +It combines three key ingredients: + +1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. + +2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. + +3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe. Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines. @@ -56,30 +62,242 @@ pip install -e ".[hilserl]" ### Understanding Configuration -The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/envs/configs.py`. Which is defined as: +The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/scripts/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs: ```python +class GymManipulatorConfig: + env: HILSerlRobotEnvConfig # Environment configuration (nested) + dataset: DatasetConfig # Dataset recording/replay configuration (nested) + mode: str | None = None # "record", "replay", or None (for training) + device: str = "cpu" # Compute device + class HILSerlRobotEnvConfig(EnvConfig): robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/robots`) - teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/teleoperators`) - wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py` - fps: int = 10 # Control frequency + teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm + processor: HILSerlProcessorConfig # Processing pipeline configuration (nested) name: str = "real_robot" # Environment name - mode: str = None # "record", "replay", or None (for training) - repo_id: str | None = None # LeRobot dataset repository ID - dataset_root: str | None = None # Local dataset root (optional) - task: str = "" # Task identifier - num_episodes: int = 10 # Number of episodes for recording - episode: int = 0 # episode index for replay - device: str = "cuda" # Compute device - push_to_hub: bool = True # Whether to push the recorded datasets to Hub - pretrained_policy_name_or_path: str | None = None # For policy loading - reward_classifier_pretrained_path: str | None = None # For reward model - number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier + task: str | None = None # Task identifier + fps: int = 10 # Control frequency + +# Nested processor configuration +class HILSerlProcessorConfig: + control_mode: str = "gamepad" # Control mode + observation: ObservationConfig | None = None # Observation processing settings + image_preprocessing: ImagePreprocessingConfig | None = None # Image crop/resize settings + gripper: GripperConfig | None = None # Gripper control and penalty settings + reset: ResetConfig | None = None # Environment reset and timing settings + inverse_kinematics: InverseKinematicsConfig | None = None # IK processing settings + reward_classifier: RewardClassifierConfig | None = None # Reward classifier settings + max_gripper_pos: float | None = 100.0 # Maximum gripper position + +# Sub-configuration classes +class ObservationConfig: + add_joint_velocity_to_observation: bool = False # Add joint velocities to state + add_current_to_observation: bool = False # Add motor currents to state + add_ee_pose_to_observation: bool = False # Add end-effector pose to state + display_cameras: bool = False # Display camera feeds during execution + +class ImagePreprocessingConfig: + crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None # Image cropping parameters + resize_size: tuple[int, int] | None = None # Target image size + +class GripperConfig: + use_gripper: bool = True # Enable gripper control + gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage + gripper_penalty_in_reward: bool = False # Include gripper penalty in reward + +class ResetConfig: + fixed_reset_joint_positions: Any | None = None # Joint positions for reset + reset_time_s: float = 5.0 # Time to wait during reset + control_time_s: float = 20.0 # Maximum episode duration + terminate_on_success: bool = True # Whether to terminate episodes on success detection + +class InverseKinematicsConfig: + urdf_path: str | None = None # Path to robot URDF file + target_frame_name: str | None = None # End-effector frame name + end_effector_bounds: dict[str, list[float]] | None = None # EE workspace bounds + end_effector_step_sizes: dict[str, float] | None = None # EE step sizes per axis + +class RewardClassifierConfig: + pretrained_path: str | None = None # Path to pretrained reward classifier + success_threshold: float = 0.5 # Success detection threshold + success_reward: float = 1.0 # Reward value for successful episodes + +# Dataset configuration +class DatasetConfig: + repo_id: str # LeRobot dataset repository ID + task: str # Task identifier + root: str | None = None # Local dataset root directory + num_episodes_to_record: int = 5 # Number of episodes for recording + replay_episode: int | None = None # Episode index for replay + push_to_hub: bool = False # Whether to push datasets to Hub ``` +### Processor Pipeline Architecture + +HIL-SERL uses a modular processor pipeline architecture that processes robot observations and actions through a series of composable steps. The pipeline is divided into two main components: + +#### Environment Processor Pipeline + +The environment processor (`env_processor`) handles incoming observations and environment state: + +1. **VanillaObservationProcessorStep**: Converts raw robot observations into standardized format +2. **JointVelocityProcessorStep** (optional): Adds joint velocity information to observations +3. **MotorCurrentProcessorStep** (optional): Adds motor current readings to observations +4. **ForwardKinematicsJointsToEE** (optional): Computes end-effector pose from joint positions +5. **ImageCropResizeProcessorStep** (optional): Crops and resizes camera images +6. **TimeLimitProcessorStep** (optional): Enforces episode time limits +7. **GripperPenaltyProcessorStep** (optional): Applies penalties for inappropriate gripper usage +8. **RewardClassifierProcessorStep** (optional): Automated reward detection using vision models +9. **AddBatchDimensionProcessorStep**: Converts data to batch format for neural network processing +10. **DeviceProcessorStep**: Moves data to the specified compute device (CPU/GPU) + +#### Action Processor Pipeline + +The action processor (`action_processor`) handles outgoing actions and human interventions: + +1. **AddTeleopActionAsComplimentaryDataStep**: Captures teleoperator actions for logging +2. **AddTeleopEventsAsInfoStep**: Records intervention events and episode control signals +3. **InterventionActionProcessorStep**: Handles human interventions and episode termination +4. **Inverse Kinematics Pipeline** (when enabled): + - **MapDeltaActionToRobotActionStep**: Converts delta actions to robot action format + - **EEReferenceAndDelta**: Computes end-effector reference and delta movements + - **EEBoundsAndSafety**: Enforces workspace safety bounds + - **InverseKinematicsEEToJoints**: Converts end-effector actions to joint targets + - **GripperVelocityToJoint**: Handles gripper control commands + +#### Configuration Examples + +**Basic Observation Processing**: + +```json +{ + "env": { + "processor": { + "observation": { + "add_joint_velocity_to_observation": true, + "add_current_to_observation": false, + "display_cameras": false + } + } + } +} +``` + +**Image Processing**: + +```json +{ + "env": { + "processor": { + "image_preprocessing": { + "crop_params_dict": { + "observation.images.front": [180, 250, 120, 150], + "observation.images.side": [180, 207, 180, 200] + }, + "resize_size": [128, 128] + } + } + } +} +``` + +**Inverse Kinematics Setup**: + +```json +{ + "env": { + "processor": { + "inverse_kinematics": { + "urdf_path": "path/to/robot.urdf", + "target_frame_name": "end_effector", + "end_effector_bounds": { + "min": [0.16, -0.08, 0.03], + "max": [0.24, 0.2, 0.1] + }, + "end_effector_step_sizes": { + "x": 0.02, + "y": 0.02, + "z": 0.02 + } + } + } + } +} +``` + +### Advanced Observation Processing + +The HIL-SERL framework supports additional observation processing features that can improve policy learning: + +#### Joint Velocity Processing + +Enable joint velocity estimation to provide the policy with motion information: + +```json +{ + "env": { + "processor": { + "observation": { + "add_joint_velocity_to_observation": true + } + } + } +} +``` + +This processor: + +- Estimates joint velocities using finite differences between consecutive joint position readings +- Adds velocity information to the observation state vector +- Useful for policies that need motion awareness for dynamic tasks + +#### Motor Current Processing + +Monitor motor currents to detect contact forces and load conditions: + +```json +{ + "env": { + "processor": { + "observation": { + "add_current_to_observation": true + } + } + } +} +``` + +This processor: + +- Reads motor current values from the robot's control system +- Adds current measurements to the observation state vector +- Helps detect contact events, object weights, and mechanical resistance +- Useful for contact-rich manipulation tasks + +#### Combined Observation Processing + +You can enable multiple observation processing features simultaneously: + +```json +{ + "env": { + "processor": { + "observation": { + "add_joint_velocity_to_observation": true, + "add_current_to_observation": true, + "add_ee_pose_to_observation": false, + "display_cameras": false + } + } + } +} +``` + +**Note**: Enabling additional observation features increases the state space dimensionality, which may require adjusting your policy network architecture and potentially collecting more training data. + ### Finding Robot Workspace Bounds Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot. @@ -128,24 +346,58 @@ With the bounds defined, you can safely collect demonstrations for training. Tra **Setting Up Record Mode** -Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)): +Create a configuration file for recording demonstrations (or edit an existing one like [env_config.json](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/env_config.json)): -1. Set `mode` to `"record"` -2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name") -3. Set `num_episodes` to the number of demonstrations you want to collect -4. Set `crop_params_dict` to `null` initially (we'll determine crops later) -5. Configure `robot`, `cameras`, and other hardware settings +1. Set `mode` to `"record"` at the root level +2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name") +3. Set `num_episodes_to_record` in the `dataset` section to the number of demonstrations you want to collect +4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later) +5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section Example configuration section: ```json -"mode": "record", -"repo_id": "username/pick_lift_cube", -"dataset_root": null, -"task": "pick_and_lift", -"num_episodes": 15, -"episode": 0, -"push_to_hub": true +{ + "env": { + "type": "gym_manipulator", + "name": "real_robot", + "fps": 10, + "processor": { + "control_mode": "gamepad", + "observation": { + "display_cameras": false + }, + "image_preprocessing": { + "crop_params_dict": {}, + "resize_size": [128, 128] + }, + "gripper": { + "use_gripper": true, + "gripper_penalty": 0.0 + }, + "reset": { + "reset_time_s": 5.0, + "control_time_s": 20.0 + } + }, + "robot": { + // ... robot configuration ... + }, + "teleop": { + // ... teleoperator configuration ... + } + }, + "dataset": { + "repo_id": "username/pick_lift_cube", + "root": null, + "task": "pick_and_lift", + "num_episodes_to_record": 15, + "replay_episode": 0, + "push_to_hub": true + }, + "mode": "record", + "device": "cpu" +} ``` ### Using a Teleoperation Device @@ -191,10 +443,20 @@ The gamepad provides a very convenient way to control the robot and the episode To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file. ```json +{ + "env": { "teleop": { - "type": "gamepad", - "use_gripper": true + "type": "gamepad", + "use_gripper": true }, + "processor": { + "control_mode": "gamepad", + "gripper": { + "use_gripper": true + } + } + } +} ```

@@ -216,11 +478,21 @@ The SO101 leader arm has reduced gears that allows it to move and track the foll To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file. ```json +{ + "env": { "teleop": { - "type": "so101_leader", - "port": "/dev/tty.usbmodem585A0077921", # check your port number - "use_degrees": true + "type": "so101_leader", + "port": "/dev/tty.usbmodem585A0077921", + "use_degrees": true }, + "processor": { + "control_mode": "leader", + "gripper": { + "use_gripper": true + } + } + } +} ``` In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure. @@ -251,7 +523,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/e During recording: -1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions` +1. The robot will reset to the initial position defined in the configuration file `env.processor.reset.fixed_reset_joint_positions` 2. Complete the task successfully 3. The episode ends with a reward of 1 when you press the "success" button 4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0 @@ -310,11 +582,19 @@ observation.images.front: [180, 250, 120, 150] Add these crop parameters to your training configuration: ```json -"crop_params_dict": { - "observation.images.side": [180, 207, 180, 200], - "observation.images.front": [180, 250, 120, 150] -}, -"resize_size": [128, 128] +{ + "env": { + "processor": { + "image_preprocessing": { + "crop_params_dict": { + "observation.images.side": [180, 207, 180, 200], + "observation.images.front": [180, 250, 120, 150] + }, + "resize_size": [128, 128] + } + } + } +} ``` **Recommended image resolution** @@ -343,26 +623,52 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/r **Key Parameters for Data Collection** -- **mode**: set it to `"record"` to collect a dataset -- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub -- **num_episodes**: Number of episodes to record -- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected -- **fps**: Number of frames per second to record -- **push_to_hub**: Whether to push the dataset to the hub +- **mode**: set it to `"record"` to collect a dataset (at root level) +- **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub +- **dataset.num_episodes_to_record**: Number of episodes to record +- **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`) +- **env.fps**: Number of frames per second to record +- **dataset.push_to_hub**: Whether to push the dataset to the hub -The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier. +The `env.processor.reset.terminate_on_success` parameter allows you to control episode termination behavior. When set to `false`, episodes will continue even after success is detected, allowing you to collect more positive examples with the reward=1 label. This is crucial for training reward classifiers as it provides more success state examples in your dataset. When set to `true` (default), episodes terminate immediately upon success detection. + +**Important**: For reward classifier training, set `terminate_on_success: false` to collect sufficient positive examples. For regular HIL-SERL training, keep it as `true` to enable automatic episode termination when the task is completed successfully. Example configuration section for data collection: ```json { + "env": { + "type": "gym_manipulator", + "name": "real_robot", + "fps": 10, + "processor": { + "reset": { + "reset_time_s": 5.0, + "control_time_s": 20.0, + "terminate_on_success": false + }, + "gripper": { + "use_gripper": true + } + }, + "robot": { + // ... robot configuration ... + }, + "teleop": { + // ... teleoperator configuration ... + } + }, + "dataset": { + "repo_id": "hf_username/dataset_name", + "dataset_root": "data/your_dataset", + "task": "reward_classifier_task", + "num_episodes_to_record": 20, + "replay_episode": null, + "push_to_hub": true + }, "mode": "record", - "repo_id": "hf_username/dataset_name", - "dataset_root": "data/your_dataset", - "num_episodes": 20, - "push_to_hub": true, - "fps": 10, - "number_of_steps_after_success": 15 + "device": "cpu" } ``` @@ -421,9 +727,17 @@ To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to ```python -env_config = HILSerlRobotEnvConfig( - reward_classifier_pretrained_path="path_to_your_pretrained_trained_model", - # Other environment parameters +config = GymManipulatorConfig( + env=HILSerlRobotEnvConfig( + processor=HILSerlProcessorConfig( + reward_classifier=RewardClassifierConfig( + pretrained_path="path_to_your_pretrained_trained_model" + ) + ), + # Other environment parameters + ), + dataset=DatasetConfig(...), + mode=None # For training ) ``` @@ -432,7 +746,18 @@ or set the argument in the json config file. ```json { - "reward_classifier_pretrained_path": "path_to_your_pretrained_model" + "env": { + "processor": { + "reward_classifier": { + "pretrained_path": "path_to_your_pretrained_model", + "success_threshold": 0.7, + "success_reward": 1.0 + }, + "reset": { + "terminate_on_success": true + } + } + } } ``` @@ -447,7 +772,7 @@ The reward classifier will automatically provide rewards based on the visual inp **Example Workflow for training the reward classifier** 1. **Create the configuration files**: - Create the necessary json configuration files for the reward classifier and the environment. Check the examples [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/tree/main). + Create the necessary json configuration files for the reward classifier and the environment. Check the examples [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/reward_classifier/config.json). 2. **Collect a dataset**: @@ -472,7 +797,7 @@ The LeRobot system uses a distributed actor-learner architecture for training. T **Configuration Setup** -Create a training configuration file (example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_config_hilserl_so100.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`. +Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`. 1. Configure the policy settings (`type="sac"`, `device`, etc.) 2. Set `dataset` to your cropped dataset diff --git a/docs/source/hilserl_sim.mdx b/docs/source/hilserl_sim.mdx index c739be83..77191fde 100644 --- a/docs/source/hilserl_sim.mdx +++ b/docs/source/hilserl_sim.mdx @@ -26,15 +26,18 @@ pip install -e ".[hilserl]" ## Configuration -To use `gym_hil` with LeRobot, you need to create a configuration file. An example is provided [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/gym_hil_env.json). Key configuration sections include: +To use `gym_hil` with LeRobot, you need to create a configuration file. An example is provided [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/gym_hil/env_config.json). Key configuration sections include: ### Environment Type and Task ```json { - "type": "hil", - "name": "franka_sim", - "task": "PandaPickCubeGamepad-v0", + "env": { + "type": "gym_manipulator", + "name": "gym_hil", + "task": "PandaPickCubeGamepad-v0", + "fps": 10 + }, "device": "cuda" } ``` @@ -45,28 +48,40 @@ Available tasks: - `PandaPickCubeGamepad-v0`: With gamepad control - `PandaPickCubeKeyboard-v0`: With keyboard control -### Gym Wrappers Configuration +### Processor Configuration ```json -"wrapper": { - "gripper_penalty": -0.02, - "control_time_s": 15.0, - "use_gripper": true, - "fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785], - "end_effector_step_sizes": { - "x": 0.025, - "y": 0.025, - "z": 0.025 - }, - "control_mode": "gamepad" +{ + "env": { + "processor": { + "control_mode": "gamepad", + "gripper": { + "use_gripper": true, + "gripper_penalty": -0.02 + }, + "reset": { + "control_time_s": 15.0, + "fixed_reset_joint_positions": [ + 0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785 + ] + }, + "inverse_kinematics": { + "end_effector_step_sizes": { + "x": 0.025, + "y": 0.025, + "z": 0.025 + } + } } + } +} ``` Important parameters: -- `gripper_penalty`: Penalty for excessive gripper movement -- `use_gripper`: Whether to enable gripper control -- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector +- `gripper.gripper_penalty`: Penalty for excessive gripper movement +- `gripper.use_gripper`: Whether to enable gripper control +- `inverse_kinematics.end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector - `control_mode`: Set to `"gamepad"` to use a gamepad controller ## Running with HIL RL of LeRobot @@ -75,39 +90,50 @@ Important parameters: To run the environment, set mode to null: - -```python +```bash python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json ``` - ### Recording a Dataset To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record: - -```python +```json +{ + "env": { + "type": "gym_manipulator", + "name": "gym_hil", + "task": "PandaPickCubeGamepad-v0" + }, + "dataset": { + "repo_id": "username/sim_dataset", + "root": null, + "task": "pick_cube", + "num_episodes_to_record": 10, + "replay_episode": null, + "push_to_hub": true + }, + "mode": "record" +} +``` + +```bash python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json ``` - ### Training a Policy -To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers: +To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/gym_hil/train_config.json) and run the actor and learner servers: - -```python +```bash python -m lerobot.scripts.rl.actor --config_path path/to/train_gym_hil_env.json ``` - In a different terminal, run the learner server: - -```python +```bash python -m lerobot.scripts.rl.learner --config_path path/to/train_gym_hil_env.json ``` - The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots. diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 905046be..19b62167 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -519,11 +519,14 @@ from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import _init_rerun from lerobot.record import record_loop +from lerobot.policies.factory import make_processor NUM_EPISODES = 5 FPS = 30 EPISODE_TIME_SEC = 60 TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" # Create the robot configuration camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} @@ -535,7 +538,7 @@ robot_config = SO100FollowerConfig( robot = SO100Follower(robot_config) # Initialize the policy -policy = ACTPolicy.from_pretrained("/") +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") @@ -544,7 +547,7 @@ dataset_features = {**action_features, **obs_features} # Create the dataset dataset = LeRobotDataset.create( - repo_id="/eval_", + repo_id=HF_DATASET_ID, fps=FPS, features=dataset_features, robot_type=robot.name, @@ -559,6 +562,12 @@ _init_rerun(session_name="recording") # Connect the robot robot.connect() +preprocessor, postprocessor = make_processor( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, +) + for episode_idx in range(NUM_EPISODES): log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") @@ -568,6 +577,8 @@ for episode_idx in range(NUM_EPISODES): events=events, fps=FPS, policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, diff --git a/docs/source/il_sim.mdx b/docs/source/il_sim.mdx index 3dd80dc4..6a615620 100644 --- a/docs/source/il_sim.mdx +++ b/docs/source/il_sim.mdx @@ -22,13 +22,38 @@ pip install -e ".[hilserl]" ## Teleoperate and Record a Dataset -To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_gym_hil_il.json). +To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/env_config.json). -To teleoperate and collect a dataset, we need to modify this config file and you should add your `repo_id` here: `"repo_id": "il_gym",` and `"num_episodes": 30,` and make sure you set `mode` to `record`, "mode": "record". +To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection: -If you do not have a Nvidia GPU also change `"device": "cuda"` parameter in the config file (for example to `mps` for MacOS). +```json +{ + "env": { + "type": "gym_manipulator", + "name": "gym_hil", + "task": "PandaPickCubeGamepad-v0", + "fps": 10 + }, + "dataset": { + "repo_id": "your_username/il_gym", + "root": null, + "task": "pick_cube", + "num_episodes_to_record": 30, + "replay_episode": null, + "push_to_hub": true + }, + "mode": "record", + "device": "cuda" +} +``` -By default the config file assumes you use a controller. To use your keyboard please change the envoirment specified at `"task"` in the config file and set it to `"PandaPickCubeKeyboard-v0"`. +Key configuration points: + +- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"` +- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes +- Ensure `mode` is set to `"record"` +- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"` +- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"` Then we can run this command to start: @@ -140,9 +165,32 @@ huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \ ## Evaluate your policy in Sim -To evaluate your policy we have to use the config file that can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json). +To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/eval_config.json). -Make sure to replace the `repo_id` with the dataset you trained on, for example `pepijn223/il_sim_dataset` and replace the `pretrained_policy_name_or_path` with your model id, for example `pepijn223/il_sim_model` +Here's an example evaluation configuration: + +```json +{ + "env": { + "type": "gym_manipulator", + "name": "gym_hil", + "task": "PandaPickCubeGamepad-v0", + "fps": 10 + }, + "dataset": { + "repo_id": "your_username/il_sim_dataset", + "dataset_root": null, + "task": "pick_cube" + }, + "pretrained_policy_name_or_path": "your_username/il_sim_model", + "device": "cuda" +} +``` + +Make sure to replace: + +- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`) +- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`) Then you can run this command to visualize your trained policy diff --git a/docs/source/implement_your_own_processor.mdx b/docs/source/implement_your_own_processor.mdx new file mode 100644 index 00000000..5b7d4f78 --- /dev/null +++ b/docs/source/implement_your_own_processor.mdx @@ -0,0 +1,273 @@ +# Implement your own Robot Processor + +In this tutorial, you'll learn how to implement your own Robot Processor. +It begins by exploring the need for a custom processor, then uses the `NormalizerProcessorStep` as the running example to explain how to implement, configure, and serialize a processor. Finally, it lists all helper processors that ship with LeRobot. + +## Why would you need a custom processor? + +In most cases, when reading raw data from sensors or when models output actions, you need to process this data to make it compatible with your target system. For example, a common need is normalizing data ranges to make them suitable for neural networks. + +LeRobot's `NormalizerProcessorStep` handles this crucial task: + +```python +# Input: raw joint positions in [0, 180] degrees +raw_action = torch.tensor([90.0, 45.0, 135.0]) + +# After processing: normalized to [-1, 1] range for model training +normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=dataset_stats) +normalized_result = normalizer(transition) +# ... +``` + +Other common processing needs include: + +- **Device placement**: Moving tensors between CPU/GPU and converting data types +- **Format conversion**: Transforming between different data structures +- **Batching**: Adding/removing batch dimensions for model compatibility +- **Safety constraints**: Applying limits to robot commands + +```python +# Example pipeline combining multiple processors +pipeline = PolicyProcessorPipeline([ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + NormalizerProcessorStep(features=features, stats=stats), + DeviceProcessorStep(device="cuda"), + # ... +]) +``` + +LeRobot provides a pipeline mechanism to implement sequences of processing steps for both input data and output actions, making it easy to compose these transformations in the right order for optimal performance. + +## How to implement your own processor? + +We'll use the `NormalizerProcessorStep` as our main example because it demonstrates essential processor patterns including state management, configuration serialization, and tensor handling that you'll commonly need. + +Prepare the sequence of processing steps necessary for your problem. A processor step is a class that implements the following methods: + +- `__call__`: implements the processing step for the input transition. +- `get_config`: gets the configuration of the processor step. +- `state_dict`: gets the state of the processor step. +- `load_state_dict`: loads the state of the processor step. +- `reset`: resets the state of the processor step. +- `feature_contract`: displays the modification to the feature space during the processor step. + +### Implement the `__call__` method + +The `__call__` method is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`. Here's how the `NormalizerProcessorStep` works: + +```python +@dataclass +@ProcessorStepRegistry.register("normalizer_processor") +class NormalizerProcessorStep(ProcessorStep): + """Normalize observations/actions using dataset statistics.""" + + features: dict[str, PolicyFeature] + norm_map: dict[FeatureType, NormalizationMode] + stats: dict[str, dict[str, Any]] | None = None + eps: float = 1e-8 + _tensor_stats: dict = field(default_factory=dict, init=False, repr=False) + + def __post_init__(self): + """Convert stats to tensors for efficient computation.""" + self.stats = self.stats or {} + self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=torch.float32) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + new_transition = transition.copy() + # Normalize observations + # ... + # Normalize action + # ... + return new_transition + +``` + +See the full implementation in `src/lerobot/processor/normalize_processor.py` for complete details. + +**Key principles:** + +- **Always use `transition.copy()`** to avoid side effects +- **Handle both observations and actions** consistently +- **Separate config from state**: `get_config()` returns JSON-serializable params, `state_dict()` returns tensors +- **Convert stats to tensors** in `__post_init__()` for efficient computation + +### Configuration and State Management + +Processors support serialization through three methods that separate configuration from tensor state. The `NormalizerProcessorStep` demonstrates this perfectly - it carries dataset statistics (tensors) in its state, and hyperparameters in its config: + +```python +# Continuing the NormalizerProcessorStep example... + +def get_config(self) -> dict[str, Any]: + """JSON-serializable configuration (no tensors).""" + return { + "eps": self.eps, + "features": {k: {"type": v.type.value, "shape": v.shape} for k, v in self.features.items()}, + "norm_map": {ft.value: nm.value for ft, nm in self.norm_map.items()}, + # ... + } + +def state_dict(self) -> dict[str, torch.Tensor]: + """Tensor state only (e.g., dataset statistics).""" + flat: dict[str, torch.Tensor] = {} + for key, sub in self._tensor_stats.items(): + for stat_name, tensor in sub.items(): + flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU + return flat + +def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Restore tensor state at runtime.""" + self._tensor_stats.clear() + for flat_key, tensor in state.items(): + key, stat_name = flat_key.rsplit(".", 1) + # Load to processor's configured device + self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to( + dtype=torch.float32, device=self.device + ) + # ... +``` + +**Usage:** + +```python +# Save (e.g., inside a policy) +config = normalizer.get_config() +tensors = normalizer.state_dict() + +# Restore (e.g., loading a pretrained policy) +new_normalizer = NormalizerProcessorStep(**config) +new_normalizer.load_state_dict(tensors) +# Now new_normalizer has the same stats and configuration +``` + +### Transform features + +The `transform_features` method defines how your processor transforms feature names and shapes. This is crucial for policy configuration and debugging. + +For `NormalizerProcessorStep`, features are typically preserved unchanged since normalization doesn't alter keys or shapes: + +```python +def transform_features(self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Normalization preserves all feature definitions.""" + return features # No changes to feature structure + # ... +``` + +When your processor renames or reshapes data, implement this method to reflect the mapping for downstream components. For example, a simple rename processor: + +```python +def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # Simple renaming + if "pixels" in features: + features["observation.image"] = features.pop("pixels") + + # Pattern-based renaming + for key in list(features.keys()): + if key.startswith("env_state."): + suffix = key[len("env_state."):] + features[f"observation.{suffix}"] = features.pop(key) + # ... + + return features +``` + +**Key principles:** + +- Use `features.pop(old_key)` to remove and get the old feature +- Use `features[new_key] = old_feature` to add the renamed feature +- Always return the modified features dictionary +- Document transformations clearly in the docstring + +### Using overrides + +You can override step parameters at load-time using `overrides`. This is handy for non-serializable objects or site-specific settings. It works both in policy factories and with `DataProcessorPipeline.from_pretrained(...)`. + +**Foundational model adaptation**: This is particularly useful when working with foundational pretrained policies where you rarely have access to the original training statistics. You can inject your own dataset statistics to adapt the normalizer to your specific robot or environment data. + +Example: during policy evaluation on the robot, override the device and rename map. +Use this to run a policy trained on CUDA on a CPU-only robot, or to remap camera keys when the robot uses different names than the dataset. + +Direct usage with `from_pretrained`: + +```python +from lerobot.processor import RobotProcessorPipeline + +# Load a foundational policy trained on diverse robot data +# but adapt normalization to your specific robot/environment +new_stats = LeRobotDataset(repo_id="username/my-dataset").meta.stats +processor = RobotProcessorPipeline.from_pretrained( + "huggingface/foundational-robot-policy", # Pretrained foundation model + overrides={ + "normalizer_processor": {"stats": new_stats}, # Inject your robot's statistics + "device_processor": {"device": "cuda:0"}, # registry name for registered steps + "rename_processor": {"rename_map": robot_key_map}, # Map your robot's observation keys + # ... + }, +) +``` + +## Best Practices + +Based on analysis of all LeRobot processor implementations, here are the key patterns and practices: + +### 1. **Safe Data Handling** + +Always create copies of input data to avoid unintended side effects. Use `transition.copy()` and `observation.copy()` rather than modifying data in-place. This prevents your processor from accidentally affecting other components in the pipeline. + +Check for required data before processing and handle missing data gracefully. If your processor expects certain keys (like `"pixels"` for image processing), validate their presence first. For optional data, use safe access patterns like `transition.get()` and handle `None` values appropriately. + +When data validation fails, provide clear, actionable error messages that help users understand what went wrong and how to fix it. + +### 2. **Choose Appropriate Base Classes** + +LeRobot provides specialized base classes that reduce boilerplate code and ensure consistency. Use `ObservationProcessorStep` when you only need to modify observations, `ActionProcessorStep` for action-only processing, and `RobotActionProcessorStep` specifically for dictionary-based robot actions. + +Only inherit directly from `ProcessorStep` when you need full control over the entire transition or when processing multiple transition components simultaneously. The specialized base classes handle the transition management for you and provide type safety. + +### 3. **Registration and Naming** + +Register your processors with descriptive, namespaced names using `@ProcessorStepRegistry.register()`. Use organization prefixes like `"robotics_lab/safety_clipper"` or `"acme_corp/vision_enhancer"` to avoid naming conflicts. Avoid generic names like `"processor"` or `"step"` that could clash with other implementations. + +Good registration makes your processors discoverable and enables clean serialization/deserialization when saving and loading pipelines. + +### 4. **State Management Patterns** + +Distinguish between configuration parameters (JSON-serializable values) and internal state (tensors, buffers). Use dataclass fields with `init=False, repr=False` for internal state that shouldn't appear in the constructor or string representation. + +Implement the `reset()` method to clear internal state between episodes. This is crucial for stateful processors that accumulate data over time, like moving averages or temporal filters. + +Remember that `get_config()` should only return JSON-serializable configuration, while `state_dict()` handles tensor state separately. + +### 5. **Input Validation and Error Handling** + +Validate input types and shapes before processing. Check tensor properties like `dtype` and dimensions to ensure compatibility with your algorithms. For robot actions, verify that required pose components or joint values are present and within expected ranges. + +Use early returns for edge cases where no processing is needed. Provide clear, descriptive error messages that include the expected vs. actual data types or shapes. This makes debugging much easier for users. + +### 6. **Device and Dtype Awareness** + +Design your processors to automatically adapt to the device and dtype of input tensors. Internal tensors (like normalization statistics) should match the input tensor's device and dtype to ensure compatibility with multi-GPU training, mixed precision, and distributed setups. + +Implement a `to()` method that moves your processor's internal state to the specified device. Check device/dtype compatibility at runtime and automatically migrate internal state when needed. This pattern enables seamless operation across different hardware configurations without manual intervention. + +## Conclusion + +You now have all the tools to implement custom processors in LeRobot! The key steps are: + +1. **Define your processor** as a dataclass with the required methods (`__call__`, `get_config`, `state_dict`, `load_state_dict`, `reset`, `transform_features`) +2. **Register it** using `@ProcessorStepRegistry.register("name")` for discoverability +3. **Integrate it** into a `DataProcessorPipeline` with other processing steps +4. **Use base classes** like `ObservationProcessorStep` when possible to reduce boilerplate +5. **Implement device/dtype awareness** to support multi-GPU and mixed precision setups + +The processor system is designed to be modular and composable, allowing you to build complex data processing pipelines from simple, focused components. Whether you're preprocessing sensor data for training or post-processing model outputs for robot execution, custom processors give you the flexibility to handle any data transformation your robotics application requires. + +Key principles for robust processors: + +- **Device/dtype adaptation**: Internal tensors should match input tensors +- **Clear error messages**: Help users understand what went wrong +- **Base class usage**: Leverage specialized base classes to reduce boilerplate +- **Feature contracts**: Declare data structure changes with `transform_features()` + +Start simple, test thoroughly, and ensure your processors work seamlessly across different hardware configurations! diff --git a/docs/source/introduction_processors.mdx b/docs/source/introduction_processors.mdx new file mode 100644 index 00000000..308edbb3 --- /dev/null +++ b/docs/source/introduction_processors.mdx @@ -0,0 +1,314 @@ +# Introduction to Processors + +In robotics, there's a fundamental mismatch between the data that robots and humans produce and what machine learning models expect. +Robots output raw sensor data like camera images and joint positions that need normalization, batching, and device placement before models can process them. +Language instructions from humans must be tokenized into numerical representations, and different robots use different coordinate systems that need standardization. + +The challenge extends to model outputs as well. +Models might output end-effector positions while robots need joint-space commands, or teleoperators produce relative movements while robots expect absolute commands. +Model predictions are often normalized and need conversion back to real-world scales. + +Cross-domain translation adds another layer of complexity. +Training data from one robot setup needs adaptation for deployment on different hardware, models trained with specific camera configurations must work with new arrangements, and datasets with different naming conventions need harmonization. + +**That's where processors come in.** They serve as universal translators that bridge these gaps, ensuring seamless data flow from sensors to models to actuators. +Processors handle all the preprocessing and postprocessing steps needed to convert raw environment data into model-ready inputs and vice versa. + +This means that your favorite policy can be used like this: + +```python +import torch + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.policies.factory import make_pre_post_processors +from lerobot.policies.your_policy import YourPolicy +from lerobot.processor.pipeline import RobotProcessorPipeline, PolicyProcessorPipeline +dataset = LeRobotDataset("hf_user/dataset", episodes=[0]) +sample = dataset[10] + +model = YourPolicy.from_pretrained( + "hf_user/model", +) +model.eval() +model.to("cuda") +preprocessor, postprocessor = make_pre_post_processors(model.config, pretrained_path="hf_user/model", dataset_stats=dataset.meta.stats) + +preprocessed_sample = preprocessor(sample) +action = model.select_action(preprocessed_sample) +postprocessed_action = postprocessor(action) +``` + +## What are Processors? + +In robotics, data comes in many forms: images from cameras, joint positions from sensors, text instructions from users, and more. Each type of data requires specific transformations before a model can use it effectively. Models need this data to be: + +- **Normalized**: Scaled to appropriate ranges for neural network processing +- **Batched**: Organized with proper dimensions for batch processing +- **Tokenized**: Text converted to numerical representations +- **Device-placed**: Moved to the right hardware (CPU/GPU) +- **Type-converted**: Cast to appropriate data types + +Processors handle these transformations through composable, reusable steps that can be chained together into pipelines. Think of them as a modular assembly line where each station performs a specific transformation on your data. + +## Core Concepts + +### EnvTransition: The Universal Data Container + +The `EnvTransition` is the fundamental data structure that flows through all processors. +It's a typed dictionary that represents a complete robot-environment interaction: + +- **OBSERVATION**: All sensor data (images, states, proprioception) +- **ACTION**: The action to execute or that was executed +- **REWARD**: Reinforcement learning signal +- **DONE/TRUNCATED**: Episode boundary indicators +- **INFO**: Arbitrary metadata +- **COMPLEMENTARY_DATA**: Task descriptions, indices, padding flags, inter-step data + +### ProcessorStep: The Building Block + +A `ProcessorStep` is a single transformation unit that processes transitions. It's an abstract base class with two required methods: + +```python +from lerobot.processor import ProcessorStep, EnvTransition + +class MyProcessorStep(ProcessorStep): + """Example processor step - inherit and implement abstract methods.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Transform the transition - REQUIRED abstract method.""" + # Your processing logic here + return transition + + def transform_features(self, features): + """Declare how this step transforms feature shapes/types - REQUIRED abstract method.""" + return features # Most processors return features unchanged +``` + +`__call__` is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`. + +`transform_features` is used to declare how this step transforms feature shapes/types. + +### DataProcessorPipeline: The Generic Orchestrator + +The `DataProcessorPipeline[TInput, TOutput]` chains multiple `ProcessorStep` instances together: + +```python +from lerobot.processor import RobotProcessorPipeline, PolicyProcessorPipeline + +# For robot hardware (unbatched data) +robot_processor = RobotProcessorPipeline[RobotAction, RobotAction]( + steps=[step1, step2, step3], + name="robot_pipeline" +) + +# For model training/inference (batched data) +policy_processor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=[step1, step2, step3], + name="policy_pipeline" +) +``` + +## RobotProcessorPipeline vs PolicyProcessorPipeline + +The key distinction is in the data structures they handle: + +| Aspect | RobotProcessorPipeline | PolicyProcessorPipeline | +| --------------- | -------------------------------------------- | ---------------------------------------- | +| **Input** | `dict[str, Any]` - Individual robot values | `dict[str, Any]` - Batched tensors | +| **Output** | `dict[str, Any]` - Individual robot commands | `torch.Tensor` - Policy predictions | +| **Use Case** | Real-time robot control | Model training/inference | +| **Data Format** | Unbatched, heterogeneous | Batched, homogeneous | +| **Examples** | `{"joint_1": 0.5}` | `{"observation.state": tensor([[0.5]])}` | + +**Use `RobotProcessorPipeline`** for robot hardware interfaces: + +```python +# Robot data structures: dict[str, Any] for observations and actions +robot_obs: dict[str, Any] = { + "joint_1": 0.5, # Individual joint values + "joint_2": -0.3, + "camera_0": image_array # Raw camera data +} + +robot_action: dict[str, Any] = { + "joint_1": 0.2, # Target joint positions + "joint_2": 0.1, + "gripper": 0.8 +} +``` + +**Use `PolicyProcessorPipeline`** for model training and batch processing: + +```python +# Policy data structures: batch dicts and tensors +policy_batch: dict[str, Any] = { + "observation.state": torch.tensor([[0.5, -0.3]]), # Batched states + "observation.images.camera0": torch.tensor(...), # Batched images + "action": torch.tensor([[0.2, 0.1, 0.8]]) # Batched actions +} + +policy_action: torch.Tensor = torch.tensor([[0.2, 0.1, 0.8]]) # Model output tensor +``` + +## Converter Functions + +LeRobot provides converter functions to bridge different data formats in `lerobot.processor.converters`. These functions handle the crucial translations between robot hardware data structures, policy model formats, and the internal `EnvTransition` representation that flows through processor pipelines. + +| Category | Function | Description | +| ------------------------------ | ----------------------------- | ------------------------------- | +| **Robot Hardware Converters** | `robot_action_to_transition` | Robot dict → EnvTransition | +| | `observation_to_transition` | Robot obs → EnvTransition | +| | `transition_to_robot_action` | EnvTransition → Robot dict | +| **Policy/Training Converters** | `batch_to_transition` | Batch dict → EnvTransition | +| | `transition_to_batch` | EnvTransition → Batch dict | +| | `policy_action_to_transition` | Policy tensor → EnvTransition | +| | `transition_to_policy_action` | EnvTransition → Policy tensor | +| **Utilities** | `create_transition` | Build transitions with defaults | +| | `identity_transition` | Pass-through converter | + +The key insight is that **robot hardware converters** work with individual values and dictionaries, while **policy/training converters** work with batched tensors and model outputs. The converter functions automatically handle the structural differences, so your processor steps can focus on the core transformations without worrying about data format compatibility. + +## Processor Examples + +The following examples demonstrate real-world processor configurations for policy training and inference. + +Here is an example processor for policy training and inference: + +```python +# Training data preprocessing (optimized order for GPU performance) +training_preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=[ + RenameObservationsProcessorStep(rename_map={}), # Standardize keys + AddBatchDimensionProcessorStep(), # Add batch dims + TokenizerProcessorStep(tokenizer_name="...", ...), # Tokenize language + DeviceProcessorStep(device="cuda"), # Move to GPU first + NormalizerProcessorStep(features=..., stats=...), # Normalize on GPU + ] +) + +# Model output postprocessing +training_postprocessor = PolicyProcessorPipeline[torch.Tensor, torch.Tensor]( + steps=[ + DeviceProcessorStep(device="cpu"), # Move to CPU + UnnormalizerProcessorStep(features=..., stats=...), # Denormalize + ] + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, +) +``` + +### An interaction between a robot and a policy with processors + +The most common real-world scenario combines both pipeline types robot hardware generates observations that need policy processing, and policy outputs need robot-compatible postprocessing: + +```python +# Real deployment: Robot sensors → Model → Robot commands +with torch.no_grad(): + while not done: + raw_obs = robot.get_observation() # dict[str, Any] + + # Add your robot observation to policy observation processor + + policy_input = policy_preprocessor(raw_obs) # Batched dict + + policy_output = policy.select_action(policy_input) # Policy tensor + + policy_action = policy_postprocessor(policy_output) + + # Add your robot action to policy action processor + + robot.send_action(policy_action) +``` + +## Feature Contracts: Shape and Type Transformation + +Processors don't just transform data - they can also **change the data structure itself**. The `transform_features()` method declares these changes, which is crucial for dataset recording and policy creation. + +### Why Feature Contracts Matter + +When building datasets or policies, LeRobot needs to know: + +- **What data fields will exist** after processing +- **What shapes and types** each field will have +- **How to configure models** for the expected data structure + +```python +# Example: A processor that adds velocity to observations +class VelocityProcessor(ObservationProcessorStep): + def observation(self, obs): + new_obs = obs.copy() + if "observation.state" in obs: + # concatenate computed velocity field to the state + new_obs["observation.state"] = self._compute_velocity(obs["observation.state"]) + return new_obs + + def transform_features(self, features): + """Declare the new velocity field we're adding.""" + state_feature = features[PipelineFeatureType.OBSERVATION].get("observation.state") + if state_feature: + double_shape = (state_feature.shape[0] * 2,) if state_feature.shape else (2,) + features[PipelineFeatureType.OBSERVATION]["observation.state"] = PolicyFeature( + type=FeatureType.STATE, shape=double_shape + ) + return features +``` + +### Feature Specification Functions + +`create_initial_features()` and `aggregate_pipeline_dataset_features()` solve a critical dataset creation problem: determining the exact final data structure before any data is processed. +Since processor pipelines can add new features (like velocity fields), change tensor shapes (like cropping images), or rename keys, datasets need to know the complete output specification upfront to allocate proper storage and define schemas. +These functions work together by starting with robot hardware specifications (`create_initial_features()`) then simulating the entire pipeline transformation (`aggregate_pipeline_dataset_features()`) to compute the final feature dictionary that gets passed to `LeRobotDataset.create()`, ensuring perfect alignment between what processors output and what datasets expect to store. + +```python +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features + +# Start with robot's raw features +initial_features = create_initial_features( + observation=robot.observation_features, # {"joint_1.pos": float, "camera_0": (480,640,3)} + action=robot.action_features # {"joint_1.pos": float, "gripper.pos": float} +) + +# Apply processor pipeline to compute final features +final_features = aggregate_pipeline_dataset_features( + pipeline=my_processor_pipeline, + initial_features=initial_features, + use_videos=True +) + +# Use for dataset creation +dataset = LeRobotDataset.create( + repo_id="my_dataset", + features=final_features, # Knows exactly what data to expect + ... +) +``` + +## Common Processor Steps + +LeRobot provides many registered processor steps. Here are the most commonly used core processors: + +### Essential Processors + +- **`normalizer_processor`**: Normalize observations/actions using dataset statistics (mean/std or min/max) +- **`device_processor`**: Move tensors to CPU/GPU with optional dtype conversion +- **`to_batch_processor`**: Add batch dimensions to transitions for model compatibility +- **`rename_observations_processor`**: Rename observation keys using mapping dictionaries +- **`tokenizer_processor`**: Tokenize natural language task descriptions into tokens and attention masks + +### Next Steps + +- **[Implement Your Own Processor](implement_your_own_processor.mdx)** - Create custom processor steps +- **[Debug Your Pipeline](debug_processor_pipeline.mdx)** - Troubleshoot and optimize pipelines +- **[Processors for Robots and Teleoperators](processors_robots_teleop.mdx)** - Real-world integration patterns + +## Summary + +Processors solve the data translation problem in robotics by providing: + +- **Modular transformations**: Composable, reusable processing steps +- **Type safety**: Generic pipelines with compile-time checking +- **Performance optimization**: GPU-accelerated operations +- **Robot/Policy distinction**: Separate pipelines for different data structures +- **Comprehensive ecosystem**: 30+ registered processors for common tasks + +The key insight: `RobotProcessorPipeline` handles unbatched robot hardware data, while `PolicyProcessorPipeline` handles batched model data. Choose the right tool for your data structure! diff --git a/docs/source/phone_teleop.mdx b/docs/source/phone_teleop.mdx new file mode 100644 index 00000000..71d5457f --- /dev/null +++ b/docs/source/phone_teleop.mdx @@ -0,0 +1,192 @@ +# Phone + +Use your phone (iOS or Android) to control your robot. + +**In this guide you'll learn:** + +- How to connect an iOS/Android phone +- How phone pose is mapped to robot end‑effector (EE) targets +- How to tweak safety limits, gripper control, and IK settings + +To use phone to control your robot, install the relevant dependencies with: + +```bash +pip install lerobot[phone] +``` + +## Get started + +### Supported platforms + +- iOS: Uses the HEBI Mobile I/O app (ARKit pose + buttons). Download the app first, open it and the examples will discover it on your network and stream the phone pose and inputs. +- Android: Uses the `teleop` package (WebXR). When you start the Python process, it prints a local URL. Open the link on your phone, tap Start, then use Move to stream pose. + +Links: + +- Android WebXR library: [`teleop` on PyPI](https://pypi.org/project/teleop/) +- iOS app: [HEBI Mobile I/O](https://docs.hebi.us/tools.html#mobile-io) + +### Phone orientation and controls + +- Orientation: hold the phone with the screen facing up and the top edge pointing in the same direction as the robot gripper. This ensures calibration aligns the phone’s frame with the robot frame so motion feels natural, see the image below for reference. +- Enable/disable: + - iOS: Hold `B1` to enable teleoperation, release to stop. The first press captures a reference pose. + - Android: Press and hold the `Move` button, release to stop. The first press captures a reference pose. +- Gripper control: + - iOS: Analog input `A3` controls the gripper as velocity input. + - Android: Buttons `A` and `B` act like increment/decrement (A opens, B closes). You can tune velocity in the `GripperVelocityToJoint` step. + +Phone teleop orientation + +### Step 1: Choose the platform + +Modify the examples to use `PhoneOS.IOS` or `PhoneOS.ANDROID` in `PhoneConfig`. The API is identical across platforms, only the input source differs. All examples are under `examples/` and have `phone_so100_*.py` variants. + +Teleoperation example: + +```36:43:examples/phone_so100_teleop.py +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS + +teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID +teleop_device = Phone(teleop_config) +``` + +### Step 2: Connect and calibrate + +When `Phone(teleop_config)` is created and `connect()` is called, calibration is prompted automatically. Hold the phone in the orientation described above, then: + +- iOS: press and hold `B1` to capture the reference pose. +- Android: press `Move` button on the WebXR page to capture the reference pose. + +Why calibrate? We capture the current pose so subsequent poses are expressed in a robot aligned frame. When you again press the button to enable control, the position is recaptured to avoid drift when your phone is repositioned while it was disabled. + +### Step 3: Run an example + +Run on of the examples scripts to teleoperate, record a dataset, replay a dataset or evaluate a policy. + +All scripts assume you configured your robot (e.g., SO-100 follower) and set the correct serial port. + +Additionally you need to **copy the urdf of the robot to the examples folder**. For the examples in this tutorial (Using SO100/SO101) it is highly recommended to use the urdf in the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf) + +- Run this example to teleoperate: + + ```bash + python examples/phone_to_so100/teleoperate.py + ``` + +After running the example: + +- Android: after starting the script, open the printed local URL on your phone, tap Start, then press and hold Move. +- iOS: open HEBI Mobile I/O first; B1 enables motion. A3 controls the gripper. + +Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop.mdx) guide. + +- Run this example to record a dataset, which saves absolute end effector observations and actions: + + ```bash + python examples/phone_to_so100/record.py + ``` + +- Run this example to replay recorded episodes: + + ```bash + python examples/phone_to_so100/replay.py + ``` + +- Run this example to evaluate a pretrained policy: + + ```bash + python examples/phone_to_so100/evaluate.py + ``` + +### Important pipeline steps and options + +- Kinematics are used in multiple steps. We use [Placo](https://github.com/Rhoban/placo) which is a wrapper around Pinocchio for handling our kinematics. We construct the kinematics object by passing the robot's URDF and target frame. We set `target_frame_name` to the gripper frame. + + ```examples/phone_to_so100/teleoperate.py + kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), + ) + + ``` + +- The `MapPhoneActionToRobotAction` step converts the calibrated phone pose and inputs into target deltas and gripper commands, below is shown what the step outputs. + + ```src/lerobot/teleoperators/phone/phone_processor.py + action["enabled"] = enabled + action["target_x"] = -pos[1] if enabled else 0.0 + action["target_y"] = pos[0] if enabled else 0.0 + action["target_z"] = pos[2] if enabled else 0.0 + action["target_wx"] = rotvec[1] if enabled else 0.0 + action["target_wy"] = rotvec[0] if enabled else 0.0 + action["target_wz"] = -rotvec[2] if enabled else 0.0 + action["gripper_vel"] = gripper_vel # Still send gripper action when disabled + ``` + +- The `EEReferenceAndDelta` step converts target deltas to an absolute desired EE pose, storing a reference on enable, the `end_effector_step_sizes` are the step sizes for the EE pose and can be modified to change the motion speed. + + ```examples/phone_to_so100/teleoperate.py + EEReferenceAndDelta( + kinematics=kinematics_solver, + end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, + motor_names=list(robot.bus.motors.keys()), + use_latched_reference=True, + ), + ``` + +- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` and `max_ee_twist_step_rad` are the step limits for the EE pose and can be modified to change the safety limits. + + ```examples/phone_to_so100/teleoperate.py + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.10, + max_ee_twist_step_rad=0.50, + ) + ``` + +- The `GripperVelocityToJoint` step turns a velocity‑like gripper input into absolute gripper position using the current measured state. The `speed_factor` is the factor by which the velocity is multiplied. + + ```examples/phone_to_so100/teleoperate.py + GripperVelocityToJoint(speed_factor=20.0) + ``` + +#### Different IK initial guesses + +We use different IK initial guesses in the kinematic steps. As initial guess either the current measured joints or the previous IK solution is used. + +- Closed loop (used in record/eval): sets `initial_guess_current_joints=True` so IK starts from the measured joints each frame. + + ```examples/phone_to_so100/record.py + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, # closed loop + ) + ``` + +- Open loop (used in replay): sets `initial_guess_current_joints=False` so IK continues from the previous IK solution rather than the measured state. This preserves action stability when we replay without feedback. + + ```examples/phone_to_so100/replay.py + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=False, # open loop + ) + ``` + +### Pipeline steps explained + +- MapPhoneActionToRobotAction: converts calibrated phone pose and inputs into target deltas and a gripper command. Motion is gated by an enable signal (B1 on iOS, Move on Android). +- EEReferenceAndDelta: latches a reference EE pose on enable and combines it with target deltas to produce an absolute desired EE pose each frame. When disabled, it keeps sending the last commanded pose. +- EEBoundsAndSafety: clamps the EE pose to a workspace and rate‑limits jumps for safety. Also declares `action.ee.*` features. +- InverseKinematicsEEToJoints: turns an EE pose into joint positions with IK. `initial_guess_current_joints=True` is recommended for closed‑loop control; set `False` for open‑loop replay for stability. +- GripperVelocityToJoint: integrates a velocity‑like gripper input into an absolute gripper position using the current measured state. +- ForwardKinematicsJointsToEE: computes `observation.state.ee.*` from observed joints for logging and training on EE state. + +### Troubleshooting + +- iOS not discovered: ensure HEBI Mobile I/O is open and your laptop/phone are on the same network. +- Android URL not reachable: check local you used `https` instead of `http`, use the exact IP printed by the script and allow your browser to enter and ignore the certificate issue. +- Motion feels inverted: adjust the sign flips in `MapPhoneActionToRobotAction` or swap axes to match your setup. diff --git a/docs/source/processors_robots_teleop.mdx b/docs/source/processors_robots_teleop.mdx new file mode 100644 index 00000000..c4fcbe03 --- /dev/null +++ b/docs/source/processors_robots_teleop.mdx @@ -0,0 +1,151 @@ +# Processors for Robots and Teleoperators + +This guide shows how to build and modify processing pipelines that connect teleoperators (e.g., phone) to robots and datasets. Pipelines standardize conversions between different action/observation spaces so you can swap teleops and robots without rewriting glue code. + +We use the Phone to SO‑100 follower examples for concreteness, but the same patterns apply to other robots. + +**What you'll learn** + +- Absolute vs. relative EE control: What each means, trade‑offs, and how to choose for your task. +- Three-pipeline pattern: How to map teleop actions → dataset actions → robot commands, and robot observations → dataset observations. +- Adapters (`to_transition` / `to_output`): How these convert raw dicts to `EnvTransition` and back to reduce boilerplate. +- Dataset feature contracts: How steps declare features via `transform_features(...)`, and how to aggregate/merge them for recording. +- Choosing a representation: When to store joints, absolute EE poses, or relative EE deltas—and how that affects training. +- Pipeline customization guidance: How to swap robots/URDFs safely and tune bounds, step sizes, and options like IK initialization. + +### Absolute vs relative EE control + +The examples in this guide use absolute end effector (EE) poses because they are easy to reason about. In practice, relative EE deltas or joint position are often preferred as learning features. + +With processors, you choose the learning features you want to use for your policy. This could be joints positions/velocities, absolute EE, or relative EE positions. You can also choose to store other features, such as joint torques, motor currents, etc. + +## Three pipelines + +We often compose three pipelines. Depending on your setup, some can be empty if action and observation spaces already match. +Each of these pipelines handle different conversions between different action and observation spaces. Below is a quick explanation of each pipeline. + +1. Pipeline 1: Teleop action space → dataset action space (phone pose → EE targets) +2. Pipeline 2: Dataset action space → robot command space (EE targets → joints) +3. Pipeline 3: Robot observation space → dataset observation space (joints → EE pose) + +Below is an example of the three pipelines that we use in the phone to SO-100 follower examples: + +```69:90:examples/phone_so100_record.py +phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, RobotAction]( # teleop -> dataset action + steps=[ + MapPhoneActionToRobotAction(platform=teleop_config.phone_os), + EEReferenceAndDelta( + kinematics=kinematics_solver, end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, motor_names=list(robot.bus.motors.keys()), + ), + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20, max_ee_twist_step_rad=0.50, + ), + GripperVelocityToJoint(), + ], + to_transition=robot_action_to_transition, + to_output=transition_to_robot_action, +) + +robot_ee_to_joints_processor = RobotProcessorPipeline[RobotAction, RobotAction]( # dataset action -> robot + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()), initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_to_transition, + to_output=transition_to_robot_action, +) + +robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation]( # robot obs -> dataset obs + steps=[ + ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) + ], + to_transition=observation_to_transition, + to_output=transition_to_observation, +) +``` + +## Why to_transition / to_output + +To convert from robot/teleoperator to pipeline and back, we use the `to_transition` and `to_output` pipeline adapters. +They standardize conversions to reduce boilerplate code, and form the bridge between the robot and teleoperators raw dictionaries and the pipeline’s `EnvTransition` format. +In the phone to SO-100 follower examples we use the following adapters: + +- `robot_action_to_transition`: transforms the teleop action dict to a pipeline transition. +- `transition_to_robot_action`: transforms the pipeline transition to a robot action dict. +- `observation_to_transition`: transforms the robot observation dict to a pipeline transition. +- `transition_to_observation`: transforms the pipeline transition to a observation dict. + +Checkout [src/lerobot/processor/converters.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/processor/converters.py) for more details. + +## Dataset feature contracts + +Dataset features are determined by the keys saved in the dataset. Each step can declare what features it modifies in a contract called `transform_features(...)`. Once you build a processor, the processor can then aggregate all of these features with `aggregate_pipeline_dataset_features()` and merge multiple feature dicts with `combine_feature_dicts(...)`. + +Below is and example of how we declare features with the `transform_features` method in the phone to SO-100 follower examples: + +```src/lerobot/robots/so100_follower/robot_kinematic_processor.py + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We only use the ee pose in the dataset, so we don't need the joint positions + for n in self.motor_names: + features[PipelineFeatureType.ACTION].pop(f"{n}.pos", None) + # We specify the dataset features of this step that we want to be stored in the dataset + for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]: + features[PipelineFeatureType.ACTION][f"ee.{k}"] = PolicyFeature( + type=FeatureType.STATE, shape=(1,) + ) + return features +``` + +Here we declare what PolicyFeatures we modify in this step, so we know what features we can expect when we run the processor. These features can then be aggregated and used to create the dataset features. + +Below is an example of how we aggregate and merge features in the phone to SO-100 record example: + +```121:145:examples/phone_so100_record.py +features=combine_feature_dicts( + # Run the feature contract of the pipelines + # This tells you how the features would look like after the pipeline steps + aggregate_pipeline_dataset_features( + pipeline=phone_to_robot_ee_pose_processor, + initial_features=create_initial_features(action=phone.action_features), # <- Action features we can expect, these come from our teleop device (phone) and action processor + use_videos=True, + ), + aggregate_pipeline_dataset_features( + pipeline=robot_joints_to_ee_pose, + initial_features=create_initial_features(observation=robot.observation_features), # <- Observation features we can expect, these come from our robot and observation processor + use_videos=True, + patterns=["observation.state.ee"], # <- Here you could optionally filter the features we want to store in the dataset, with a specific pattern + + ), + ), +``` + +How it works: + +- `aggregate_pipeline_dataset_features(...)`: applies `transform_features` across the pipeline and filters by patterns (images included when `use_videos=True`, and state features included when `patterns` is specified). +- `combine_feature_dicts(...)`: combine multiple feature dicts. +- Recording with `record_loop(...)` uses `build_dataset_frame(...)` to build frames consistent with `dataset.features` before we call `add_frame(...)` to add the frame to the dataset. + +## Guidance when customizing robot pipelines + +You can store any of the following features as your action/observation space: + +- Joint positions +- Absolute EE poses +- Relative EE deltas +- Other features: joint velocity, torques, etc. + +Pick what you want to use for your policy action and observation space and configure/modify the pipelines and steps accordingly. + +### Different robots + +- You can easily reuse pipelines, for example to use another robot with phone teleop, modify the examples and swap the robot `RobotKinematics` (URDF) and `motor_names` to use your own robot with Phone teleop. Additionally you should ensure `target_frame_name` points to your gripper/wrist. + +### Safety first + +- When changing pipelines, start with tight bounds, implement safety steps when working with real robots. +- Its advised to start with simulation first and then move to real robots. + +Thats it! We hope this guide helps you get started with customizing your robot pipelines, If you run into any issues at any point, jump into our [Discord community](https://discord.com/invite/s3KuuzsPFb) for support. diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index f2de79db..7f3fad36 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -27,6 +27,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetad from lerobot.datasets.utils import dataset_to_policy_features from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy +from lerobot.policies.factory import make_pre_post_processors def main(): @@ -56,9 +57,10 @@ def main(): cfg = DiffusionConfig(input_features=input_features, output_features=output_features) # We can now instantiate our policy with this config and the dataset stats. - policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats) + policy = DiffusionPolicy(cfg) policy.train() policy.to(device) + preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats) # Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames # which can differ for inputs, outputs and rewards (if there are some). @@ -99,7 +101,7 @@ def main(): done = False while not done: for batch in dataloader: - batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + batch = preprocessor(batch) loss, _ = policy.forward(batch) loss.backward() optimizer.step() @@ -114,6 +116,8 @@ def main(): # Save a policy checkpoint. policy.save_pretrained(output_directory) + preprocessor.save_pretrained(output_directory) + postprocessor.save_pretrained(output_directory) if __name__ == "__main__": diff --git a/examples/5_train_with_streaming.py b/examples/5_train_with_streaming.py index 17818410..93d13535 100644 --- a/examples/5_train_with_streaming.py +++ b/examples/5_train_with_streaming.py @@ -30,6 +30,7 @@ from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.utils import dataset_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_pre_post_processors def main(): @@ -60,9 +61,10 @@ def main(): # We can now instantiate our policy with this config and the dataset stats. cfg = ACTConfig(input_features=input_features, output_features=output_features) - policy = ACTPolicy(cfg, dataset_stats=dataset_metadata.stats) + policy = ACTPolicy(cfg) policy.train() policy.to(device) + preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats) # Delta timestamps are used to (1) augment frames used during training and (2) supervise the policy. # Here, we use delta-timestamps to only provide ground truth actions for supervision @@ -89,13 +91,7 @@ def main(): done = False while not done: for batch in dataloader: - batch = { - k: (v.type(torch.float32) if isinstance(v, torch.Tensor) and v.dtype != torch.bool else v) - for k, v in batch.items() - } - batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} - - # batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + batch = preprocessor(batch) loss, _ = policy.forward(batch) loss.backward() optimizer.step() @@ -110,6 +106,8 @@ def main(): # Save a policy checkpoint. policy.save_pretrained(output_directory) + preprocessor.save_pretrained(output_directory) + postprocessor.save_pretrained(output_directory) if __name__ == "__main__": diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 57fb62e1..3dbb10f5 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -1,6 +1,24 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_pre_post_processors +from lerobot.processor import make_default_processors from lerobot.record import record_loop from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.utils.control_utils import init_keyboard_listener @@ -11,12 +29,16 @@ NUM_EPISODES = 2 FPS = 30 EPISODE_TIME_SEC = 60 TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" -# Create the robot and teleoperator configurations +# Create the robot configuration & robot robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") + robot = LeKiwiClient(robot_config) -policy = ACTPolicy.from_pretrained("/") +# Create policy +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") @@ -25,7 +47,7 @@ dataset_features = {**action_features, **obs_features} # Create the dataset dataset = LeRobotDataset.create( - repo_id="/", + repo_id=HF_DATASET_ID, fps=FPS, features=dataset_features, robot_type=robot.name, @@ -33,33 +55,52 @@ dataset = LeRobotDataset.create( image_writer_threads=4, ) +# Build Policy Processors +preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, + # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility. + preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}}, +) + +# Connect the robot # To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi` robot.connect() -_init_rerun(session_name="recording") +# TODO(Steven): Update this example to use pipelines +teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() +# Initialize the keyboard listener and rerun visualization listener, events = init_keyboard_listener() +_init_rerun(session_name="lekiwi_evaluate") if not robot.is_connected: raise ValueError("Robot is not connected!") +print("Starting evaluate loop...") recorded_episodes = 0 while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}") - # Run the policy inference loop + # Main record loop record_loop( robot=robot, events=events, fps=FPS, policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, ) - # Logic for reset env + # Reset the environment if not stopping or re-recording if not events["stop_recording"] and ( (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] ): @@ -71,6 +112,9 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, ) if events["rerecord_episode"]: @@ -80,11 +124,12 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: dataset.clear_episode_buffer() continue + # Save episode dataset.save_episode() recorded_episodes += 1 -# Upload to hub and clean up -dataset.push_to_hub() - +# Clean up +log_say("Stop recording") robot.disconnect() listener.stop() +dataset.push_to_hub() diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 11a71676..f5d109d5 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -1,5 +1,22 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.processor import make_default_processors from lerobot.record import record_loop from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient @@ -9,21 +26,26 @@ from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import _init_rerun -NUM_EPISODES = 3 +NUM_EPISODES = 2 FPS = 30 EPISODE_TIME_SEC = 30 RESET_TIME_SEC = 10 TASK_DESCRIPTION = "My task description" +HF_REPO_ID = "/" # Create the robot and teleoperator configurations robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm") keyboard_config = KeyboardTeleopConfig() +# Initialize the robot and teleoperator robot = LeKiwiClient(robot_config) leader_arm = SO100Leader(leader_arm_config) keyboard = KeyboardTeleop(keyboard_config) +# TODO(Steven): Update this example to use pipelines +teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() + # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") obs_features = hw_to_dataset_features(robot.observation_features, "observation") @@ -31,7 +53,7 @@ dataset_features = {**action_features, **obs_features} # Create the dataset dataset = LeRobotDataset.create( - repo_id="/", + repo_id=HF_REPO_ID, fps=FPS, features=dataset_features, robot_type=robot.name, @@ -39,23 +61,25 @@ dataset = LeRobotDataset.create( image_writer_threads=4, ) +# Connect the robot and teleoperator # To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi` robot.connect() leader_arm.connect() keyboard.connect() +# Initialize the keyboard listener and rerun visualization +listener, events = init_keyboard_listener() _init_rerun(session_name="lekiwi_record") -listener, events = init_keyboard_listener() - if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: - raise ValueError("Robot, leader arm of keyboard is not connected!") + raise ValueError("Robot or teleop is not connected!") +print("Starting record loop...") recorded_episodes = 0 while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: log_say(f"Recording episode {recorded_episodes}") - # Run the record loop + # Main record loop record_loop( robot=robot, events=events, @@ -65,9 +89,12 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, ) - # Logic for reset env + # Reset the environment if not stopping or re-recording if not events["stop_recording"] and ( (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] ): @@ -80,6 +107,9 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: control_time_s=RESET_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, ) if events["rerecord_episode"]: @@ -89,13 +119,14 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: dataset.clear_episode_buffer() continue + # Save episode dataset.save_episode() recorded_episodes += 1 -# Upload to hub and clean up -dataset.push_to_hub() - +# Clean up +log_say("Stop recording") robot.disconnect() leader_arm.disconnect() keyboard.disconnect() listener.stop() +dataset.push_to_hub() diff --git a/examples/lekiwi/replay.py b/examples/lekiwi/replay.py index 248354df..0f8eabdf 100644 --- a/examples/lekiwi/replay.py +++ b/examples/lekiwi/replay.py @@ -1,3 +1,19 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -8,25 +24,36 @@ from lerobot.utils.utils import log_say EPISODE_IDX = 0 +# Initialize the robot config robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") + +# Initialize the robot robot = LeKiwiClient(robot_config) +# Fetch the dataset to replay dataset = LeRobotDataset("/", episodes=[EPISODE_IDX]) -actions = dataset.hf_dataset.select_columns("action") +# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 +episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) +actions = episode_frames.select_columns("action") +# Connect to the robot robot.connect() if not robot.is_connected: raise ValueError("Robot is not connected!") +print("Starting replay loop...") log_say(f"Replaying episode {EPISODE_IDX}") -for idx in range(dataset.num_frames): +for idx in range(len(episode_frames)): t0 = time.perf_counter() + # Get recorded action from dataset action = { name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) } - robot.send_action(action) + + # Send action to robot + _ = robot.send_action(action) busy_wait(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) diff --git a/examples/lekiwi/teleoperate.py b/examples/lekiwi/teleoperate.py index 8358a2b9..cde4000d 100644 --- a/examples/lekiwi/teleoperate.py +++ b/examples/lekiwi/teleoperate.py @@ -1,3 +1,19 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig @@ -13,35 +29,44 @@ robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="my_lekiwi") teleop_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm") keyboard_config = KeyboardTeleopConfig(id="my_laptop_keyboard") +# Initialize the robot and teleoperator robot = LeKiwiClient(robot_config) leader_arm = SO100Leader(teleop_arm_config) keyboard = KeyboardTeleop(keyboard_config) +# Connect to the robot and teleoperator # To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi` robot.connect() leader_arm.connect() keyboard.connect() +# Init rerun viewer _init_rerun(session_name="lekiwi_teleop") if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: - raise ValueError("Robot, leader arm of keyboard is not connected!") + raise ValueError("Robot or teleop is not connected!") +print("Starting teleop loop...") while True: t0 = time.perf_counter() + # Get robot observation observation = robot.get_observation() + # Get teleop action + # Arm arm_action = leader_arm.get_action() arm_action = {f"arm_{k}": v for k, v in arm_action.items()} - + # Keyboard keyboard_keys = keyboard.get_action() base_action = robot._from_keyboard_to_base_action(keyboard_keys) - log_rerun_data(observation, {**arm_action, **base_action}) - action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action - robot.send_action(action) + # Send action to robot + _ = robot.send_action(action) + + # Visualize + log_rerun_data(observation=observation, action=action) busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0)) diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py new file mode 100644 index 00000000..e76b1135 --- /dev/null +++ b/examples/phone_to_so100/evaluate.py @@ -0,0 +1,197 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import combine_feature_dicts +from lerobot.model.kinematics import RobotKinematics +from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_pre_post_processors +from lerobot.processor import ( + RobotAction, + RobotObservation, + RobotProcessorPipeline, + make_default_teleop_action_processor, +) +from lerobot.processor.converters import ( + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.record import record_loop +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +from lerobot.utils.visualization_utils import _init_rerun + +NUM_EPISODES = 5 +FPS = 30 +EPISODE_TIME_SEC = 60 +TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" + +# Create the robot configuration & robot +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, +) + +robot = SO100Follower(robot_config) + +# Create policy +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert EE action to joints action +robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Build pipeline to convert joints observation to EE observation +robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ + ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) + ], + to_transition=observation_to_transition, + to_output=transition_to_observation, +) + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id=HF_DATASET_ID, + fps=FPS, + features=combine_feature_dicts( + aggregate_pipeline_dataset_features( + pipeline=robot_joints_to_ee_pose_processor, + initial_features=create_initial_features(observation=robot.observation_features), + use_videos=True, + ), + # User for now should be explicit on the feature keys that were used for record + # Alternatively, the user can pass the processor step that has the right features + aggregate_pipeline_dataset_features( + pipeline=make_default_teleop_action_processor(), + initial_features=create_initial_features( + action={ + f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,)) + for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"] + } + ), + use_videos=True, + ), + ), + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, +) + +# Build Policy Processors +preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, + # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility. + preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}}, +) + +# Connect the robot +robot.connect() + +# Initialize the keyboard listener and rerun visualization +listener, events = init_keyboard_listener() +_init_rerun(session_name="phone_so100_evaluate") + +if not robot.is_connected: + raise ValueError("Robot is not connected!") + +print("Starting evaluate loop...") +episode_idx = 0 +for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + + # Main record loop + record_loop( + robot=robot, + events=events, + fps=FPS, + policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) + + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) + + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + # Save episode + dataset.save_episode() + episode_idx += 1 + +# Clean up +log_say("Stop recording") +robot.disconnect() +listener.stop() +dataset.push_to_hub() diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py new file mode 100644 index 00000000..768041d6 --- /dev/null +++ b/examples/phone_to_so100/record.py @@ -0,0 +1,204 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import combine_feature_dicts +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor.converters import ( + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.record import record_loop +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + EEBoundsAndSafety, + EEReferenceAndDelta, + ForwardKinematicsJointsToEE, + GripperVelocityToJoint, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS +from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction +from lerobot.teleoperators.phone.teleop_phone import Phone +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +from lerobot.utils.visualization_utils import _init_rerun + +NUM_EPISODES = 2 +FPS = 30 +EPISODE_TIME_SEC = 60 +RESET_TIME_SEC = 30 +TASK_DESCRIPTION = "My task description" +HF_REPO_ID = "/" + +# Create the robot and teleoperator configurations +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, +) +teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID + +# Initialize the robot and teleoperator +robot = SO100Follower(robot_config) +phone = Phone(teleop_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert phone action to EE action +phone_to_robot_ee_pose_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + MapPhoneActionToRobotAction(platform=teleop_config.phone_os), + EEReferenceAndDelta( + kinematics=kinematics_solver, + end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, + motor_names=list(robot.bus.motors.keys()), + use_latched_reference=True, + ), + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.20, + max_ee_twist_step_rad=0.50, + ), + GripperVelocityToJoint(speed_factor=20.0), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Build pipeline to convert EE action to joints action +robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Build pipeline to convert joint observation to EE observation +robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ + ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) + ], + to_transition=observation_to_transition, + to_output=transition_to_observation, +) + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id=HF_REPO_ID, + fps=FPS, + features=combine_feature_dicts( + # Run the feature contract of the pipelines + # This tells you how the features would look like after the pipeline steps + aggregate_pipeline_dataset_features( + pipeline=phone_to_robot_ee_pose_processor, + initial_features=create_initial_features(action=phone.action_features), + use_videos=True, + ), + aggregate_pipeline_dataset_features( + pipeline=robot_joints_to_ee_pose, + initial_features=create_initial_features(observation=robot.observation_features), + use_videos=True, + ), + ), + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, +) + +# Connect the robot and teleoperator +robot.connect() +phone.connect() + +# Initialize the keyboard listener and rerun visualization +listener, events = init_keyboard_listener() +_init_rerun(session_name="phone_so100_record") + +if not robot.is_connected or not phone.is_connected: + raise ValueError("Robot or teleop is not connected!") + + +print("Starting record loop. Move your phone to teleoperate the robot...") +episode_idx = 0 +while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + + # Main record loop + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=phone, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=phone_to_robot_ee_pose_processor, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose, + ) + + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=phone, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=phone_to_robot_ee_pose_processor, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose, + ) + + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + # Save episode + dataset.save_episode() + episode_idx += 1 + +# Clean up +log_say("Stop recording") +robot.disconnect() +phone.disconnect() +listener.stop() +dataset.push_to_hub() diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py new file mode 100644 index 00000000..80c65a4c --- /dev/null +++ b/examples/phone_to_so100/replay.py @@ -0,0 +1,99 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor.converters import ( + robot_action_observation_to_transition, + transition_to_robot_action, +) +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.robot_utils import busy_wait +from lerobot.utils.utils import log_say + +EPISODE_IDX = 0 +HF_REPO_ID = "/" + +# Initialize the robot config +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True +) + +# Initialize the robot +robot = SO100Follower(robot_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert EE action to joints action +robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=False, # Because replay is open loop + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Fetch the dataset to replay +dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) +# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 +episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) +actions = episode_frames.select_columns("action") + +# Connect to the robot +robot.connect() + +if not robot.is_connected: + raise ValueError("Robot is not connected!") + +print("Starting replay loop...") +log_say(f"Replaying episode {EPISODE_IDX}") +for idx in range(len(episode_frames)): + t0 = time.perf_counter() + + # Get recorded action from dataset + ee_action = { + name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) + } + + # Get robot observation + robot_obs = robot.get_observation() + + # Dataset EE -> robot joints + joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) + + # Send action to robot + _ = robot.send_action(joint_action) + + busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0)) + +# Clean up +robot.disconnect() diff --git a/examples/phone_to_so100/teleoperate.py b/examples/phone_to_so100/teleoperate.py new file mode 100644 index 00000000..eb5ed352 --- /dev/null +++ b/examples/phone_to_so100/teleoperate.py @@ -0,0 +1,114 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specif + +import time + +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor.converters import ( + robot_action_observation_to_transition, + transition_to_robot_action, +) +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + EEBoundsAndSafety, + EEReferenceAndDelta, + GripperVelocityToJoint, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS +from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction +from lerobot.teleoperators.phone.teleop_phone import Phone +from lerobot.utils.robot_utils import busy_wait +from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data + +FPS = 30 + +# Initialize the robot and teleoperator +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True +) +teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID + +# Initialize the robot and teleoperator +robot = SO100Follower(robot_config) +teleop_device = Phone(teleop_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert phone action to ee pose action to joint action +phone_to_robot_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + MapPhoneActionToRobotAction(platform=teleop_config.phone_os), + EEReferenceAndDelta( + kinematics=kinematics_solver, + end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, + motor_names=list(robot.bus.motors.keys()), + use_latched_reference=True, + ), + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.10, + max_ee_twist_step_rad=0.50, + ), + GripperVelocityToJoint( + speed_factor=20.0, + ), + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Connect to the robot and teleoperator +robot.connect() +teleop_device.connect() + +# Init rerun viewer +_init_rerun(session_name="phone_so100_teleop") + +if not robot.is_connected or not teleop_device.is_connected: + raise ValueError("Robot or teleop is not connected!") + +print("Starting teleop loop. Move your phone to teleoperate the robot...") +while True: + t0 = time.perf_counter() + + # Get robot observation + robot_obs = robot.get_observation() + + # Get teleop action + phone_obs = teleop_device.get_action() + + # Phone -> EE pose -> Joints transition + joint_action = phone_to_robot_joints_processor((phone_obs, robot_obs)) + + # Send action to robot + _ = robot.send_action(joint_action) + + # Visualize + log_rerun_data(observation=phone_obs, action=joint_action) + + busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0)) diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py new file mode 100644 index 00000000..fd10bf86 --- /dev/null +++ b/examples/so100_to_so100_EE/evaluate.py @@ -0,0 +1,198 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import combine_feature_dicts +from lerobot.model.kinematics import RobotKinematics +from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_pre_post_processors +from lerobot.processor import ( + RobotAction, + RobotObservation, + RobotProcessorPipeline, + make_default_teleop_action_processor, +) +from lerobot.processor.converters import ( + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.record import record_loop +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +from lerobot.utils.visualization_utils import _init_rerun + +NUM_EPISODES = 5 +FPS = 30 +EPISODE_TIME_SEC = 60 +TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" + +# Create the robot configuration & robot +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, +) + +robot = SO100Follower(robot_config) + +# Create policy +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert EE action to joints action +robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Build pipeline to convert joints observation to EE observation +robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ + ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) + ], + to_transition=observation_to_transition, + to_output=transition_to_observation, +) + + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id=HF_DATASET_ID, + fps=FPS, + features=combine_feature_dicts( + aggregate_pipeline_dataset_features( + pipeline=robot_joints_to_ee_pose_processor, + initial_features=create_initial_features(observation=robot.observation_features), + use_videos=True, + ), + # User for now should be explicit on the feature keys that were used for record + # Alternatively, the user can pass the processor step that has the right features + aggregate_pipeline_dataset_features( + pipeline=make_default_teleop_action_processor(), + initial_features=create_initial_features( + action={ + f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,)) + for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"] + } + ), + use_videos=True, + ), + ), + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, +) + +# Build Policy Processors +preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, + # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility. + preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}}, +) + +# Connect the robot and teleoperator +robot.connect() + +# Initialize the keyboard listener and rerun visualization +listener, events = init_keyboard_listener() +_init_rerun(session_name="so100_so100_evaluate") + +if not robot.is_connected: + raise ValueError("Robot is not connected!") + +print("Starting evaluate loop...") +episode_idx = 0 +for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + + # Main record loop + record_loop( + robot=robot, + events=events, + fps=FPS, + policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) + + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) + + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + # Save episode + dataset.save_episode() + episode_idx += 1 + +# Clean up +log_say("Stop recording") +robot.disconnect() +listener.stop() +dataset.push_to_hub() diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py new file mode 100644 index 00000000..abb8fb99 --- /dev/null +++ b/examples/so100_to_so100_EE/record.py @@ -0,0 +1,203 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import combine_feature_dicts +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor.converters import ( + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.record import record_loop +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + EEBoundsAndSafety, + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig +from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +from lerobot.utils.visualization_utils import _init_rerun + +NUM_EPISODES = 2 +FPS = 30 +EPISODE_TIME_SEC = 60 +RESET_TIME_SEC = 30 +TASK_DESCRIPTION = "My task description" +HF_REPO_ID = "/" + +# Create the robot and teleoperator configurations +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +follower_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", cameras=camera_config, use_degrees=True +) +leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm") + +# Initialize the robot and teleoperator +follower = SO100Follower(follower_config) +leader = SO100Leader(leader_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +follower_kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(follower.bus.motors.keys()), +) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +leader_kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(leader.bus.motors.keys()), +) + +# Build pipeline to convert follower joints to EE observation +follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ + ForwardKinematicsJointsToEE( + kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys()) + ), + ], + to_transition=observation_to_transition, + to_output=transition_to_observation, +) + +# Build pipeline to convert leader joints to EE action +leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + ForwardKinematicsJointsToEE( + kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys()) + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Build pipeline to convert EE action to follower joints +ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + [ + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.10, + max_ee_twist_step_rad=0.50, + ), + InverseKinematicsEEToJoints( + kinematics=follower_kinematics_solver, + motor_names=list(follower.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id=HF_REPO_ID, + fps=FPS, + features=combine_feature_dicts( + # Run the feature contract of the pipelines + # This tells you how the features would look like after the pipeline steps + aggregate_pipeline_dataset_features( + pipeline=leader_joints_to_ee, + initial_features=create_initial_features(action=leader.action_features), + use_videos=True, + ), + aggregate_pipeline_dataset_features( + pipeline=follower_joints_to_ee, + initial_features=create_initial_features(observation=follower.observation_features), + use_videos=True, + ), + ), + robot_type=follower.name, + use_videos=True, + image_writer_threads=4, +) + + +# Connect the robot and teleoperator +leader.connect() +follower.connect() + +# Initialize the keyboard listener and rerun visualization +listener, events = init_keyboard_listener() +_init_rerun(session_name="recording_phone") + +if not leader.is_connected or not follower.is_connected: + raise ValueError("Robot or teleop is not connected!") + +print("Starting record loop...") +episode_idx = 0 +while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + + # Main record loop + record_loop( + robot=follower, + events=events, + fps=FPS, + teleop=leader, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=leader_joints_to_ee, + robot_action_processor=ee_to_follower_joints, + robot_observation_processor=follower_joints_to_ee, + ) + + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): + log_say("Reset the environment") + record_loop( + robot=follower, + events=events, + fps=FPS, + teleop=leader, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=leader_joints_to_ee, + robot_action_processor=ee_to_follower_joints, + robot_observation_processor=follower_joints_to_ee, + ) + + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + # Save episode + dataset.save_episode() + episode_idx += 1 + +# Clean up +log_say("Stop recording") +leader.disconnect() +follower.disconnect() +listener.stop() +dataset.push_to_hub() diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py new file mode 100644 index 00000000..6987f483 --- /dev/null +++ b/examples/so100_to_so100_EE/replay.py @@ -0,0 +1,100 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor.converters import ( + robot_action_observation_to_transition, + transition_to_robot_action, +) +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.robot_utils import busy_wait +from lerobot.utils.utils import log_say + +EPISODE_IDX = 0 +HF_REPO_ID = "/" + +# Initialize the robot config +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True +) + +# Initialize the robot +robot = SO100Follower(robot_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert EE action to joints action +robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=False, # Because replay is open loop + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Fetch the dataset to replay +dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) +# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 +episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) +actions = episode_frames.select_columns("action") + +# Connect to the robot +robot.connect() + +if not robot.is_connected: + raise ValueError("Robot is not connected!") + +print("Starting replay loop...") +log_say(f"Replaying episode {EPISODE_IDX}") +for idx in range(len(episode_frames)): + t0 = time.perf_counter() + + # Get recorded action from dataset + ee_action = { + name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) + } + + # Get robot observation + robot_obs = robot.get_observation() + + # Dataset EE -> robot joints + joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) + + # Send action to robot + _ = robot.send_action(joint_action) + + busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0)) + +# Clean up +robot.disconnect() diff --git a/examples/so100_to_so100_EE/teleoperate.py b/examples/so100_to_so100_EE/teleoperate.py new file mode 100644 index 00000000..ab54e723 --- /dev/null +++ b/examples/so100_to_so100_EE/teleoperate.py @@ -0,0 +1,122 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor.converters import ( + robot_action_observation_to_transition, + robot_action_to_transition, + transition_to_robot_action, +) +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + EEBoundsAndSafety, + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig +from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader +from lerobot.utils.robot_utils import busy_wait +from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data + +FPS = 30 + +# Initialize the robot and teleoperator config +follower_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True +) +leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm") + +# Initialize the robot and teleoperator +follower = SO100Follower(follower_config) +leader = SO100Leader(leader_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +follower_kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(follower.bus.motors.keys()), +) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +leader_kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(leader.bus.motors.keys()), +) + +# Build pipeline to convert teleop joints to EE action +leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction]( + steps=[ + ForwardKinematicsJointsToEE( + kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys()) + ), + ], + to_transition=robot_action_to_transition, + to_output=transition_to_robot_action, +) + +# build pipeline to convert EE action to robot joints +ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + [ + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.10, + max_ee_twist_step_rad=0.50, + ), + InverseKinematicsEEToJoints( + kinematics=follower_kinematics_solver, + motor_names=list(follower.bus.motors.keys()), + initial_guess_current_joints=False, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Connect to the robot and teleoperator +follower.connect() +leader.connect() + +# Init rerun viewer +_init_rerun(session_name="so100_so100_EE_teleop") + +print("Starting teleop loop...") +while True: + t0 = time.perf_counter() + + # Get robot observation + robot_obs = follower.get_observation() + + # Get teleop observation + leader_joints_obs = leader.get_action() + + # teleop joints -> teleop EE action + leader_ee_act = leader_to_ee(leader_joints_obs) + + # teleop EE -> robot joints + follower_joints_act = ee_to_follower_joints((leader_ee_act, robot_obs)) + + # Send action to robot + _ = follower.send_action(follower_joints_act) + + # Visualize + log_rerun_data(observation=leader_ee_act, action=follower_joints_act) + + busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0)) diff --git a/pyproject.toml b/pyproject.toml index 7241a78f..70755cf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1"] placo-dep = ["placo>=0.9.6"] -transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency +transformers-dep = ["transformers>=4.52.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # Motors @@ -111,6 +111,7 @@ intelrealsense = [ "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", "pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'", ] +phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"] # stretch = [ # "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'", # "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'", @@ -153,7 +154,8 @@ all = [ "lerobot[video_benchmark]", "lerobot[aloha]", "lerobot[pusht]", - "lerobot[xarm]" + "lerobot[xarm]", + "lerobot[phone]", ] [project.scripts] diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index f5fa727c..7532f061 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download from huggingface_hub.constants import CONFIG_NAME from huggingface_hub.errors import HfHubHTTPError -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.constants import ACTION, OBS_STATE from lerobot.optim.optimizers import OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig @@ -53,7 +53,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): """ n_obs_steps: int = 1 - normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict) input_features: dict[str, PolicyFeature] = field(default_factory=dict) output_features: dict[str, PolicyFeature] = field(default_factory=dict) diff --git a/src/lerobot/configs/types.py b/src/lerobot/configs/types.py index 6040ff70..e0252784 100644 --- a/src/lerobot/configs/types.py +++ b/src/lerobot/configs/types.py @@ -24,6 +24,12 @@ class FeatureType(str, Enum): ENV = "ENV" ACTION = "ACTION" REWARD = "REWARD" + LANGUAGE = "LANGUAGE" + + +class PipelineFeatureType(str, Enum): + ACTION = "ACTION" + OBSERVATION = "OBSERVATION" class NormalizationMode(str, Enum): diff --git a/src/lerobot/constants.py b/src/lerobot/constants.py index 382435a9..464969c7 100644 --- a/src/lerobot/constants.py +++ b/src/lerobot/constants.py @@ -21,8 +21,14 @@ OBS_ENV_STATE = "observation.environment_state" OBS_STATE = "observation.state" OBS_IMAGE = "observation.image" OBS_IMAGES = "observation.images" +OBS_LANGUAGE = "observation.language" ACTION = "action" REWARD = "next.reward" +TRUNCATED = "next.truncated" +DONE = "next.done" + +OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" +OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" ROBOTS = "robots" ROBOT_TYPE = "robot_type" @@ -39,6 +45,9 @@ OPTIMIZER_STATE = "optimizer_state.safetensors" OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json" SCHEDULER_STATE = "scheduler_state.json" +POLICY_PREPROCESSOR_DEFAULT_NAME = "policy_preprocessor" +POLICY_POSTPROCESSOR_DEFAULT_NAME = "policy_postprocessor" + if "LEROBOT_HOME" in os.environ: raise ValueError( f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n" diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py new file mode 100644 index 00000000..b55ccf8a --- /dev/null +++ b/src/lerobot/datasets/pipeline_features.py @@ -0,0 +1,141 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from collections.abc import Sequence +from typing import Any + +from lerobot.configs.types import PipelineFeatureType +from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE +from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.processor import DataProcessorPipeline + + +def create_initial_features( + action: dict[str, Any] | None = None, observation: dict[str, Any] | None = None +) -> dict[PipelineFeatureType, dict[str, Any]]: + """ + Creates the initial features dict for the dataset from action and observation specs. + + Args: + action: A dictionary of action feature names to their types/shapes. + observation: A dictionary of observation feature names to their types/shapes. + + Returns: + The initial features dictionary structured by PipelineFeatureType. + """ + features = {PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: {}} + if action: + features[PipelineFeatureType.ACTION] = action + if observation: + features[PipelineFeatureType.OBSERVATION] = observation + return features + + +# Helper to filter state/action keys based on regex patterns. +def should_keep(key: str, patterns: tuple[str]) -> bool: + if patterns is None: + return True + return any(re.search(pat, key) for pat in patterns) + + +def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str: + for prefix in prefixes_to_strip: + if key.startswith(prefix): + return key[len(prefix) :] + return key + + +# Define prefixes to strip from feature keys for clean names. +# Handles both fully qualified (e.g., "action.state") and short (e.g., "state") forms. +PREFIXES_TO_STRIP = tuple( + f"{token}." for const in (ACTION, OBS_STATE, OBS_IMAGES) for token in (const, const.split(".")[-1]) +) + + +def aggregate_pipeline_dataset_features( + pipeline: DataProcessorPipeline, + initial_features: dict[PipelineFeatureType, dict[str, Any]], + *, + use_videos: bool = True, + patterns: Sequence[str] | None = None, +) -> dict[str, dict]: + """ + Aggregates and filters pipeline features to create a dataset-ready features dictionary. + + This function transforms initial features using the pipeline, categorizes them as action or observations + (image or state), filters them based on `use_videos` and `patterns`, and finally + formats them for use with a Hugging Face LeRobot Dataset. + + Args: + pipeline: The DataProcessorPipeline to apply. + initial_features: A dictionary of raw feature specs for actions and observations. + use_videos: If False, image features are excluded. + patterns: A sequence of regex patterns to filter action and state features. + Image features are not affected by this filter. + + Returns: + A dictionary of features formatted for a Hugging Face LeRobot Dataset. + """ + all_features = pipeline.transform_features(initial_features) + + # Intermediate storage for categorized and filtered features. + processed_features: dict[str, dict[str, Any]] = { + "action": {}, + "observation": {}, + } + images_token = OBS_IMAGES.split(".")[-1] + + # Iterate through all features transformed by the pipeline. + for ptype, feats in all_features.items(): + if ptype not in [PipelineFeatureType.ACTION, PipelineFeatureType.OBSERVATION]: + continue + + for key, value in feats.items(): + # 1. Categorize the feature. + is_action = ptype == PipelineFeatureType.ACTION + # Observations are classified as images if their key matches image-related tokens or if the shape of the feature is 3. + # All other observations are treated as state. + is_image = not is_action and ( + (isinstance(value, tuple) and len(value) == 3) + or ( + key.startswith(f"{OBS_IMAGES}.") + or key.startswith(f"{images_token}.") + or f".{images_token}." in key + ) + ) + + # 2. Apply filtering rules. + if is_image and not use_videos: + continue + if not is_image and not should_keep(key, patterns): + continue + + # 3. Add the feature to the appropriate group with a clean name. + name = strip_prefix(key, PREFIXES_TO_STRIP) + if is_action: + processed_features["action"][name] = value + else: + processed_features["observation"][name] = value + + # Convert the processed features into the final dataset format. + dataset_features = {} + if processed_features["action"]: + dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos)) + if processed_features["observation"]: + dataset_features.update( + hw_to_dataset_features(processed_features["observation"], "observation", use_videos) + ) + + return dataset_features diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index c840d5bc..922fc4e3 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -150,14 +150,20 @@ def get_video_size_in_mb(mp4_path: Path) -> float: def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: - """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. + """Flatten a nested dictionary by joining keys with a separator. - For example: - ``` - >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}` - >>> print(flatten_dict(dct)) - {"a/b": 1, "a/c/d": 2, "e": 3} - ``` + Example: + >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3} + >>> print(flatten_dict(dct)) + {'a/b': 1, 'a/c/d': 2, 'e': 3} + + Args: + d (dict): The dictionary to flatten. + parent_key (str): The base key to prepend to the keys in this level. + sep (str): The separator to use between keys. + + Returns: + dict: A flattened dictionary. """ items = [] for k, v in d.items(): @@ -170,6 +176,20 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: def unflatten_dict(d: dict, sep: str = "/") -> dict: + """Unflatten a dictionary with delimited keys into a nested dictionary. + + Example: + >>> flat_dct = {"a/b": 1, "a/c/d": 2, "e": 3} + >>> print(unflatten_dict(flat_dct)) + {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3} + + Args: + d (dict): A dictionary with flattened keys. + sep (str): The separator used in the keys. + + Returns: + dict: A nested dictionary. + """ outdict = {} for key, value in d.items(): parts = key.split(sep) @@ -183,6 +203,19 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict: def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: + """Serialize a dictionary containing tensors or numpy arrays to be JSON-compatible. + + Converts torch.Tensor, np.ndarray, and np.generic types to lists or native Python types. + + Args: + stats (dict): A dictionary that may contain non-serializable numeric types. + + Returns: + dict: A dictionary with all values converted to JSON-serializable types. + + Raises: + NotImplementedError: If a value has an unsupported type. + """ serialized_dict = {} for key, value in flatten_dict(stats).items(): if isinstance(value, (torch.Tensor, np.ndarray)): @@ -199,6 +232,17 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: + """Embed image bytes into the dataset table before saving to Parquet. + + This function prepares a Hugging Face dataset for serialization by converting + image objects into an embedded format that can be stored in Arrow/Parquet. + + Args: + dataset (datasets.Dataset): The input dataset, possibly containing image features. + + Returns: + datasets.Dataset: The dataset with images embedded in the table storage. + """ # Embed image bytes into the table before saving to parquet format = dataset.format dataset = dataset.with_format("arrow") @@ -208,11 +252,27 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: def load_json(fpath: Path) -> Any: + """Load data from a JSON file. + + Args: + fpath (Path): Path to the JSON file. + + Returns: + Any: The data loaded from the JSON file. + """ with open(fpath) as f: return json.load(f) def write_json(data: dict, fpath: Path) -> None: + """Write data to a JSON file. + + Creates parent directories if they don't exist. + + Args: + data (dict): The dictionary to write. + fpath (Path): The path to the output JSON file. + """ fpath.parent.mkdir(exist_ok=True, parents=True) with open(fpath, "w") as f: json.dump(data, f, indent=4, ensure_ascii=False) @@ -223,6 +283,16 @@ def write_info(info: dict, local_dir: Path) -> None: def load_info(local_dir: Path) -> dict: + """Load dataset info metadata from its standard file path. + + Also converts shape lists to tuples for consistency. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + dict: The dataset information dictionary. + """ info = load_json(local_dir / INFO_PATH) for ft in info["features"].values(): ft["shape"] = tuple(ft["shape"]) @@ -230,16 +300,40 @@ def load_info(local_dir: Path) -> dict: def write_stats(stats: dict, local_dir: Path) -> None: + """Serialize and write dataset statistics to their standard file path. + + Args: + stats (dict): The statistics dictionary (can contain tensors/numpy arrays). + local_dir (Path): The root directory of the dataset. + """ serialized_stats = serialize_dict(stats) write_json(serialized_stats, local_dir / STATS_PATH) def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]: + """Recursively cast numerical values in a stats dictionary to numpy arrays. + + Args: + stats (dict): The statistics dictionary. + + Returns: + dict: The statistics dictionary with values cast to numpy arrays. + """ stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} return unflatten_dict(stats) def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None: + """Load dataset statistics and cast numerical values to numpy arrays. + + Returns None if the stats file doesn't exist. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + A dictionary of statistics or None if the file is not found. + """ if not (local_dir / STATS_PATH).exists(): return None stats = load_json(local_dir / STATS_PATH) @@ -297,6 +391,18 @@ def backward_compatible_episodes_stats( def load_image_as_numpy( fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True ) -> np.ndarray: + """Load an image from a file into a numpy array. + + Args: + fpath (str | Path): Path to the image file. + dtype (np.dtype): The desired data type of the output array. If floating, + pixels are scaled to [0, 1]. + channel_first (bool): If True, converts the image to (C, H, W) format. + Otherwise, it remains in (H, W, C) format. + + Returns: + np.ndarray: The image as a numpy array. + """ img = PILImage.open(fpath).convert("RGB") img_array = np.array(img, dtype=dtype) if channel_first: # (H, W, C) -> (C, H, W) @@ -307,10 +413,19 @@ def load_image_as_numpy( def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: - """Get a transform function that convert items from Hugging Face dataset (pyarrow) - to torch tensors. Importantly, images are converted from PIL, which corresponds to - a channel last representation (h w c) of uint8 type, to a torch image representation - with channel first (c h w) of float32 type in range [0,1]. + """Convert a batch from a Hugging Face dataset to torch tensors. + + This transform function converts items from Hugging Face dataset format (pyarrow) + to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8) + to a torch image representation (C, H, W, float32) in the range [0, 1]. Other + types are converted to torch.tensor. + + Args: + items_dict (dict): A dictionary representing a batch of data from a + Hugging Face dataset. + + Returns: + dict: The batch with items converted to torch tensors. """ for key in items_dict: first_item = items_dict[key][0] @@ -325,6 +440,14 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to def is_valid_version(version: str) -> bool: + """Check if a string is a valid PEP 440 version. + + Args: + version (str): The version string to check. + + Returns: + bool: True if the version string is valid, False otherwise. + """ try: packaging.version.parse(version) return True @@ -338,6 +461,18 @@ def check_version_compatibility( current_version: str | packaging.version.Version, enforce_breaking_major: bool = True, ) -> None: + """Check for version compatibility between a dataset and the current codebase. + + Args: + repo_id (str): The repository ID for logging purposes. + version_to_check (str | packaging.version.Version): The version of the dataset. + current_version (str | packaging.version.Version): The current version of the codebase. + enforce_breaking_major (bool): If True, raise an error on major version mismatch. + + Raises: + BackwardCompatibilityError: If the dataset version is from a newer, incompatible + major version of the codebase. + """ v_check = ( packaging.version.parse(version_to_check) if not isinstance(version_to_check, packaging.version.Version) @@ -355,7 +490,14 @@ def check_version_compatibility( def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: - """Returns available valid versions (branches and tags) on given repo.""" + """Return available valid versions (branches and tags) on a given Hub repo. + + Args: + repo_id (str): The repository ID on the Hugging Face Hub. + + Returns: + list[packaging.version.Version]: A list of valid versions found. + """ api = HfApi() repo_refs = api.list_repo_refs(repo_id, repo_type="dataset") repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags] @@ -368,9 +510,22 @@ def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str: - """ - Returns the version if available on repo or the latest compatible one. - Otherwise, will throw a `CompatibilityError`. + """Return the specified version if available on repo, or the latest compatible one. + + If the exact version is not found, it looks for the latest version with the + same major version number that is less than or equal to the target minor version. + + Args: + repo_id (str): The repository ID on the Hugging Face Hub. + version (str | packaging.version.Version): The target version. + + Returns: + str: The safe version string (e.g., "v1.2.3") to use as a revision. + + Raises: + RevisionNotFoundError: If the repo has no version tags. + BackwardCompatibilityError: If only older major versions are available. + ForwardCompatibilityError: If only newer major versions are available. """ target_version = ( packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version @@ -412,6 +567,17 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> def get_hf_features_from_features(features: dict) -> datasets.Features: + """Convert a LeRobot features dictionary to a `datasets.Features` object. + + Args: + features (dict): A LeRobot-style feature dictionary. + + Returns: + datasets.Features: The corresponding Hugging Face `datasets.Features` object. + + Raises: + ValueError: If a feature has an unsupported shape. + """ hf_features = {} for key, ft in features.items(): if ft["dtype"] == "video": @@ -439,6 +605,14 @@ def get_hf_features_from_features(features: dict) -> datasets.Features: def _validate_feature_names(features: dict[str, dict]) -> None: + """Validate that feature names do not contain invalid characters. + + Args: + features (dict): The LeRobot features dictionary. + + Raises: + ValueError: If any feature name contains '/'. + """ invalid_features = {name: ft for name, ft in features.items() if "/" in name} if invalid_features: raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") @@ -447,8 +621,28 @@ def _validate_feature_names(features: dict[str, dict]) -> None: def hw_to_dataset_features( hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True ) -> dict[str, dict]: + """Convert hardware-specific features to a LeRobot dataset feature dictionary. + + This function takes a dictionary describing hardware outputs (like joint states + or camera image shapes) and formats it into the standard LeRobot feature + specification. + + Args: + hw_features (dict): Dictionary mapping feature names to their type (float for + joints) or shape (tuple for images). + prefix (str): The prefix to add to the feature keys (e.g., "observation" + or "action"). + use_video (bool): If True, image features are marked as "video", otherwise "image". + + Returns: + dict: A LeRobot features dictionary. + """ features = {} - joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float} + joint_fts = { + key: ftype + for key, ftype in hw_features.items() + if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) + } cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} if joint_fts and prefix == "action": @@ -479,6 +673,20 @@ def hw_to_dataset_features( def build_dataset_frame( ds_features: dict[str, dict], values: dict[str, Any], prefix: str ) -> dict[str, np.ndarray]: + """Construct a single data frame from raw values based on dataset features. + + A "frame" is a dictionary containing all the data for a single timestep, + formatted as numpy arrays according to the feature specification. + + Args: + ds_features (dict): The LeRobot dataset features dictionary. + values (dict): A dictionary of raw values from the hardware/environment. + prefix (str): The prefix to filter features by (e.g., "observation" + or "action"). + + Returns: + dict: A dictionary representing a single frame of data. + """ frame = {} for key, ft in ds_features.items(): if key in DEFAULT_FEATURES or not key.startswith(prefix): @@ -492,6 +700,21 @@ def build_dataset_frame( def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: + """Convert dataset features to policy features. + + This function transforms the dataset's feature specification into a format + that a policy can use, classifying features by type (e.g., visual, state, + action) and ensuring correct shapes (e.g., channel-first for images). + + Args: + features (dict): The LeRobot dataset features dictionary. + + Returns: + dict: A dictionary mapping feature keys to `PolicyFeature` objects. + + Raises: + ValueError: If an image feature does not have a 3D shape. + """ # TODO(aliberts): Implement "type" in dataset features and simplify this policy_features = {} for key, ft in features.items(): @@ -522,6 +745,58 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea return policy_features +def combine_feature_dicts(*dicts: dict) -> dict: + """Merge LeRobot grouped feature dicts. + + - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape. + - For others (e.g. `observation.images.*`), the last one wins (if they are identical). + + Args: + *dicts: A variable number of LeRobot feature dictionaries to merge. + + Returns: + dict: A single merged feature dictionary. + + Raises: + ValueError: If there's a dtype mismatch for a feature being merged. + """ + out: dict = {} + for d in dicts: + for key, value in d.items(): + if not isinstance(value, dict): + out[key] = value + continue + + dtype = value.get("dtype") + shape = value.get("shape") + is_vector = ( + dtype not in ("image", "video", "string") + and isinstance(shape, tuple) + and len(shape) == 1 + and "names" in value + ) + + if is_vector: + # Initialize or retrieve the accumulating dict for this feature key + target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)}) + # Ensure consistent data types across merged entries + if "dtype" in target and dtype != target["dtype"]: + raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}") + + # Merge feature names: append only new ones to preserve order without duplicates + seen = set(target["names"]) + for n in value["names"]: + if n not in seen: + target["names"].append(n) + seen.add(n) + # Recompute the shape to reflect the updated number of features + target["shape"] = (len(target["names"]),) + else: + # For images/videos and non-1D entries: override with the latest definition + out[key] = value + return out + + def create_empty_dataset_info( codebase_version: str, fps: int, @@ -532,6 +807,18 @@ def create_empty_dataset_info( data_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None, ) -> dict: + """Create a template dictionary for a new dataset's `info.json`. + + Args: + codebase_version (str): The version of the LeRobot codebase. + fps (int): The frames per second of the data. + features (dict): The LeRobot features dictionary for the dataset. + use_videos (bool): Whether the dataset will store videos. + robot_type (str | None): The type of robot used, if any. + + Returns: + dict: A dictionary with the initial dataset metadata. + """ return { "codebase_version": codebase_version, "robot_type": robot_type, @@ -552,9 +839,23 @@ def create_empty_dataset_info( def check_delta_timestamps( delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True ) -> bool: - """This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance. - This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be - actual timestamps from the dataset. + """Check if delta timestamps are multiples of 1/fps +/- tolerance. + + This ensures that adding these delta timestamps to any existing timestamp in + the dataset will result in a value that aligns with the dataset's frame rate. + + Args: + delta_timestamps (dict): A dictionary where values are lists of time + deltas in seconds. + fps (int): The frames per second of the dataset. + tolerance_s (float): The allowed tolerance in seconds. + raise_value_error (bool): If True, raises an error on failure. + + Returns: + bool: True if all deltas are valid, False otherwise. + + Raises: + ValueError: If any delta is outside the tolerance and `raise_value_error` is True. """ outside_tolerance = {} for key, delta_ts in delta_timestamps.items(): @@ -580,6 +881,15 @@ def check_delta_timestamps( def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: + """Convert delta timestamps in seconds to delta indices in frames. + + Args: + delta_timestamps (dict): A dictionary of time deltas in seconds. + fps (int): The frames per second of the dataset. + + Returns: + dict: A dictionary of frame delta indices. + """ delta_indices = {} for key, delta_ts in delta_timestamps.items(): delta_indices[key] = [round(d * fps) for d in delta_ts] @@ -588,9 +898,17 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic def cycle(iterable: Any) -> Iterator[Any]: - """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. + """Create a dataloader-safe cyclical iterator. - See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe. + This is an equivalent of `itertools.cycle` but is safe for use with + PyTorch DataLoaders with multiple workers. + See https://github.com/pytorch/pytorch/issues/23900 for details. + + Args: + iterable: The iterable to cycle over. + + Yields: + Items from the iterable, restarting from the beginning when exhausted. """ iterator = iter(iterable) while True: @@ -601,8 +919,14 @@ def cycle(iterable: Any) -> Iterator[Any]: def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) -> None: - """Create a branch on a existing Hugging Face repo. Delete the branch if it already - exists before creating it. + """Create a branch on an existing Hugging Face repo. + + Deletes the branch if it already exists before creating it. + + Args: + repo_id (str): The ID of the repository. + branch (str): The name of the branch to create. + repo_type (str | None): The type of the repository (e.g., "dataset"). """ api = HfApi() @@ -620,9 +944,20 @@ def create_lerobot_dataset_card( dataset_info: dict | None = None, **kwargs, ) -> DatasetCard: - """ - Keyword arguments will be used to replace values in src/lerobot/datasets/card_template.md. - Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses. + """Create a `DatasetCard` for a LeRobot dataset. + + Keyword arguments are used to replace values in the card template. + Note: If specified, `license` must be a valid license identifier from + https://huggingface.co/docs/hub/repositories-licenses. + + Args: + tags (list | None): A list of tags to add to the dataset card. + dataset_info (dict | None): The dataset's info dictionary, which will + be displayed on the card. + **kwargs: Additional keyword arguments to populate the card template. + + Returns: + DatasetCard: The generated dataset card object. """ card_tags = ["LeRobot"] @@ -675,6 +1010,15 @@ def validate_frame(frame: dict, features: dict) -> None: def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str: + """Check for missing or extra features in a frame. + + Args: + actual_features (set[str]): The set of feature names present in the frame. + expected_features (set[str]): The set of feature names expected in the frame. + + Returns: + str: An error message string if there's a mismatch, otherwise an empty string. + """ error_message = "" missing_features = expected_features - actual_features extra_features = actual_features - expected_features @@ -692,6 +1036,19 @@ def validate_features_presence(actual_features: set[str], expected_features: set def validate_feature_dtype_and_shape( name: str, feature: dict, value: np.ndarray | PILImage.Image | str ) -> str: + """Validate the dtype and shape of a single feature's value. + + Args: + name (str): The name of the feature. + feature (dict): The feature specification from the LeRobot features dictionary. + value: The value of the feature to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + + Raises: + NotImplementedError: If the feature dtype is not supported for validation. + """ expected_dtype = feature["dtype"] expected_shape = feature["shape"] if is_valid_numpy_dtype_string(expected_dtype): @@ -707,6 +1064,17 @@ def validate_feature_dtype_and_shape( def validate_feature_numpy_array( name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray ) -> str: + """Validate a feature that is expected to be a numpy array. + + Args: + name (str): The name of the feature. + expected_dtype (str): The expected numpy dtype as a string. + expected_shape (list[int]): The expected shape. + value (np.ndarray): The numpy array to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ error_message = "" if isinstance(value, np.ndarray): actual_dtype = value.dtype @@ -726,6 +1094,18 @@ def validate_feature_numpy_array( def validate_feature_image_or_video( name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image ) -> str: + """Validate a feature that is expected to be an image or video frame. + + Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`. + + Args: + name (str): The name of the feature. + expected_shape (list[str]): The expected shape (C, H, W). + value: The image data to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. error_message = "" if isinstance(value, np.ndarray): @@ -742,12 +1122,35 @@ def validate_feature_image_or_video( def validate_feature_string(name: str, value: str) -> str: + """Validate a feature that is expected to be a string. + + Args: + name (str): The name of the feature. + value (str): The value to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ if not isinstance(value, str): return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" return "" def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None: + """Validate the episode buffer before it's written to disk. + + Ensures the buffer has the required keys, contains at least one frame, and + has features consistent with the dataset's specification. + + Args: + episode_buffer (dict): The buffer containing data for a single episode. + total_episodes (int): The current total number of episodes in the dataset. + features (dict): The LeRobot features dictionary for the dataset. + + Raises: + ValueError: If the buffer is invalid. + NotImplementedError: If the episode index is manually set and doesn't match. + """ if "size" not in episode_buffer: raise ValueError("size key not found in episode_buffer") diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 35797c6e..f71aca70 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -161,35 +161,73 @@ class XarmEnv(EnvConfig): @dataclass -class VideoRecordConfig: - """Configuration for video recording in ManiSkill environments.""" - - enabled: bool = False - record_dir: str = "videos" - trajectory_name: str = "trajectory" +class ImagePreprocessingConfig: + crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None + resize_size: tuple[int, int] | None = None @dataclass -class EnvTransformConfig: - """Configuration for environment wrappers.""" +class RewardClassifierConfig: + """Configuration for reward classification.""" + + pretrained_path: str | None = None + success_threshold: float = 0.5 + success_reward: float = 1.0 + + +@dataclass +class InverseKinematicsConfig: + """Configuration for inverse kinematics processing.""" + + urdf_path: str | None = None + target_frame_name: str | None = None + end_effector_bounds: dict[str, list[float]] | None = None + end_effector_step_sizes: dict[str, float] | None = None + + +@dataclass +class ObservationConfig: + """Configuration for observation processing.""" - # ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig) - control_mode: str = "gamepad" - display_cameras: bool = False add_joint_velocity_to_observation: bool = False add_current_to_observation: bool = False add_ee_pose_to_observation: bool = False - crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None - resize_size: tuple[int, int] | None = None - control_time_s: float = 20.0 - fixed_reset_joint_positions: Any | None = None - reset_time_s: float = 5.0 + display_cameras: bool = False + + +@dataclass +class GripperConfig: + """Configuration for gripper control and penalties.""" + use_gripper: bool = True - gripper_quantization_threshold: float | None = 0.8 gripper_penalty: float = 0.0 gripper_penalty_in_reward: bool = False +@dataclass +class ResetConfig: + """Configuration for environment reset behavior.""" + + fixed_reset_joint_positions: Any | None = None + reset_time_s: float = 5.0 + control_time_s: float = 20.0 + terminate_on_success: bool = True + + +@dataclass +class HILSerlProcessorConfig: + """Configuration for environment processing pipeline.""" + + control_mode: str = "gamepad" + observation: ObservationConfig | None = None + image_preprocessing: ImagePreprocessingConfig | None = None + gripper: GripperConfig | None = None + reset: ResetConfig | None = None + inverse_kinematics: InverseKinematicsConfig | None = None + reward_classifier: RewardClassifierConfig | None = None + max_gripper_pos: float | None = 100.0 + + @EnvConfig.register_subclass(name="gym_manipulator") @dataclass class HILSerlRobotEnvConfig(EnvConfig): @@ -197,77 +235,10 @@ class HILSerlRobotEnvConfig(EnvConfig): robot: RobotConfig | None = None teleop: TeleoperatorConfig | None = None - wrapper: EnvTransformConfig | None = None - fps: int = 10 + processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig) + name: str = "real_robot" - mode: str | None = None # Either "record", "replay", None - repo_id: str | None = None - dataset_root: str | None = None - task: str | None = "" - num_episodes: int = 10 # only for record mode - episode: int = 0 - device: str = "cuda" - push_to_hub: bool = True - pretrained_policy_name_or_path: str | None = None - reward_classifier_pretrained_path: str | None = None - # For the reward classifier, to record more positive examples after a success - number_of_steps_after_success: int = 0 @property def gym_kwargs(self) -> dict: return {} - - -@EnvConfig.register_subclass("hil") -@dataclass -class HILEnvConfig(EnvConfig): - """Configuration for the HIL environment.""" - - name: str = "PandaPickCube" - task: str | None = "PandaPickCubeKeyboard-v0" - use_viewer: bool = True - gripper_penalty: float = 0.0 - use_gamepad: bool = True - state_dim: int = 18 - action_dim: int = 4 - fps: int = 100 - episode_length: int = 100 - video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) - features: dict[str, PolicyFeature] = field( - default_factory=lambda: { - "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), - "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)), - } - ) - features_map: dict[str, str] = field( - default_factory=lambda: { - "action": ACTION, - "observation.image": OBS_IMAGE, - "observation.state": OBS_STATE, - } - ) - ################# args from hilserlrobotenv - reward_classifier_pretrained_path: str | None = None - robot_config: RobotConfig | None = None - teleop_config: TeleoperatorConfig | None = None - wrapper: EnvTransformConfig | None = None - mode: str | None = None # Either "record", "replay", None - repo_id: str | None = None - dataset_root: str | None = None - num_episodes: int = 10 # only for record mode - episode: int = 0 - device: str = "cuda" - push_to_hub: bool = True - pretrained_policy_name_or_path: str | None = None - # For the reward classifier, to record more positive examples after a success - number_of_steps_after_success: int = 0 - ############################ - - @property - def gym_kwargs(self) -> dict: - return { - "use_viewer": self.use_viewer, - "use_gamepad": self.use_gamepad, - "gripper_penalty": self.gripper_penalty, - } diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index dc6d96d6..af8f5eaf 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -17,7 +17,7 @@ import importlib import gymnasium as gym -from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv +from lerobot.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -27,8 +27,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return PushtEnv(**kwargs) elif env_type == "xarm": return XarmEnv(**kwargs) - elif env_type == "hil": - return HILEnvConfig(**kwargs) else: raise ValueError(f"Policy type '{env_type}' is not available.") diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 00676a01..b4f65ee9 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -127,9 +127,29 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None: def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]: """Adds task feature to the observation dict with respect to the first environment attribute.""" if hasattr(env.envs[0], "task_description"): - observation["task"] = env.call("task_description") + task_result = env.call("task_description") + + if isinstance(task_result, tuple): + task_result = list(task_result) + + if not isinstance(task_result, list): + raise TypeError(f"Expected task_description to return a list, got {type(task_result)}") + if not all(isinstance(item, str) for item in task_result): + raise TypeError("All items in task_description result must be strings") + + observation["task"] = task_result elif hasattr(env.envs[0], "task"): - observation["task"] = env.call("task") + task_result = env.call("task") + + if isinstance(task_result, tuple): + task_result = list(task_result) + + if not isinstance(task_result, list): + raise TypeError(f"Expected task to return a list, got {type(task_result)}") + if not all(isinstance(item, str) for item in task_result): + raise TypeError("All items in task result must be strings") + + observation["task"] = task_result else: # For envs without language instructions, e.g. aloha transfer cube and etc. num_envs = observation[list(observation.keys())[0]].shape[0] observation["task"] = ["" for _ in range(num_envs)] diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 9cb0f623..9b9de993 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -15,6 +15,17 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .pi0.configuration_pi0 import PI0Config as PI0Config +from .pi0.processor_pi0 import Pi0NewLineProcessor from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig +from .smolvla.processor_smolvla import SmolVLANewLineProcessor from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig + +__all__ = [ + "ACTConfig", + "DiffusionConfig", + "PI0Config", + "SmolVLAConfig", + "TDMPCConfig", + "VQBeTConfig", +] diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index cfd549b2..e0f3462c 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -35,7 +35,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.constants import ACTION, OBS_IMAGES from lerobot.policies.act.configuration_act import ACTConfig -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy @@ -51,27 +50,16 @@ class ACTPolicy(PreTrainedPolicy): def __init__( self, config: ACTConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: config: Policy configuration class instance or None, in which case the default instantiation of the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__(config) config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.model = ACT(config) if config.temporal_ensemble_coeff is not None: @@ -137,23 +125,19 @@ class ACTPolicy(PreTrainedPolicy): """Predict a chunk of actions given environment observations.""" self.eval() - batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features] actions = self.model(batch)[0] - actions = self.unnormalize_outputs({ACTION: actions})[ACTION] return actions def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" - batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features] - batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py new file mode 100644 index 00000000..b0d2067e --- /dev/null +++ b/src/lerobot/policies/act/processor_act.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import torch + +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.act.configuration_act import ACTConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_act_pre_post_processors( + config: ACTConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """Creates the pre- and post-processing pipelines for the ACT policy. + + The pre-processing pipeline handles normalization, batching, and device placement for the model inputs. + The post-processing pipeline handles unnormalization and moves the model outputs back to the CPU. + + Args: + config (ACTConfig): The ACT policy configuration object. + dataset_stats (dict[str, dict[str, torch.Tensor]] | None): A dictionary containing dataset + statistics (e.g., mean and std) used for normalization. Defaults to None. + + Returns: + tuple[PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction]]: A tuple containing the + pre-processor pipeline and the post-processor pipeline. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + device=config.device, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 85d4d598..747ead33 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -35,7 +35,6 @@ from torch import Tensor, nn from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import ( get_device_from_parameters, @@ -57,7 +56,6 @@ class DiffusionPolicy(PreTrainedPolicy): def __init__( self, config: DiffusionConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: @@ -70,14 +68,6 @@ class DiffusionPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - # queues are populated during rollout of the policy, they contain the n latest observations and actions self._queues = None @@ -106,9 +96,6 @@ class DiffusionPolicy(PreTrainedPolicy): batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} actions = self.diffusion.generate_actions(batch) - # TODO(rcadene): make above methods return output dictionary? - actions = self.unnormalize_outputs({ACTION: actions})[ACTION] - return actions @torch.no_grad() @@ -137,7 +124,6 @@ class DiffusionPolicy(PreTrainedPolicy): if ACTION in batch: batch.pop(ACTION) - batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) @@ -153,11 +139,9 @@ class DiffusionPolicy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]: """Run the batch through the model and compute the loss for training or validation.""" - batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) - batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) # no output_dict so returning None return loss, None diff --git a/src/lerobot/policies/diffusion/processor_diffusion.py b/src/lerobot/policies/diffusion/processor_diffusion.py new file mode 100644 index 00000000..4383ec95 --- /dev/null +++ b/src/lerobot/policies/diffusion/processor_diffusion.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import torch + +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_diffusion_pre_post_processors( + config: DiffusionConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for a diffusion policy. + + The pre-processing pipeline prepares the input data for the model by: + 1. Renaming features. + 2. Normalizing the input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving the data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving the data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the diffusion policy, + containing feature definitions, normalization mappings, and device information. + dataset_stats: A dictionary of statistics used for normalization. + Defaults to None. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index ef56bdb6..06c0c4ba 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -14,12 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging +from __future__ import annotations -from torch import nn +import logging +from typing import Any, TypedDict + +import torch +from typing_extensions import Unpack from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.datasets.utils import dataset_to_policy_features from lerobot.envs.configs import EnvConfig @@ -34,10 +39,32 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor.converters import ( + batch_to_transition, + policy_action_to_transition, + transition_to_batch, + transition_to_policy_action, +) -def get_policy_class(name: str) -> PreTrainedPolicy: - """Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" +def get_policy_class(name: str) -> type[PreTrainedPolicy]: + """ + Retrieves a policy class by its registered name. + + This function uses dynamic imports to avoid loading all policy classes into memory + at once, improving startup time and reducing dependencies. + + Args: + name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", + "vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla". + + Returns: + The policy class corresponding to the given name. + + Raises: + NotImplementedError: If the policy name is not recognized. + """ if name == "tdmpc": from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy @@ -79,6 +106,24 @@ def get_policy_class(name: str) -> PreTrainedPolicy: def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: + """ + Instantiates a policy configuration object based on the policy type. + + This factory function simplifies the creation of policy configuration objects by + mapping a string identifier to the corresponding config class. + + Args: + policy_type: The type of the policy. Supported types include "tdmpc", + "diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla", + "reward_classifier". + **kwargs: Keyword arguments to be passed to the configuration class constructor. + + Returns: + An instance of a `PreTrainedConfig` subclass. + + Raises: + ValueError: If the `policy_type` is not recognized. + """ if policy_type == "tdmpc": return TDMPCConfig(**kwargs) elif policy_type == "diffusion": @@ -101,30 +146,187 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: raise ValueError(f"Policy type '{policy_type}' is not available.") +class ProcessorConfigKwargs(TypedDict, total=False): + """ + A TypedDict defining the keyword arguments for processor configuration. + + This provides type hints for the optional arguments passed to `make_pre_post_processors`, + improving code clarity and enabling static analysis. + + Attributes: + preprocessor_config_filename: The filename for the preprocessor configuration. + postprocessor_config_filename: The filename for the postprocessor configuration. + preprocessor_overrides: A dictionary of overrides for the preprocessor configuration. + postprocessor_overrides: A dictionary of overrides for the postprocessor configuration. + dataset_stats: Dataset statistics for normalization. + """ + + preprocessor_config_filename: str | None + postprocessor_config_filename: str | None + preprocessor_overrides: dict[str, Any] | None + postprocessor_overrides: dict[str, Any] | None + dataset_stats: dict[str, dict[str, torch.Tensor]] | None + + +def make_pre_post_processors( + policy_cfg: PreTrainedConfig, + pretrained_path: str | None = None, + **kwargs: Unpack[ProcessorConfigKwargs], +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Create or load pre- and post-processor pipelines for a given policy. + + This function acts as a factory. It can either load existing processor pipelines + from a pretrained path or create new ones from scratch based on the policy + configuration. Each policy type has a dedicated factory function for its + processors (e.g., `make_tdmpc_pre_post_processors`). + + Args: + policy_cfg: The configuration of the policy for which to create processors. + pretrained_path: An optional path to load pretrained processor pipelines from. + If provided, pipelines are loaded from this path. + **kwargs: Keyword arguments for processor configuration, as defined in + `ProcessorConfigKwargs`. + + Returns: + A tuple containing the input (pre-processor) and output (post-processor) pipelines. + + Raises: + NotImplementedError: If a processor factory is not implemented for the given + policy configuration type. + """ + if pretrained_path: + return ( + PolicyProcessorPipeline.from_pretrained( + pretrained_model_name_or_path=pretrained_path, + config_filename=kwargs.get( + "preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json" + ), + overrides=kwargs.get("preprocessor_overrides", {}), + to_transition=batch_to_transition, + to_output=transition_to_batch, + ), + PolicyProcessorPipeline.from_pretrained( + pretrained_model_name_or_path=pretrained_path, + config_filename=kwargs.get( + "postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json" + ), + overrides=kwargs.get("postprocessor_overrides", {}), + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) + + # Create a new processor based on policy type + if isinstance(policy_cfg, TDMPCConfig): + from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors + + processors = make_tdmpc_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, DiffusionConfig): + from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors + + processors = make_diffusion_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, ACTConfig): + from lerobot.policies.act.processor_act import make_act_pre_post_processors + + processors = make_act_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, VQBeTConfig): + from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors + + processors = make_vqbet_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, PI0Config): + from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors + + processors = make_pi0_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, PI0FASTConfig): + from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors + + processors = make_pi0fast_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, SACConfig): + from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors + + processors = make_sac_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, RewardClassifierConfig): + from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor + + processors = make_classifier_processor( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, SmolVLAConfig): + from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors + + processors = make_smolvla_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + else: + raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") + + return processors + + def make_policy( cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata | None = None, env_cfg: EnvConfig | None = None, ) -> PreTrainedPolicy: - """Make an instance of a policy class. + """ + Instantiate a policy model. - This function exists because (for now) we need to parse features from either a dataset or an environment - in order to properly dimension and instantiate a policy for that dataset or environment. + This factory function handles the logic of creating a policy, which requires + determining the input and output feature shapes. These shapes can be derived + either from a `LeRobotDatasetMetadata` object or an `EnvConfig` object. The function + can either initialize a new policy from scratch or load a pretrained one. Args: - cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will - be loaded with the weights from that path. - ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and - statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None. - env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be - provided if ds_meta is not. Defaults to None. - - Raises: - ValueError: Either ds_meta or env and env_cfg must be provided. - NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility) + cfg: The configuration for the policy to be created. If `cfg.pretrained_path` is + set, the policy will be loaded with weights from that path. + ds_meta: Dataset metadata used to infer feature shapes and types. Also provides + statistics for normalization layers. + env_cfg: Environment configuration used to infer feature shapes and types. + One of `ds_meta` or `env_cfg` must be provided. Returns: - PreTrainedPolicy: _description_ + An instantiated and device-placed policy model. + + Raises: + ValueError: If both or neither of `ds_meta` and `env_cfg` are provided. + NotImplementedError: If attempting to use an unsupported policy-backend + combination (e.g., VQBeT with 'mps'). """ if bool(ds_meta) == bool(env_cfg): raise ValueError("Either one of a dataset metadata or a sim env must be provided.") @@ -147,7 +349,6 @@ def make_policy( kwargs = {} if ds_meta is not None: features = dataset_to_policy_features(ds_meta.features) - kwargs["dataset_stats"] = ds_meta.stats else: if not cfg.pretrained_path: logging.warning( @@ -155,6 +356,8 @@ def make_policy( "rather than a dataset. Normalization modules inside the policy will have infinite values " "by default without stats from a dataset." ) + if env_cfg is None: + raise ValueError("env_cfg cannot be None when ds_meta is not provided") features = env_to_policy_features(env_cfg) cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} @@ -171,7 +374,7 @@ def make_policy( policy = policy_cls(**kwargs) policy.to(cfg.device) - assert isinstance(policy, nn.Module) + assert isinstance(policy, torch.nn.Module) # policy = torch.compile(policy, mode="reduce-overhead") diff --git a/src/lerobot/policies/normalize.py b/src/lerobot/policies/normalize.py deleted file mode 100644 index 11905587..00000000 --- a/src/lerobot/policies/normalize.py +++ /dev/null @@ -1,420 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -import torch -from torch import Tensor, nn - -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature - - -def create_stats_buffers( - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, -) -> dict[str, dict[str, nn.ParameterDict]]: - """ - Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max - statistics. - - Args: (see Normalize and Unnormalize) - - Returns: - dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing - `nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation. - """ - stats_buffers = {} - - for key, ft in features.items(): - norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - assert isinstance(norm_mode, NormalizationMode) - - shape = tuple(ft.shape) - - if ft.type is FeatureType.VISUAL: - # sanity checks - assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}" - c, h, w = shape - assert c < h and c < w, f"{key} is not channel first ({shape=})" - # override image shape to be invariant to height and width - shape = (c, 1, 1) - - # Note: we initialize mean, std, min, max to infinity. They should be overwritten - # downstream by `stats` or `policy.load_state_dict`, as expected. During forward, - # we assert they are not infinity anymore. - - buffer = {} - if norm_mode is NormalizationMode.MEAN_STD: - mean = torch.ones(shape, dtype=torch.float32) * torch.inf - std = torch.ones(shape, dtype=torch.float32) * torch.inf - buffer = nn.ParameterDict( - { - "mean": nn.Parameter(mean, requires_grad=False), - "std": nn.Parameter(std, requires_grad=False), - } - ) - elif norm_mode is NormalizationMode.MIN_MAX: - min = torch.ones(shape, dtype=torch.float32) * torch.inf - max = torch.ones(shape, dtype=torch.float32) * torch.inf - buffer = nn.ParameterDict( - { - "min": nn.Parameter(min, requires_grad=False), - "max": nn.Parameter(max, requires_grad=False), - } - ) - - # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch) - if stats: - if isinstance(stats[key]["mean"], np.ndarray): - if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) - buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) - elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) - buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) - elif isinstance(stats[key]["mean"], torch.Tensor): - # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated - # tensors anywhere (for example, when we use the same stats for normalization and - # unnormalization). See the logic here - # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. - if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) - buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) - elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) - buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) - else: - type_ = type(stats[key]["mean"]) - raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") - - stats_buffers[key] = buffer - return stats_buffers - - -def _no_stats_error_str(name: str) -> str: - return ( - f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a " - "pretrained model." - ) - - -class Normalize(nn.Module): - """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training.""" - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - """ - Args: - shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values - are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing - mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape - is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. - modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values - are their normalization modes among: - - "mean_std": subtract the mean and divide by standard deviation. - - "min_max": map to [-1, 1] range. - stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") - and values are dictionaries of statistic types and their values (e.g. - `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for - training the model for the first time, these statistics will overwrite the default buffers. If - not provided, as expected for finetuning or evaluation, the default buffers should to be - overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the - dataset is not needed to get the stats, since they are already in the policy state_dict. - """ - super().__init__() - self.features = features - self.norm_map = norm_map - self.stats = stats - stats_buffers = create_stats_buffers(features, norm_map, stats) - for key, buffer in stats_buffers.items(): - setattr(self, "buffer_" + key.replace(".", "_"), buffer) - - # TODO(rcadene): should we remove torch.no_grad? - @torch.no_grad() - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - # TODO: Remove this shallow copy - batch = dict(batch) # shallow copy avoids mutating the input batch - for key, ft in self.features.items(): - if key not in batch: - # FIXME(aliberts, rcadene): This might lead to silent fail! - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - buffer = getattr(self, "buffer_" + key.replace(".", "_")) - - if norm_mode is NormalizationMode.MEAN_STD: - mean = buffer["mean"] - std = buffer["std"] - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = (batch[key] - mean) / (std + 1e-8) - elif norm_mode is NormalizationMode.MIN_MAX: - min = buffer["min"] - max = buffer["max"] - assert not torch.isinf(min).any(), _no_stats_error_str("min") - assert not torch.isinf(max).any(), _no_stats_error_str("max") - # normalize to [0,1] - batch[key] = (batch[key] - min) / (max - min + 1e-8) - # normalize to [-1, 1] - batch[key] = batch[key] * 2 - 1 - else: - raise ValueError(norm_mode) - return batch - - -class Unnormalize(nn.Module): - """ - Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their - original range used by the environment. - """ - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - """ - Args: - shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values - are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing - mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape - is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. - modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values - are their normalization modes among: - - "mean_std": subtract the mean and divide by standard deviation. - - "min_max": map to [-1, 1] range. - stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") - and values are dictionaries of statistic types and their values (e.g. - `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for - training the model for the first time, these statistics will overwrite the default buffers. If - not provided, as expected for finetuning or evaluation, the default buffers should to be - overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the - dataset is not needed to get the stats, since they are already in the policy state_dict. - """ - super().__init__() - self.features = features - self.norm_map = norm_map - self.stats = stats - # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` - stats_buffers = create_stats_buffers(features, norm_map, stats) - for key, buffer in stats_buffers.items(): - setattr(self, "buffer_" + key.replace(".", "_"), buffer) - - # TODO(rcadene): should we remove torch.no_grad? - @torch.no_grad() - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - batch = dict(batch) # shallow copy avoids mutating the input batch - for key, ft in self.features.items(): - if key not in batch: - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - buffer = getattr(self, "buffer_" + key.replace(".", "_")) - - if norm_mode is NormalizationMode.MEAN_STD: - mean = buffer["mean"] - std = buffer["std"] - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = batch[key] * std + mean - elif norm_mode is NormalizationMode.MIN_MAX: - min = buffer["min"] - max = buffer["max"] - assert not torch.isinf(min).any(), _no_stats_error_str("min") - assert not torch.isinf(max).any(), _no_stats_error_str("max") - batch[key] = (batch[key] + 1) / 2 - batch[key] = batch[key] * (max - min) + min - else: - raise ValueError(norm_mode) - return batch - - -# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization -# and remove the `Normalize` and `Unnormalize` classes. -def _initialize_stats_buffers( - module: nn.Module, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, -) -> None: - """Register statistics buffers (mean/std or min/max) on the given *module*. - - The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`, - but is factored out so it can be reused by both classes and stay in sync. - """ - for key, ft in features.items(): - norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - shape: tuple[int, ...] = tuple(ft.shape) - if ft.type is FeatureType.VISUAL: - # reduce spatial dimensions, keep channel dimension only - c, *_ = shape - shape = (c, 1, 1) - - prefix = key.replace(".", "_") - - if norm_mode is NormalizationMode.MEAN_STD: - mean = torch.full(shape, torch.inf, dtype=torch.float32) - std = torch.full(shape, torch.inf, dtype=torch.float32) - - if stats and key in stats and "mean" in stats[key] and "std" in stats[key]: - mean_data = stats[key]["mean"] - std_data = stats[key]["std"] - if isinstance(mean_data, torch.Tensor): - # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated - # tensors anywhere (for example, when we use the same stats for normalization and - # unnormalization). See the logic here - # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. - mean = mean_data.clone().to(dtype=torch.float32) - std = std_data.clone().to(dtype=torch.float32) - else: - raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") - - module.register_buffer(f"{prefix}_mean", mean) - module.register_buffer(f"{prefix}_std", std) - continue - - if norm_mode is NormalizationMode.MIN_MAX: - min_val = torch.full(shape, torch.inf, dtype=torch.float32) - max_val = torch.full(shape, torch.inf, dtype=torch.float32) - - if stats and key in stats and "min" in stats[key] and "max" in stats[key]: - min_data = stats[key]["min"] - max_data = stats[key]["max"] - if isinstance(min_data, torch.Tensor): - min_val = min_data.clone().to(dtype=torch.float32) - max_val = max_data.clone().to(dtype=torch.float32) - else: - raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") - - module.register_buffer(f"{prefix}_min", min_val) - module.register_buffer(f"{prefix}_max", max_val) - continue - - raise ValueError(norm_mode) - - -class NormalizeBuffer(nn.Module): - """Same as `Normalize` but statistics are stored as registered buffers rather than parameters.""" - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - super().__init__() - self.features = features - self.norm_map = norm_map - - _initialize_stats_buffers(self, features, norm_map, stats) - - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - batch = dict(batch) - for key, ft in self.features.items(): - if key not in batch: - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - prefix = key.replace(".", "_") - - if norm_mode is NormalizationMode.MEAN_STD: - mean = getattr(self, f"{prefix}_mean") - std = getattr(self, f"{prefix}_std") - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = (batch[key] - mean) / (std + 1e-8) - continue - - if norm_mode is NormalizationMode.MIN_MAX: - min_val = getattr(self, f"{prefix}_min") - max_val = getattr(self, f"{prefix}_max") - assert not torch.isinf(min_val).any(), _no_stats_error_str("min") - assert not torch.isinf(max_val).any(), _no_stats_error_str("max") - batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8) - batch[key] = batch[key] * 2 - 1 - continue - - raise ValueError(norm_mode) - - return batch - - -class UnnormalizeBuffer(nn.Module): - """Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics.""" - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - super().__init__() - self.features = features - self.norm_map = norm_map - - _initialize_stats_buffers(self, features, norm_map, stats) - - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - # batch = dict(batch) - for key, ft in self.features.items(): - if key not in batch: - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - prefix = key.replace(".", "_") - - if norm_mode is NormalizationMode.MEAN_STD: - mean = getattr(self, f"{prefix}_mean") - std = getattr(self, f"{prefix}_std") - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = batch[key] * std + mean - continue - - if norm_mode is NormalizationMode.MIN_MAX: - min_val = getattr(self, f"{prefix}_min") - max_val = getattr(self, f"{prefix}_max") - assert not torch.isinf(min_val).any(), _no_stats_error_str("min") - assert not torch.isinf(max_val).any(), _no_stats_error_str("max") - batch[key] = (batch[key] + 1) / 2 - batch[key] = batch[key] * (max_val - min_val) + min_val - continue - - raise ValueError(norm_mode) - - return batch diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index de41e2bd..66bd81e6 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -56,18 +56,15 @@ from collections import deque import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from transformers import AutoTokenizer -from lerobot.constants import ACTION, OBS_STATE -from lerobot.policies.normalize import Normalize, Unnormalize +from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0.paligemma_with_expert import ( PaliGemmaWithExpertConfig, PaliGemmaWithExpertModel, ) from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import log_model_loading_keys -from lerobot.utils.utils import get_safe_dtype, init_logging +from lerobot.utils.utils import get_safe_dtype def create_sinusoidal_pos_embedding( @@ -223,28 +220,17 @@ class PI0Policy(PreTrainedPolicy): def __init__( self, config: PI0Config, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: config: Policy configuration class instance or None, in which case the default instantiation of the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__(config) config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") self.model = PI0FlowMatching(config) self.reset() @@ -253,99 +239,6 @@ class PI0Policy(PreTrainedPolicy): """This should be called whenever the environment is reset.""" self._action_queue = deque([], maxlen=self.config.n_action_steps) - @classmethod - def _transform_state_dict_keys(cls, state_dict: dict) -> dict: - """ - Transform state dict keys to match expected model structure. - - Transformations: - - model.paligemma_with_expert.paligemma.language_model.lm_head -> - model.paligemma_with_expert.paligemma.lm_head - - model.paligemma_with_expert.paligemma.language_model.model -> - model.paligemma_with_expert.paligemma.model.language_model - - model.paligemma_with_expert.paligemma.vision_tower -> - model.paligemma_with_expert.paligemma.model.vision_tower - - model.paligemma_with_expert.paligemma.multi_modal_projector -> - model.paligemma_with_expert.paligemma.model.multi_modal_projector - - Also handles tied weights between lm_head.weight and - embed_tokens.weight. - """ - import re - - transformed_dict = {} - - transformations = [ - ( - re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"), - ".paligemma_with_expert.paligemma.lm_head", - ), - ( - re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"), - ".paligemma_with_expert.paligemma.model.language_model", - ), - ( - re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"), - ".paligemma_with_expert.paligemma.model.vision_tower", - ), - ( - re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"), - ".paligemma_with_expert.paligemma.model.multi_modal_projector", - ), - ] - - for key, value in state_dict.items(): - new_key = key - for pattern, replacement in transformations: - new_key = pattern.sub(replacement, new_key) - transformed_dict[new_key] = value - - # Handle tied weights: lm_head.weight and embed_tokens.weight share memory - lm_head_key = None - embed_tokens_key = None - - for key in transformed_dict: - if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"): - lm_head_key = key - elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"): - embed_tokens_key = key - if lm_head_key and embed_tokens_key: - break - - if lm_head_key and not embed_tokens_key: - embed_tokens_key = lm_head_key.replace( - ".lm_head.weight", ".model.language_model.embed_tokens.weight" - ) - transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key] - elif embed_tokens_key and not lm_head_key: - lm_head_key = embed_tokens_key.replace( - ".model.language_model.embed_tokens.weight", ".lm_head.weight" - ) - transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key] - - return transformed_dict - - @classmethod - def _load_as_safetensor( - cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool - ) -> "PI0Policy": - """Override to apply key transformations before loading.""" - from safetensors.torch import load_file - - init_logging() - # Load the state dict from file safely - state_dict = load_file(model_file, device=map_location) - - # Apply key transformations - transformed_state_dict = cls._transform_state_dict_keys(state_dict) - - # Load the transformed state dict - msg = model.load_state_dict(transformed_state_dict, strict=strict) - - # Log message - log_model_loading_keys(msg.missing_keys, msg.unexpected_keys) - return model - def get_optim_params(self) -> dict: return self.parameters() @@ -377,14 +270,13 @@ class PI0Policy(PreTrainedPolicy): if self.config.adapt_to_pi_aloha: batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch = self.normalize_inputs(batch) - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by # querying the policy. if len(self._action_queue) == 0: images, img_masks = self.prepare_images(batch) state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] + lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] actions = self.model.sample_actions( images, img_masks, lang_tokens, lang_masks, state, noise=noise @@ -394,8 +286,6 @@ class PI0Policy(PreTrainedPolicy): original_action_dim = self.config.action_feature.shape[0] actions = actions[:, :, :original_action_dim] - actions = self.unnormalize_outputs({"action": actions})["action"] - if self.config.adapt_to_pi_aloha: actions = self._pi_aloha_encode_actions(actions) @@ -410,12 +300,10 @@ class PI0Policy(PreTrainedPolicy): batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) - images, img_masks = self.prepare_images(batch) state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] + lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] actions = self.prepare_action(batch) actions_is_pad = batch.get("action_is_pad") @@ -482,26 +370,6 @@ class PI0Policy(PreTrainedPolicy): return images, img_masks - def prepare_language(self, batch) -> tuple[Tensor, Tensor]: - """Tokenize the text input""" - device = batch[OBS_STATE].device - tasks = batch["task"] - - # PaliGemma prompt has to end with a new line - tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] - - tokenized_prompt = self.language_tokenizer.__call__( - tasks, - padding="max_length", - padding_side="right", - max_length=self.config.tokenizer_max_length, - return_tensors="pt", - ) - lang_tokens = tokenized_prompt["input_ids"].to(device=device) - lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) - - return lang_tokens, lang_masks - def _pi_aloha_decode_state(self, state): # Flip the joints. for motor_idx in [1, 2, 8, 9]: @@ -567,7 +435,7 @@ class PI0FlowMatching(nn.Module): └──────────────────────────────┘ """ - def __init__(self, config): + def __init__(self, config: PI0Config): super().__init__() self.config = config diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py new file mode 100644 index 00000000..cd971220 --- /dev/null +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.pi0.configuration_pi0 import PI0Config +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + ComplementaryDataProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +@ProcessorStepRegistry.register(name="pi0_new_line_processor") +class Pi0NewLineProcessor(ComplementaryDataProcessorStep): + """ + Ensures that the task description string ends with a newline character. + + This processing step is required for compatibility with the PaliGemma tokenizer, + which expects a newline at the end of the text prompt. It handles both single + strings and lists of strings for the 'task' key in complementary data. + """ + + def complementary_data(self, complementary_data): + """ + Adds a newline to the 'task' field if it doesn't already have one. + + Args: + complementary_data: A dictionary that may contain a 'task' key with a + string or list of strings. + + Returns: + A new dictionary with the modified 'task' field. + """ + if "task" not in complementary_data: + return complementary_data + + task = complementary_data["task"] + if task is None: + return complementary_data + + new_complementary_data = dict(complementary_data) + + # Handle both string and list of strings + if isinstance(task, str): + # Single string: add newline if not present + if not task.endswith("\n"): + new_complementary_data["task"] = f"{task}\n" + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + # List of strings: add newline to each if not present + new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] + # If task is neither string nor list of strings, leave unchanged + + return new_complementary_data + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + This step does not alter the feature definitions. + + Args: + features: The input feature dictionary. + + Returns: + The unchanged feature dictionary. + """ + return features + + +def make_pi0_pre_post_processors( + config: PI0Config, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the PI0 policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Appending a newline character to the task description for tokenizer compatibility. + 5. Tokenizing the text prompt using the PaliGemma tokenizer. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0 policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + # Add remaining processors + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma + TokenizerProcessorStep( + tokenizer_name="google/paligemma-3b-pt-224", + max_length=config.tokenizer_max_length, + padding_side="right", + padding="max_length", + ), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + + output_steps: list[ProcessorStep] = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py index 88727b58..682a372f 100644 --- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py +++ b/src/lerobot/policies/pi0fast/modeling_pi0fast.py @@ -58,7 +58,6 @@ from transformers.cache_utils import HybridCache, StaticCache from transformers.models.auto import CONFIG_MAPPING from lerobot.constants import ACTION, OBS_STATE -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.policies.pretrained import PreTrainedPolicy @@ -146,14 +145,6 @@ class PI0FASTPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224") self.model = PI0FAST(config) @@ -221,8 +212,6 @@ class PI0FASTPolicy(PreTrainedPolicy): if self.config.adapt_to_pi_aloha: batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch = self.normalize_inputs(batch) - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by # querying the policy. if len(self._action_queue) == 0: @@ -235,8 +224,6 @@ class PI0FASTPolicy(PreTrainedPolicy): ] # self.config.max_action_dim # self.config.action_feature.shape[0] actions = actions[:, :, :original_action_dim] - actions = self.unnormalize_outputs({"action": actions})["action"] - if self.config.adapt_to_pi_aloha: actions = self._pi_aloha_encode_actions(actions) @@ -249,8 +236,6 @@ class PI0FASTPolicy(PreTrainedPolicy): if self.config.adapt_to_pi_aloha: batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) loss_dict = self.model.forward(batch) return loss_dict["loss"], loss_dict diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py new file mode 100644 index 00000000..81314aa3 --- /dev/null +++ b/src/lerobot/policies/pi0fast/processor_pi0fast.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_pi0fast_pre_post_processors( + config: PI0FASTConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the PI0Fast policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0Fast policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index 878f3cdd..fcaf02a4 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -28,7 +28,6 @@ import torch.nn.functional as F # noqa: N812 from torch import Tensor from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution -from lerobot.policies.normalize import NormalizeBuffer from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature from lerobot.policies.utils import get_device_from_parameters @@ -45,7 +44,6 @@ class SACPolicy( def __init__( self, config: SACConfig | None = None, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): super().__init__(config) config.validate_features() @@ -53,7 +51,6 @@ class SACPolicy( # Determine action dimension and initialize all components continuous_action_dim = config.output_features["action"].shape[0] - self._init_normalization(dataset_stats) self._init_encoders() self._init_critics(continuous_action_dim) self._init_actor(continuous_action_dim) @@ -88,8 +85,7 @@ class SACPolicy( observations_features = None if self.shared_encoder and self.actor.encoder.has_images: - # Cache and normalize image features - observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True) + observations_features = self.actor.encoder.get_cached_image_features(batch) actions, _, _ = self.actor(batch, observations_features) @@ -391,28 +387,12 @@ class SACPolicy( actor_loss = ((self.temperature * log_probs) - min_q_preds).mean() return actor_loss - def _init_normalization(self, dataset_stats): - """Initialize input/output normalization modules.""" - self.normalize_inputs = nn.Identity() - self.normalize_targets = nn.Identity() - if self.config.dataset_stats is not None: - params = _convert_normalization_params_to_tensor(self.config.dataset_stats) - self.normalize_inputs = NormalizeBuffer( - self.config.input_features, self.config.normalization_mapping, params - ) - stats = dataset_stats or params - self.normalize_targets = NormalizeBuffer( - self.config.output_features, self.config.normalization_mapping, stats - ) - def _init_encoders(self): """Initialize shared or separate encoders for actor and critic.""" self.shared_encoder = self.config.shared_encoder - self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs) + self.encoder_critic = SACObservationEncoder(self.config) self.encoder_actor = ( - self.encoder_critic - if self.shared_encoder - else SACObservationEncoder(self.config, self.normalize_inputs) + self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config) ) def _init_critics(self, continuous_action_dim): @@ -424,9 +404,7 @@ class SACPolicy( ) for _ in range(self.config.num_critics) ] - self.critic_ensemble = CriticEnsemble( - encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets - ) + self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads) target_heads = [ CriticHead( input_dim=self.encoder_critic.output_dim + continuous_action_dim, @@ -434,9 +412,7 @@ class SACPolicy( ) for _ in range(self.config.num_critics) ] - self.critic_target = CriticEnsemble( - encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets - ) + self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads) self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) if self.config.use_torch_compile: @@ -490,10 +466,9 @@ class SACPolicy( class SACObservationEncoder(nn.Module): """Encode image and/or state vector observations.""" - def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None: + def __init__(self, config: SACConfig) -> None: super().__init__() self.config = config - self.input_normalization = input_normalizer self._init_image_layers() self._init_state_layers() self._compute_output_dim() @@ -568,11 +543,10 @@ class SACObservationEncoder(nn.Module): def forward( self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False ) -> Tensor: - obs = self.input_normalization(obs) parts = [] if self.has_images: if cache is None: - cache = self.get_cached_image_features(obs, normalize=False) + cache = self.get_cached_image_features(obs) parts.append(self._encode_images(cache, detach)) if self.has_env: parts.append(self.env_encoder(obs["observation.environment_state"])) @@ -585,7 +559,7 @@ class SACObservationEncoder(nn.Module): "No parts to concatenate, you should have at least one image or environment state or state" ) - def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]: + def get_cached_image_features(self, obs: dict[str, Tensor]) -> dict[str, Tensor]: """Extract and optionally cache image features from observations. This function processes image observations through the vision encoder once and returns @@ -597,26 +571,17 @@ class SACObservationEncoder(nn.Module): - The vision encoder forward pass is typically the main computational bottleneck during training and inference - Caching these features can provide 2-4x speedup in training and inference - Normalization behavior: - - When called from inside forward(): set normalize=False since inputs are already normalized - - When called from outside forward(): set normalize=True to ensure proper input normalization - Usage patterns: - - Called in select_action() with normalize=True + - Called in select_action() - Called in learner.py's get_observation_features() to pre-compute features for all policy components - - Called internally by forward() with normalize=False + - Called internally by forward() Args: obs: Dictionary of observation tensors containing image keys - normalize: Whether to normalize observations before encoding - Set to True when calling directly from outside the encoder's forward method - Set to False when calling from within forward() where inputs are already normalized Returns: Dictionary mapping image keys to their corresponding encoded features """ - if normalize: - obs = self.input_normalization(obs) batched = torch.cat([obs[k] for k in self.image_keys], dim=0) out = self.image_encoder(batched) chunks = torch.chunk(out, len(self.image_keys), dim=0) @@ -747,7 +712,6 @@ class CriticEnsemble(nn.Module): Args: encoder (SACObservationEncoder): encoder for observations. ensemble (List[CriticHead]): list of critic heads. - output_normalization (nn.Module): normalization layer for actions. init_final (float | None): optional initializer scale for final layers. Forward returns a tensor of shape (num_critics, batch_size) containing Q-values. @@ -757,13 +721,11 @@ class CriticEnsemble(nn.Module): self, encoder: SACObservationEncoder, ensemble: list[CriticHead], - output_normalization: nn.Module, init_final: float | None = None, ): super().__init__() self.encoder = encoder self.init_final = init_final - self.output_normalization = output_normalization self.critics = nn.ModuleList(ensemble) def forward( @@ -775,11 +737,6 @@ class CriticEnsemble(nn.Module): device = get_device_from_parameters(self) # Move each tensor in observations to device observations = {k: v.to(device) for k, v in observations.items()} - # NOTE: We normalize actions it helps for sample efficiency - actions: dict[str, torch.tensor] = {"action": actions} - # NOTE: Normalization layer took dict in input and outputs a dict that why - actions = self.output_normalization(actions)["action"] - actions = actions.to(device) obs_enc = self.encoder(observations, cache=observation_features) diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/sac/processor_sac.py new file mode 100644 index 00000000..9e8013d3 --- /dev/null +++ b/src/lerobot/policies/sac/processor_sac.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.sac.configuration_sac import SACConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_sac_pre_post_processors( + config: SACConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the SAC policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the SAC policy. + dataset_stats: A dictionary of statistics for normalization. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + # Add remaining processors + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/sac/reward_model/modeling_classifier.py b/src/lerobot/policies/sac/reward_model/modeling_classifier.py index cadd1c9f..ca501c3a 100644 --- a/src/lerobot/policies/sac/reward_model/modeling_classifier.py +++ b/src/lerobot/policies/sac/reward_model/modeling_classifier.py @@ -20,7 +20,6 @@ import torch from torch import Tensor, nn from lerobot.constants import OBS_IMAGE, REWARD -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig @@ -108,22 +107,12 @@ class Classifier(PreTrainedPolicy): def __init__( self, config: RewardClassifierConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): from transformers import AutoModel super().__init__(config) self.config = config - # Initialize normalization (standardized with the policy framework) - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - # Set up encoder encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) # Extract vision model if we're given a multimodal model @@ -247,10 +236,6 @@ class Classifier(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: """Standard forward pass for training compatible with train.py.""" - # Normalize inputs if needed - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) - # Extract images and labels images, labels = self.extract_images_and_labels(batch) diff --git a/src/lerobot/policies/sac/reward_model/processor_classifier.py b/src/lerobot/policies/sac/reward_model/processor_classifier.py new file mode 100644 index 00000000..c2a34eab --- /dev/null +++ b/src/lerobot/policies/sac/reward_model/processor_classifier.py @@ -0,0 +1,82 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.processor import ( + DeviceProcessorStep, + IdentityProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_classifier_processor( + config: RewardClassifierConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the reward classifier. + + The pre-processing pipeline prepares input data for the classifier by: + 1. Normalizing both input and output features based on dataset statistics. + 2. Moving the data to the specified device. + + The post-processing pipeline handles the classifier's output by: + 1. Moving the data to the CPU. + 2. Applying an identity step, as no unnormalization is needed for the output logits. + + Args: + config: The configuration object for the RewardClassifier. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + NormalizerProcessorStep( + features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + NormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device=config.device), + ] + output_steps = [DeviceProcessorStep(device="cpu"), IdentityProcessorStep()] + + return ( + PolicyProcessorPipeline( + steps=input_steps, + name="classifier_preprocessor", + ), + PolicyProcessorPipeline( + steps=output_steps, + name="classifier_postprocessor", + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 18f2fc58..48d4b231 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -53,21 +53,13 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") """ import math -import os -import re from collections import deque -import safetensors import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from transformers import AutoProcessor -from lerobot.constants import ACTION, OBS_STATE -from lerobot.policies.normalize import ( - Normalize, - Unnormalize, -) +from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel @@ -76,102 +68,6 @@ from lerobot.policies.utils import ( ) from lerobot.utils.utils import get_safe_dtype -# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker -_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") - - -def canonicalise(k: str) -> str: - """ - Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a - normalisation-buffer key. - """ - return _VARIANT_RE.sub(".buffer_", k) - - -def standardise_state_dict( - checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True -) -> tuple[dict[str, torch.Tensor], list[str]]: - """ - • Re-keys `checkpoint ` so that every entry matches the *reference* key set. - • If several variant keys collapse to the same canonical name we keep the - first one and log the collision. - • Returns the new dict + a list of entries that could not be matched. - """ - out, collisions, unmatched = {}, {}, [] - - for k, v in checkpoint.items(): - canon = canonicalise(k) - if canon in ref_keys: - if canon in out: # duplicate after collapsing - collisions.setdefault(canon, []).append(k) - else: - out[canon] = v - else: - unmatched.append(k) - - if verbose: - for canon, variants in collisions.items(): - print(f"[standardise_state_dict] '{canon}' ← {variants}") - if unmatched: - print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys") - - out.update({k: checkpoint[k] for k in unmatched}) - return out, unmatched - - -def rename_checkpoint_keys(checkpoint: dict, rename_str: str): - """ - Renames keys in a checkpoint dictionary based on the given rename string. - - Args: - checkpoint (dict): The checkpoint dictionary. - rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2". - - Returns: - dict: The modified checkpoint with renamed keys. - """ - - rename_dict = dict(pair.split("//") for pair in rename_str.split(",")) - - new_checkpoint = {} - for k, v in checkpoint.items(): - for old_key, new_key in rename_dict.items(): - if old_key in k: - k = k.replace(old_key, new_key) - new_checkpoint[k] = v - return new_checkpoint - - -def load_smolvla( - model: torch.nn.Module, - filename: str | os.PathLike, - *, - device: str = "cpu", - checkpoint_keys_mapping: str = "", -) -> torch.nn.Module: - state_dict = safetensors.torch.load_file(filename, device=device) - - # Optional user-supplied renames (e.g. "model._orig_mod.//model.") - if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping: - state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping) - - state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) - - # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset - norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs") - state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)} - - missing, unexpected = model.load_state_dict(state_dict, strict=False) - - if not all(key.startswith(norm_keys) for key in missing) or unexpected: - raise RuntimeError( - "SmolVLA %d missing / %d unexpected keys", - len(missing), - len(unexpected), - ) - - return model - def create_sinusoidal_pos_embedding( time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" @@ -326,28 +222,17 @@ class SmolVLAPolicy(PreTrainedPolicy): def __init__( self, config: SmolVLAConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: config: Policy configuration class instance or None, in which case the default instantiation of the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__(config) config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer self.model = VLAFlowMatching(config) self.reset() @@ -357,23 +242,6 @@ class SmolVLAPolicy(PreTrainedPolicy): ACTION: deque(maxlen=self.config.n_action_steps), } - # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues - @classmethod - def _load_as_safetensor( - cls, - model: "SmolVLAPolicy", - model_file: str, - map_location: str, - strict: bool, - ): - safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) - return load_smolvla( - model, - model_file, - device=map_location, - checkpoint_keys_mapping="model._orig_mod.//model.", - ) - def get_optim_params(self) -> dict: return self.parameters() @@ -389,7 +257,8 @@ class SmolVLAPolicy(PreTrainedPolicy): images, img_masks = self.prepare_images(batch) state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] + lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise) @@ -397,8 +266,6 @@ class SmolVLAPolicy(PreTrainedPolicy): original_action_dim = self.config.action_feature.shape[0] actions = actions[:, :, :original_action_dim] - actions = self.unnormalize_outputs({ACTION: actions})[ACTION] - if self.config.adapt_to_pi_aloha: actions = self._pi_aloha_encode_actions(actions) @@ -408,8 +275,6 @@ class SmolVLAPolicy(PreTrainedPolicy): if self.config.adapt_to_pi_aloha: batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch = self.normalize_inputs(batch) - return batch @torch.no_grad() @@ -450,11 +315,11 @@ class SmolVLAPolicy(PreTrainedPolicy): if self.config.adapt_to_pi_aloha: batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) + images, img_masks = self.prepare_images(batch) state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] + lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] actions = self.prepare_action(batch) actions_is_pad = batch.get("actions_id_pad") loss_dict = {} @@ -518,30 +383,6 @@ class SmolVLAPolicy(PreTrainedPolicy): img_masks.append(mask) return images, img_masks - def prepare_language(self, batch) -> tuple[Tensor, Tensor]: - """Tokenize the text input""" - device = batch[OBS_STATE].device - tasks = batch["task"] - if isinstance(tasks, str): - tasks = [tasks] - - if len(tasks) == 1: - tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] - - tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] - - tokenized_prompt = self.language_tokenizer.__call__( - tasks, - padding=self.config.pad_language_to, - padding_side="right", - max_length=self.config.tokenizer_max_length, - return_tensors="pt", - ) - lang_tokens = tokenized_prompt["input_ids"].to(device=device) - lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) - - return lang_tokens, lang_masks - def _pi_aloha_decode_state(self, state): # Flip the joints. for motor_idx in [1, 2, 8, 9]: diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py new file mode 100644 index 00000000..ac3cd462 --- /dev/null +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + ComplementaryDataProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_smolvla_pre_post_processors( + config: SmolVLAConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the SmolVLA policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Ensuring the language task description ends with a newline character. + 5. Tokenizing the language task description. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output actions to their original scale. + + Args: + config: The configuration object for the SmolVLA policy. + dataset_stats: A dictionary of statistics for normalization. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + SmolVLANewLineProcessor(), + TokenizerProcessorStep( + tokenizer_name=config.vlm_model_name, + padding=config.pad_language_to, + padding_side="right", + max_length=config.tokenizer_max_length, + ), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) + + +@ProcessorStepRegistry.register(name="smolvla_new_line_processor") +class SmolVLANewLineProcessor(ComplementaryDataProcessorStep): + """ + A processor step that ensures the 'task' description ends with a newline character. + + This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a + newline at the end of the prompt. It handles both single string tasks and lists + of string tasks. + """ + + def complementary_data(self, complementary_data): + if "task" not in complementary_data: + return complementary_data + + task = complementary_data["task"] + if task is None: + return complementary_data + + new_complementary_data = dict(complementary_data) + + # Handle both string and list of strings + if isinstance(task, str): + # Single string: add newline if not present + if not task.endswith("\n"): + new_complementary_data["task"] = f"{task}\n" + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + # List of strings: add newline to each if not present + new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] + # If task is neither string nor list of strings, leave unchanged + + return new_complementary_data + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index 7ba88e5e..e160310b 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -36,7 +36,6 @@ import torch.nn.functional as F # noqa: N812 from torch import Tensor from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues @@ -63,26 +62,19 @@ class TDMPCPolicy(PreTrainedPolicy): config_class = TDMPCConfig name = "tdmpc" - def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None): + def __init__( + self, + config: TDMPCConfig, + ): """ Args: config: Policy configuration class instance or None, in which case the default instantiation of the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__(config) config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.model = TDMPCTOLD(config) self.model_target = deepcopy(self.model) for param in self.model_target.parameters(): @@ -137,7 +129,6 @@ class TDMPCPolicy(PreTrainedPolicy): actions = torch.clamp(actions, -1, +1) - actions = self.unnormalize_outputs({ACTION: actions})[ACTION] return actions @torch.no_grad() @@ -147,11 +138,12 @@ class TDMPCPolicy(PreTrainedPolicy): if ACTION in batch: batch.pop(ACTION) - batch = self.normalize_inputs(batch) - if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] + # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out + if ACTION in batch: + batch.pop(ACTION) self._queues = populate_queues(self._queues, batch) @@ -320,11 +312,9 @@ class TDMPCPolicy(PreTrainedPolicy): """ device = get_device_from_parameters(self) - batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] - batch = self.normalize_targets(batch) info = {} diff --git a/src/lerobot/policies/tdmpc/processor_tdmpc.py b/src/lerobot/policies/tdmpc/processor_tdmpc.py new file mode 100644 index 00000000..75a7d4f7 --- /dev/null +++ b/src/lerobot/policies/tdmpc/processor_tdmpc.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python + +# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import torch + +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_tdmpc_pre_post_processors( + config: TDMPCConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the TDMPC policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the TDMPC policy. + dataset_stats: A dictionary of statistics for normalization. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index feb65bb4..bb6040e9 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -28,7 +28,6 @@ import torchvision from torch import Tensor, nn from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig @@ -48,7 +47,6 @@ class VQBeTPolicy(PreTrainedPolicy): def __init__( self, config: VQBeTConfig | None = None, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: @@ -61,14 +59,6 @@ class VQBeTPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.vqbet = VQBeTModel(config) self.reset() @@ -128,7 +118,6 @@ class VQBeTPolicy(PreTrainedPolicy): def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] - actions = self.unnormalize_outputs({ACTION: actions})[ACTION] return actions @torch.no_grad() @@ -142,10 +131,12 @@ class VQBeTPolicy(PreTrainedPolicy): # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out if ACTION in batch: batch.pop(ACTION) - batch = self.normalize_inputs(batch) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original # NOTE: It's important that this happens after stacking the images into a single key. batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out + if ACTION in batch: + batch.pop(ACTION) self._queues = populate_queues(self._queues, batch) @@ -165,10 +156,8 @@ class VQBeTPolicy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" - batch = self.normalize_inputs(batch) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) - batch = self.normalize_targets(batch) # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181) if not self.vqbet.action_head.vqvae_model.discretized.item(): # loss: total loss of training RVQ diff --git a/src/lerobot/policies/vqbet/processor_vqbet.py b/src/lerobot/policies/vqbet/processor_vqbet.py new file mode 100644 index 00000000..1c741cd3 --- /dev/null +++ b/src/lerobot/policies/vqbet/processor_vqbet.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python + +# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru +# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import torch + +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_vqbet_pre_post_processors( + config: VQBeTConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the VQ-BeT policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features, allowing customization to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the VQ-BeT policy. + dataset_stats: A dictionary of statistics for normalization. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), # Let the possibility to the user to rename the keys + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 8dd244c2..be11ac1a 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -14,41 +14,120 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .device_processor import DeviceProcessor -from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor -from .observation_processor import VanillaObservationProcessor -from .pipeline import ( - ActionProcessor, - DoneProcessor, +from .batch_processor import AddBatchDimensionProcessorStep +from .converters import ( + batch_to_transition, + create_transition, + transition_to_batch, +) +from .core import ( + EnvAction, EnvTransition, - IdentityProcessor, - InfoProcessor, - ObservationProcessor, + PolicyAction, + RobotAction, + RobotObservation, + TransitionKey, +) +from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep +from .device_processor import DeviceProcessorStep +from .factory import ( + make_default_processors, + make_default_robot_action_processor, + make_default_robot_observation_processor, + make_default_teleop_action_processor, +) +from .gym_action_processor import ( + Numpy2TorchActionProcessorStep, + Torch2NumpyActionProcessorStep, +) +from .hil_processor import ( + AddTeleopActionAsComplimentaryDataStep, + AddTeleopEventsAsInfoStep, + GripperPenaltyProcessorStep, + ImageCropResizeProcessorStep, + InterventionActionProcessorStep, + RewardClassifierProcessorStep, + TimeLimitProcessorStep, +) +from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep +from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats +from .observation_processor import VanillaObservationProcessorStep +from .pipeline import ( + ActionProcessorStep, + ComplementaryDataProcessorStep, + DataProcessorPipeline, + DoneProcessorStep, + IdentityProcessorStep, + InfoProcessorStep, + ObservationProcessorStep, + PolicyActionProcessorStep, + PolicyProcessorPipeline, + ProcessorKwargs, ProcessorStep, ProcessorStepRegistry, - RewardProcessor, - RobotProcessor, - TransitionKey, - TruncatedProcessor, + RewardProcessorStep, + RobotActionProcessorStep, + RobotProcessorPipeline, + TruncatedProcessorStep, ) -from .rename_processor import RenameProcessor +from .policy_robot_bridge import ( + PolicyActionToRobotActionProcessorStep, + RobotActionToPolicyActionProcessorStep, +) +from .rename_processor import RenameObservationsProcessorStep +from .tokenizer_processor import TokenizerProcessorStep __all__ = [ - "ActionProcessor", - "DeviceProcessor", - "DoneProcessor", + "ActionProcessorStep", + "AddTeleopActionAsComplimentaryDataStep", + "AddTeleopEventsAsInfoStep", + "ComplementaryDataProcessorStep", + "batch_to_transition", + "create_transition", + "DeviceProcessorStep", + "DoneProcessorStep", + "EnvAction", "EnvTransition", - "IdentityProcessor", - "InfoProcessor", - "NormalizerProcessor", - "UnnormalizerProcessor", - "ObservationProcessor", + "GripperPenaltyProcessorStep", + "hotswap_stats", + "IdentityProcessorStep", + "ImageCropResizeProcessorStep", + "InfoProcessorStep", + "InterventionActionProcessorStep", + "JointVelocityProcessorStep", + "make_default_processors", + "make_default_teleop_action_processor", + "make_default_robot_action_processor", + "make_default_robot_observation_processor", + "MapDeltaActionToRobotActionStep", + "MapTensorToDeltaActionDictStep", + "MotorCurrentProcessorStep", + "NormalizerProcessorStep", + "Numpy2TorchActionProcessorStep", + "ObservationProcessorStep", + "PolicyAction", + "PolicyActionProcessorStep", + "PolicyProcessorPipeline", + "ProcessorKwargs", "ProcessorStep", "ProcessorStepRegistry", - "RenameProcessor", - "RewardProcessor", - "RobotProcessor", + "RobotAction", + "RobotActionProcessorStep", + "RobotObservation", + "RenameObservationsProcessorStep", + "RewardClassifierProcessorStep", + "RewardProcessorStep", + "DataProcessorPipeline", + "TimeLimitProcessorStep", + "AddBatchDimensionProcessorStep", + "RobotProcessorPipeline", + "TokenizerProcessorStep", + "Torch2NumpyActionProcessorStep", + "RobotActionToPolicyActionProcessorStep", + "PolicyActionToRobotActionProcessorStep", + "transition_to_batch", "TransitionKey", - "TruncatedProcessor", - "VanillaObservationProcessor", + "TruncatedProcessorStep", + "UnnormalizerProcessorStep", + "VanillaObservationProcessorStep", ] diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py new file mode 100644 index 00000000..a563599c --- /dev/null +++ b/src/lerobot/processor/batch_processor.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script defines processor steps for adding a batch dimension to various components of an environment transition. + +These steps are designed to process actions, observations, and complementary data, making them suitable for batch processing by adding a leading dimension. This is a common requirement before feeding data into a neural network model. +""" + +from dataclasses import dataclass, field + +from torch import Tensor + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE + +from .core import EnvTransition, PolicyAction +from .pipeline import ( + ComplementaryDataProcessorStep, + ObservationProcessorStep, + PolicyActionProcessorStep, + ProcessorStep, + ProcessorStepRegistry, + TransitionKey, +) + + +@dataclass +@ProcessorStepRegistry.register(name="to_batch_processor_action") +class AddBatchDimensionActionStep(PolicyActionProcessorStep): + """ + Processor step to add a batch dimension to a 1D tensor action. + + This is useful for creating a batch of size 1 from a single action sample. + """ + + def action(self, action: PolicyAction) -> PolicyAction: + """ + Adds a batch dimension to the action if it's a 1D tensor. + + Args: + action: The action tensor. + + Returns: + The action tensor with an added batch dimension. + """ + if action.dim() != 1: + return action + return action.unsqueeze(0) + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Returns the input features unchanged. + + Adding a batch dimension does not alter the feature definition. + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ + return features + + +@dataclass +@ProcessorStepRegistry.register(name="to_batch_processor_observation") +class AddBatchDimensionObservationStep(ObservationProcessorStep): + """ + Processor step to add a batch dimension to observations. + + It handles different types of observations: + - State vectors (1D tensors). + - Single images (3D tensors). + - Dictionaries of multiple images (3D tensors). + """ + + def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]: + """ + Adds a batch dimension to tensor-based observations in the observation dictionary. + + Args: + observation: The observation dictionary. + + Returns: + The observation dictionary with batch dimensions added to tensors. + """ + # Process state observations - add batch dim if 1D + for state_key in [OBS_STATE, OBS_ENV_STATE]: + if state_key in observation: + state_value = observation[state_key] + if isinstance(state_value, Tensor) and state_value.dim() == 1: + observation[state_key] = state_value.unsqueeze(0) + + # Process single image observation - add batch dim if 3D + if OBS_IMAGE in observation: + image_value = observation[OBS_IMAGE] + if isinstance(image_value, Tensor) and image_value.dim() == 3: + observation[OBS_IMAGE] = image_value.unsqueeze(0) + + # Process multiple image observations - add batch dim if 3D + for key, value in observation.items(): + if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3: + observation[key] = value.unsqueeze(0) + return observation + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Returns the input features unchanged. + + Adding a batch dimension does not alter the feature definition. + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ + return features + + +@dataclass +@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data") +class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep): + """ + Processor step to add a batch dimension to complementary data fields. + + Handles specific keys like 'task', 'index', and 'task_index' to make them batched. + - 'task' (str) is wrapped in a list. + - 'index' and 'task_index' (0D tensors) get a batch dimension. + """ + + def complementary_data(self, complementary_data: dict) -> dict: + """ + Adds a batch dimension to specific fields in the complementary data dictionary. + + Args: + complementary_data: The complementary data dictionary. + + Returns: + The complementary data dictionary with batch dimensions added. + """ + # Process task field - wrap string in list to add batch dimension + if "task" in complementary_data: + task_value = complementary_data["task"] + if isinstance(task_value, str): + complementary_data["task"] = [task_value] + + # Process index field - add batch dim if 0D + if "index" in complementary_data: + index_value = complementary_data["index"] + if isinstance(index_value, Tensor) and index_value.dim() == 0: + complementary_data["index"] = index_value.unsqueeze(0) + + # Process task_index field - add batch dim if 0D + if "task_index" in complementary_data: + task_index_value = complementary_data["task_index"] + if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0: + complementary_data["task_index"] = task_index_value.unsqueeze(0) + return complementary_data + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Returns the input features unchanged. + + Adding a batch dimension does not alter the feature definition. + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ + return features + + +@dataclass +@ProcessorStepRegistry.register(name="to_batch_processor") +class AddBatchDimensionProcessorStep(ProcessorStep): + """ + A composite processor step that adds a batch dimension to the entire environment transition. + + This step combines individual processors for actions, observations, and complementary data + to create a batched transition (batch size 1) from a single-instance transition. + + Attributes: + to_batch_action_processor: Processor for the action component. + to_batch_observation_processor: Processor for the observation component. + to_batch_complementary_data_processor: Processor for the complementary data component. + """ + + to_batch_action_processor: AddBatchDimensionActionStep = field( + default_factory=AddBatchDimensionActionStep + ) + to_batch_observation_processor: AddBatchDimensionObservationStep = field( + default_factory=AddBatchDimensionObservationStep + ) + to_batch_complementary_data_processor: AddBatchDimensionComplementaryDataStep = field( + default_factory=AddBatchDimensionComplementaryDataStep + ) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Applies the batching process to all relevant parts of an environment transition. + + Args: + transition: The environment transition to process. + + Returns: + The environment transition with a batch dimension added. + """ + if transition[TransitionKey.ACTION] is not None: + transition = self.to_batch_action_processor(transition) + if transition[TransitionKey.OBSERVATION] is not None: + transition = self.to_batch_observation_processor(transition) + if transition[TransitionKey.COMPLEMENTARY_DATA] is not None: + transition = self.to_batch_complementary_data_processor(transition) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Returns the input features unchanged. + + Adding a batch dimension does not alter the feature definition. + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ + # NOTE: We ignore the batch dimension when transforming features + return features diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py new file mode 100644 index 00000000..440f8b1d --- /dev/null +++ b/src/lerobot/processor/converters.py @@ -0,0 +1,412 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from functools import singledispatch +from typing import Any + +import numpy as np +import torch + +from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey + + +@singledispatch +def to_tensor( + value: Any, + *, + dtype: torch.dtype | None = torch.float32, + device: torch.device | str | None = None, +) -> torch.Tensor: + """ + Convert various data types to PyTorch tensors with configurable options. + + This is a unified tensor conversion function using single dispatch to handle + different input types appropriately. + + Args: + value: Input value to convert (tensor, array, scalar, sequence, etc.). + dtype: Target tensor dtype. If None, preserves original dtype. + device: Target device for the tensor. + + Returns: + A PyTorch tensor. + + Raises: + TypeError: If the input type is not supported. + """ + raise TypeError(f"Unsupported type for tensor conversion: {type(value)}") + + +@to_tensor.register(torch.Tensor) +def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor: + """Handle conversion for existing PyTorch tensors.""" + if dtype is not None: + value = value.to(dtype=dtype) + if device is not None: + value = value.to(device=device) + return value + + +@to_tensor.register(np.ndarray) +def _( + value: np.ndarray, + *, + dtype=torch.float32, + device=None, + **kwargs, +) -> torch.Tensor: + """Handle conversion for numpy arrays.""" + # Check for numpy scalars (0-dimensional arrays) and treat them as scalars. + if value.ndim == 0: + # Numpy scalars should be converted to 0-dimensional tensors. + scalar_value = value.item() + return torch.tensor(scalar_value, dtype=dtype, device=device) + + # Create tensor from numpy array. + tensor = torch.from_numpy(value) + + # Apply dtype and device conversion if specified. + if dtype is not None: + tensor = tensor.to(dtype=dtype) + if device is not None: + tensor = tensor.to(device=device) + + return tensor + + +@to_tensor.register(int) +@to_tensor.register(float) +@to_tensor.register(np.integer) +@to_tensor.register(np.floating) +def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor: + """Handle conversion for scalar values including numpy scalars.""" + return torch.tensor(value, dtype=dtype, device=device) + + +@to_tensor.register(list) +@to_tensor.register(tuple) +def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor: + """Handle conversion for sequences (lists, tuples).""" + return torch.tensor(value, dtype=dtype, device=device) + + +@to_tensor.register(dict) +def _(value: dict, *, device=None, **kwargs) -> dict: + """Handle conversion for dictionaries by recursively converting their values to tensors.""" + if not value: + return {} + + result = {} + for key, sub_value in value.items(): + if sub_value is None: + continue + + if isinstance(sub_value, dict): + # Recursively process nested dictionaries. + result[key] = to_tensor( + sub_value, + device=device, + **kwargs, + ) + continue + + # Convert individual values to tensors. + result[key] = to_tensor( + sub_value, + device=device, + **kwargs, + ) + return result + + +def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | Any: + """ + Convert a PyTorch tensor to a numpy array or scalar if applicable. + + If the input is not a tensor, it is returned unchanged. + + Args: + x: The input, which can be a tensor or any other type. + + Returns: + A numpy array, a scalar, or the original input. + """ + if isinstance(x, torch.Tensor): + return x.item() if x.numel() == 1 else x.detach().cpu().numpy() + return x + + +def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: + """ + Extract complementary data from a batch dictionary. + + This includes padding flags, task description, and indices. + + Args: + batch: The batch dictionary. + + Returns: + A dictionary with the extracted complementary data. + """ + pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} + task_key = {"task": batch["task"]} if "task" in batch else {} + index_key = {"index": batch["index"]} if "index" in batch else {} + task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} + + return {**pad_keys, **task_key, **index_key, **task_index_key} + + +def create_transition( + observation: dict[str, Any] | None = None, + action: PolicyAction | RobotAction | None = None, + reward: float = 0.0, + done: bool = False, + truncated: bool = False, + info: dict[str, Any] | None = None, + complementary_data: dict[str, Any] | None = None, +) -> EnvTransition: + """ + Create an `EnvTransition` dictionary with sensible defaults. + + Args: + observation: Observation dictionary. + action: Action dictionary. + reward: Scalar reward value. + done: Episode termination flag. + truncated: Episode truncation flag. + info: Additional info dictionary. + complementary_data: Complementary data dictionary. + + Returns: + A complete `EnvTransition` dictionary. + """ + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info if info is not None else {}, + TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, + } + + +def robot_action_observation_to_transition( + action_observation: tuple[RobotAction, RobotObservation], +) -> EnvTransition: + """ + Convert a raw robot action and observation dictionary into a standardized `EnvTransition`. + + Args: + action: The raw action dictionary from a teleoperation device or controller. + observation: The raw observation dictionary from the environment. + + Returns: + An `EnvTransition` containing the formatted observation. + """ + if not isinstance(action_observation, tuple): + raise ValueError("action_observation should be a tuple type with an action and observation") + + action, observation = action_observation + + if action is not None and not isinstance(action, dict): + raise ValueError(f"Action should be a RobotAction type got {type(action)}") + + if observation is not None and not isinstance(observation, dict): + raise ValueError(f"Observation should be a RobotObservation type got {type(observation)}") + + return create_transition(action=action, observation=observation) + + +def robot_action_to_transition(action: RobotAction) -> EnvTransition: + """ + Convert a raw robot action dictionary into a standardized `EnvTransition`. + + Args: + action: The raw action dictionary from a teleoperation device or controller. + + Returns: + An `EnvTransition` containing the formatted action. + """ + if not isinstance(action, dict): + raise ValueError(f"Action should be a RobotAction type got {type(action)}") + return create_transition(action=action) + + +def observation_to_transition(observation: RobotObservation) -> EnvTransition: + """ + Convert a raw robot observation dictionary into a standardized `EnvTransition`. + + Args: + observation: The raw observation dictionary from the environment. + + Returns: + An `EnvTransition` containing the formatted observation. + """ + if not isinstance(observation, dict): + raise ValueError(f"Observation should be a RobotObservation type got {type(observation)}") + return create_transition(observation=observation) + + +def transition_to_robot_action(transition: EnvTransition) -> RobotAction: + """ + Extract a raw robot action dictionary for a robot from an `EnvTransition`. + + This function searches for keys in the format "action.*.pos" or "action.*.vel" + and converts them into a flat dictionary suitable for sending to a robot controller. + + Args: + transition: The `EnvTransition` containing the action. + + Returns: + A dictionary representing the raw robot action. + """ + if not isinstance(transition, dict): + raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") + + action = transition.get(TransitionKey.ACTION) + if not isinstance(action, dict): + raise ValueError(f"Action should be a RobotAction type (dict) got {type(action)}") + return transition.get(TransitionKey.ACTION) + + +def transition_to_policy_action(transition: EnvTransition) -> PolicyAction: + """ + Convert an `EnvTransition` to a `PolicyAction`. + """ + if not isinstance(transition, dict): + raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") + + action = transition.get(TransitionKey.ACTION) + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + return action + + +def transition_to_observation(transition: EnvTransition) -> RobotObservation: + """ + Convert an `EnvTransition` to a `RobotObservation`. + """ + if not isinstance(transition, dict): + raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") + + observation = transition.get(TransitionKey.OBSERVATION) + if not isinstance(observation, dict): + raise ValueError(f"Observation should be a RobotObservation (dict) type got {type(observation)}") + return observation + + +def policy_action_to_transition(action: PolicyAction) -> EnvTransition: + """ + Convert a `PolicyAction` to an `EnvTransition`. + """ + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + return create_transition(action=action) + + +def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: + """ + Convert a batch dictionary from a dataset/dataloader into an `EnvTransition`. + + This function maps recognized keys from a batch to the `EnvTransition` structure, + filling in missing keys with sensible defaults. + + Args: + batch: A batch dictionary. + + Returns: + An `EnvTransition` dictionary. + + Raises: + ValueError: If the input is not a dictionary. + """ + + # Validate input type. + if not isinstance(batch, dict): + raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}") + + action = batch.get("action") + if action is not None and not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + + # Extract observation and complementary data keys. + observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} + complementary_data = _extract_complementary_data(batch) + + return create_transition( + observation=observation_keys if observation_keys else None, + action=batch.get("action"), + reward=batch.get("next.reward", 0.0), + done=batch.get("next.done", False), + truncated=batch.get("next.truncated", False), + info=batch.get("info", {}), + complementary_data=complementary_data if complementary_data else None, + ) + + +def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: + """ + Convert an `EnvTransition` back to the canonical batch format used in LeRobot. + + This is the inverse of `batch_to_transition`. + + Args: + transition: The `EnvTransition` to convert. + + Returns: + A batch dictionary with canonical LeRobot field names. + """ + if not isinstance(transition, dict): + raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") + + batch = { + "action": transition.get(TransitionKey.ACTION), + "next.reward": transition.get(TransitionKey.REWARD, 0.0), + "next.done": transition.get(TransitionKey.DONE, False), + "next.truncated": transition.get(TransitionKey.TRUNCATED, False), + "info": transition.get(TransitionKey.INFO, {}), + } + + # Add complementary data. + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + if comp_data: + batch.update(comp_data) + + # Flatten observation dictionary. + observation = transition.get(TransitionKey.OBSERVATION) + if isinstance(observation, dict): + batch.update(observation) + + return batch + + +def identity_transition(transition: EnvTransition) -> EnvTransition: + """ + An identity function for transitions, returning the input unchanged. + + Useful as a default or placeholder in processing pipelines. + + Args: + tr: An `EnvTransition`. + + Returns: + The same `EnvTransition`. + """ + return transition diff --git a/src/lerobot/processor/core.py b/src/lerobot/processor/core.py new file mode 100644 index 00000000..679ba8c5 --- /dev/null +++ b/src/lerobot/processor/core.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from enum import Enum +from typing import Any, TypeAlias, TypedDict + +import numpy as np +import torch + + +class TransitionKey(str, Enum): + """Keys for accessing EnvTransition dictionary components.""" + + # TODO(Steven): Use consts + OBSERVATION = "observation" + ACTION = "action" + REWARD = "reward" + DONE = "done" + TRUNCATED = "truncated" + INFO = "info" + COMPLEMENTARY_DATA = "complementary_data" + + +PolicyAction: TypeAlias = torch.Tensor +RobotAction: TypeAlias = dict[str, Any] +EnvAction: TypeAlias = np.ndarray +RobotObservation: TypeAlias = dict[str, Any] + + +EnvTransition = TypedDict( + "EnvTransition", + { + TransitionKey.OBSERVATION.value: dict[str, Any] | None, + TransitionKey.ACTION.value: PolicyAction | RobotAction | EnvAction | None, + TransitionKey.REWARD.value: float | torch.Tensor | None, + TransitionKey.DONE.value: bool | torch.Tensor | None, + TransitionKey.TRUNCATED.value: bool | torch.Tensor | None, + TransitionKey.INFO.value: dict[str, Any] | None, + TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None, + }, +) diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py new file mode 100644 index 00000000..949ae78d --- /dev/null +++ b/src/lerobot/processor/delta_action_processor.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature + +from .core import PolicyAction, RobotAction +from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep + + +@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict") +@dataclass +class MapTensorToDeltaActionDictStep(ActionProcessorStep): + """ + Maps a flat action tensor from a policy to a structured delta action dictionary. + + This step is typically used after a policy outputs a continuous action vector. + It decomposes the vector into named components for delta movements of the + end-effector (x, y, z) and optionally the gripper. + + Attributes: + use_gripper: If True, assumes the 4th element of the tensor is the + gripper action. + """ + + use_gripper: bool = True + + def action(self, action: PolicyAction) -> RobotAction: + if not isinstance(action, PolicyAction): + raise ValueError("Only PolicyAction is supported for this processor") + + if action.dim() > 1: + action = action.squeeze(0) + + # TODO (maractingi): add rotation + delta_action = { + "delta_x": action[0].item(), + "delta_y": action[1].item(), + "delta_z": action[2].item(), + } + if self.use_gripper: + delta_action["gripper"] = action[3].item() + return delta_action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for axis in ["x", "y", "z"]: + features[PipelineFeatureType.ACTION][f"delta_{axis}"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + if self.use_gripper: + features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + return features + + +@ProcessorStepRegistry.register("map_delta_action_to_robot_action") +@dataclass +class MapDeltaActionToRobotActionStep(RobotActionProcessorStep): + """ + Maps delta actions from teleoperators to robot target actions for inverse kinematics. + + This step converts a dictionary of delta movements (e.g., from a gamepad) + into a target action format that includes an "enabled" flag and target + end-effector positions. It also handles scaling and noise filtering. + + Attributes: + position_scale: A factor to scale the delta position inputs. + rotation_scale: A factor to scale the delta rotation inputs (currently unused). + noise_threshold: The magnitude below which delta inputs are considered noise + and do not trigger an "enabled" state. + """ + + # Scale factors for delta movements + position_scale: float = 1.0 + rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard + noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise + + def action(self, action: RobotAction) -> RobotAction: + # NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy + # TODO (maractingi): changing this target_xyz naming convention from the teleop_devices + delta_x = action.pop("delta_x") + delta_y = action.pop("delta_y") + delta_z = action.pop("delta_z") + gripper = action.pop("gripper") + + # Determine if the teleoperator is actively providing input + # Consider enabled if any significant movement delta is detected + position_magnitude = (delta_x**2 + delta_y**2 + delta_z**2) ** 0.5 # Use Euclidean norm for position + enabled = position_magnitude > self.noise_threshold # Small threshold to avoid noise + + # Scale the deltas appropriately + scaled_delta_x = delta_x * self.position_scale + scaled_delta_y = delta_y * self.position_scale + scaled_delta_z = delta_z * self.position_scale + + # For gamepad/keyboard, we don't have rotation input, so set to 0 + # These could be extended in the future for more sophisticated teleoperators + target_wx = 0.0 + target_wy = 0.0 + target_wz = 0.0 + + # Update action with robot target format + action = { + "enabled": enabled, + "target_x": scaled_delta_x, + "target_y": scaled_delta_y, + "target_z": scaled_delta_z, + "target_wx": target_wx, + "target_wy": target_wy, + "target_wz": target_wz, + "gripper_vel": float(gripper), + } + + return action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for axis in ["x", "y", "z", "gripper"]: + features[PipelineFeatureType.ACTION].pop(f"delta_{axis}", None) + + for feat in ["enabled", "target_x", "target_y", "target_z", "target_wx", "target_wy", "target_wz"]: + features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + return features diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 0f00bb47..2d0dd088 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -13,70 +13,182 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +This script defines a processor step for moving environment transition data to a specific torch device and casting +its floating-point precision. +""" + from dataclasses import dataclass from typing import Any import torch -from lerobot.configs.types import PolicyFeature -from lerobot.processor.pipeline import EnvTransition, TransitionKey +from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.utils.utils import get_safe_torch_device +from .core import EnvTransition, PolicyAction, TransitionKey +from .pipeline import ProcessorStep, ProcessorStepRegistry + +@ProcessorStepRegistry.register("device_processor") @dataclass -class DeviceProcessor: - """Processes transitions by moving tensors to the specified device. +class DeviceProcessorStep(ProcessorStep): + """ + Processor step to move all tensors within an `EnvTransition` to a specified device and optionally cast their + floating-point data type. - This processor ensures that all tensors in the transition are moved to the - specified device (CPU or GPU) before they are returned. + This is crucial for preparing data for model training or inference on hardware like GPUs. + + Attributes: + device: The target device for tensors (e.g., "cpu", "cuda", "cuda:0"). + float_dtype: The target floating-point dtype as a string (e.g., "float32", "float16", "bfloat16"). + If None, the dtype is not changed. """ - device: torch.device = "cpu" + device: str = "cpu" + float_dtype: str | None = None + + DTYPE_MAPPING = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "bfloat16": torch.bfloat16, + "half": torch.float16, + "float": torch.float32, + "double": torch.float64, + } def __post_init__(self): - self.device = get_safe_torch_device(self.device) + """ + Initializes the processor by converting string configurations to torch objects. + + This method sets up the `torch.device`, determines if transfers can be non-blocking, and validates the + `float_dtype` string, converting it to a `torch.dtype` object. + """ + self.tensor_device: torch.device = get_safe_torch_device(self.device) + # Update device string in case a specific GPU was selected (e.g., "cuda" -> "cuda:0") + self.device = self.tensor_device.type self.non_blocking = "cuda" in str(self.device) + # Validate and convert float_dtype string to torch dtype + if self.float_dtype is not None: + if self.float_dtype not in self.DTYPE_MAPPING: + raise ValueError( + f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}" + ) + self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype] + else: + self._target_float_dtype = None + + def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Moves a single tensor to the target device and casts its dtype. + + Handles multi-GPU scenarios by not moving a tensor if it's already on a different CUDA device than + the target, which is useful when using frameworks like Accelerate. + + Args: + tensor: The input torch.Tensor. + + Returns: + The processed tensor on the correct device and with the correct dtype. + """ + # Determine target device + if tensor.is_cuda and self.tensor_device.type == "cuda": + # Both tensor and target are on GPU - preserve tensor's GPU placement. + # This handles multi-GPU scenarios where Accelerate has already placed + # tensors on the correct GPU for each process. + target_device = tensor.device + else: + # Either tensor is on CPU, or we're configured for CPU. + # In both cases, use the configured device. + target_device = self.tensor_device + + # MPS workaround: Convert float64 to float32 since MPS doesn't support float64 + if target_device.type == "mps" and tensor.dtype == torch.float64: + tensor = tensor.to(dtype=torch.float32) + + # Only move if necessary + if tensor.device != target_device: + tensor = tensor.to(target_device, non_blocking=self.non_blocking) + + # Convert float dtype if specified and tensor is floating point + if self._target_float_dtype is not None and tensor.is_floating_point(): + tensor = tensor.to(dtype=self._target_float_dtype) + + return tensor + def __call__(self, transition: EnvTransition) -> EnvTransition: - # Create a copy of the transition + """ + Applies device and dtype conversion to all tensors in an environment transition. + + It iterates through the transition, finds all `torch.Tensor` objects (including those nested in + dictionaries like `observation`), and processes them. + + Args: + transition: The input `EnvTransition` object. + + Returns: + A new `EnvTransition` object with all tensors moved to the target device and dtype. + """ new_transition = transition.copy() + action = new_transition.get(TransitionKey.ACTION) - # Process observation tensors - observation = transition.get(TransitionKey.OBSERVATION) - if observation is not None: - new_observation = { - k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v - for k, v in observation.items() - } - new_transition[TransitionKey.OBSERVATION] = new_observation + if action is not None and not isinstance(action, PolicyAction): + raise ValueError(f"If action is not None should be a PolicyAction type got {type(action)}") - # Process action tensor - action = transition.get(TransitionKey.ACTION) - if action is not None and isinstance(action, torch.Tensor): - new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking) + simple_tensor_keys = [ + TransitionKey.ACTION, + TransitionKey.REWARD, + TransitionKey.DONE, + TransitionKey.TRUNCATED, + ] - # Process reward tensor - reward = transition.get(TransitionKey.REWARD) - if reward is not None and isinstance(reward, torch.Tensor): - new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking) + dict_tensor_keys = [ + TransitionKey.OBSERVATION, + TransitionKey.COMPLEMENTARY_DATA, + ] - # Process done tensor - done = transition.get(TransitionKey.DONE) - if done is not None and isinstance(done, torch.Tensor): - new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking) + # Process simple, top-level tensors + for key in simple_tensor_keys: + value = transition.get(key) + if isinstance(value, torch.Tensor): + new_transition[key] = self._process_tensor(value) - # Process truncated tensor - truncated = transition.get(TransitionKey.TRUNCATED) - if truncated is not None and isinstance(truncated, torch.Tensor): - new_transition[TransitionKey.TRUNCATED] = truncated.to( - self.device, non_blocking=self.non_blocking - ) + # Process tensors nested within dictionaries + for key in dict_tensor_keys: + data_dict = transition.get(key) + if data_dict is not None: + new_data_dict = { + k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v + for k, v in data_dict.items() + } + new_transition[key] = new_data_dict return new_transition def get_config(self) -> dict[str, Any]: - """Return configuration for serialization.""" - return {"device": self.device} + """ + Returns the serializable configuration of the processor. - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + Returns: + A dictionary containing the device and float_dtype settings. + """ + return {"device": self.device, "float_dtype": self.float_dtype} + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Returns the input features unchanged. + + Device and dtype transformations do not alter the fundamental definition of the features (e.g., shape). + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ return features diff --git a/src/lerobot/processor/factory.py b/src/lerobot/processor/factory.py new file mode 100644 index 00000000..5a0c4107 --- /dev/null +++ b/src/lerobot/processor/factory.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .converters import ( + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from .core import RobotAction, RobotObservation +from .pipeline import IdentityProcessorStep, RobotProcessorPipeline + + +def make_default_teleop_action_processor() -> RobotProcessorPipeline[ + tuple[RobotAction, RobotObservation], RobotAction +]: + teleop_action_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[IdentityProcessorStep()], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + return teleop_action_processor + + +def make_default_robot_action_processor() -> RobotProcessorPipeline[ + tuple[RobotAction, RobotObservation], RobotAction +]: + robot_action_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[IdentityProcessorStep()], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + return robot_action_processor + + +def make_default_robot_observation_processor() -> RobotProcessorPipeline[RobotObservation, RobotObservation]: + robot_observation_processor = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[IdentityProcessorStep()], + to_transition=observation_to_transition, + to_output=transition_to_observation, + ) + return robot_observation_processor + + +def make_default_processors(): + teleop_action_processor = make_default_teleop_action_processor() + robot_action_processor = make_default_robot_action_processor() + robot_observation_processor = make_default_robot_observation_processor() + return (teleop_action_processor, robot_action_processor, robot_observation_processor) diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py new file mode 100644 index 00000000..8fa8cfd8 --- /dev/null +++ b/src/lerobot/processor/gym_action_processor.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature + +from .converters import to_tensor +from .core import EnvAction, EnvTransition, PolicyAction +from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry + + +@ProcessorStepRegistry.register("torch2numpy_action_processor") +@dataclass +class Torch2NumpyActionProcessorStep(ActionProcessorStep): + """ + Converts a PyTorch tensor action to a NumPy array. + + This step is useful when the output of a policy (typically a torch.Tensor) + needs to be passed to an environment or component that expects a NumPy array. + + Attributes: + squeeze_batch_dim: If True, removes the first dimension of the array + if it is of size 1. This is useful for converting a + batched action of size (1, D) to a single action of size (D,). + """ + + squeeze_batch_dim: bool = True + + def action(self, action: PolicyAction) -> EnvAction: + if not isinstance(action, PolicyAction): + raise TypeError( + f"Expected PolicyAction or None, got {type(action).__name__}. " + "Use appropriate processor for non-tensor actions." + ) + + numpy_action = action.detach().cpu().numpy() + + # Remove batch dimensions but preserve action dimensions. + # Only squeeze if there's a batch dimension (first dim == 1). + if ( + self.squeeze_batch_dim + and numpy_action.shape + and len(numpy_action.shape) > 1 + and numpy_action.shape[0] == 1 + ): + numpy_action = numpy_action.squeeze(0) + + return numpy_action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@ProcessorStepRegistry.register("numpy2torch_action_processor") +@dataclass +class Numpy2TorchActionProcessorStep(ProcessorStep): + """Converts a NumPy array action to a PyTorch tensor when action is present.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Converts numpy action to torch tensor if action exists, otherwise passes through.""" + from .core import TransitionKey + + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) + if action is not None: + if not isinstance(action, EnvAction): + raise TypeError( + f"Expected np.ndarray or None, got {type(action).__name__}. " + "Use appropriate processor for non-tensor actions." + ) + torch_action = to_tensor(action, dtype=None) # Preserve original dtype + new_transition[TransitionKey.ACTION] = torch_action + + return new_transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py new file mode 100644 index 00000000..47f69a97 --- /dev/null +++ b/src/lerobot/processor/hil_processor.py @@ -0,0 +1,596 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import time +from dataclasses import dataclass +from typing import Any, Protocol, TypeVar, runtime_checkable + +import numpy as np +import torch +import torchvision.transforms.functional as F # noqa: N812 + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.teleoperators.teleoperator import Teleoperator +from lerobot.teleoperators.utils import TeleopEvents + +from .core import EnvTransition, PolicyAction, TransitionKey +from .pipeline import ( + ComplementaryDataProcessorStep, + InfoProcessorStep, + ObservationProcessorStep, + ProcessorStep, + ProcessorStepRegistry, + TruncatedProcessorStep, +) + +GRIPPER_KEY = "gripper" +DISCRETE_PENALTY_KEY = "discrete_penalty" +TELEOP_ACTION_KEY = "teleop_action" + + +@runtime_checkable +class HasTeleopEvents(Protocol): + """ + Minimal protocol for objects that provide teleoperation events. + + This protocol defines the `get_teleop_events()` method, allowing processor + steps to interact with teleoperators that support event-based controls + (like episode termination or success flagging) without needing to know the + teleoperator's specific class. + """ + + def get_teleop_events(self) -> dict[str, Any]: + """ + Get extra control events from the teleoperator. + + Returns: + A dictionary containing control events such as: + - `is_intervention`: bool - Whether the human is currently intervening. + - `terminate_episode`: bool - Whether to terminate the current episode. + - `success`: bool - Whether the episode was successful. + - `rerecord_episode`: bool - Whether to rerecord the episode. + """ + ... + + +# Type variable constrained to Teleoperator subclasses that also implement events +TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator) + + +def _check_teleop_with_events(teleop: Teleoperator) -> None: + """ + Runtime check that a teleoperator implements the `HasTeleopEvents` protocol. + + Args: + teleop: The teleoperator instance to check. + + Raises: + TypeError: If the teleoperator does not have a `get_teleop_events` method. + """ + if not isinstance(teleop, HasTeleopEvents): + raise TypeError( + f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. " + f"Compatible teleoperators: GamepadTeleop, KeyboardEndEffectorTeleop" + ) + + +@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data") +@dataclass +class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep): + """ + Adds the raw action from a teleoperator to the transition's complementary data. + + This is useful for human-in-the-loop scenarios where the human's input needs to + be available to downstream processors, for example, to override a policy's action + during an intervention. + + Attributes: + teleop_device: The teleoperator instance to get the action from. + """ + + teleop_device: Teleoperator + + def complementary_data(self, complementary_data: dict) -> dict: + """ + Retrieves the teleoperator's action and adds it to the complementary data. + + Args: + complementary_data: The incoming complementary data dictionary. + + Returns: + A new dictionary with the teleoperator action added under the + `teleop_action` key. + """ + new_complementary_data = dict(complementary_data) + new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action() + return new_complementary_data + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@ProcessorStepRegistry.register("add_teleop_action_as_info") +@dataclass +class AddTeleopEventsAsInfoStep(InfoProcessorStep): + """ + Adds teleoperator control events (e.g., terminate, success) to the transition's info. + + This step extracts control events from teleoperators that support event-based + interaction, making these signals available to other parts of the system. + + Attributes: + teleop_device: An instance of a teleoperator that implements the + `HasTeleopEvents` protocol. + """ + + teleop_device: TeleopWithEvents + + def __post_init__(self): + """Validates that the provided teleoperator supports events after initialization.""" + _check_teleop_with_events(self.teleop_device) + + def info(self, info: dict) -> dict: + """ + Retrieves teleoperator events and updates the info dictionary. + + Args: + info: The incoming info dictionary. + + Returns: + A new dictionary including the teleoperator events. + """ + new_info = dict(info) + + teleop_events = self.teleop_device.get_teleop_events() + new_info.update(teleop_events) + return new_info + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@ProcessorStepRegistry.register("image_crop_resize_processor") +@dataclass +class ImageCropResizeProcessorStep(ObservationProcessorStep): + """ + Crops and/or resizes image observations. + + This step iterates through all image keys in an observation dictionary and applies + the specified transformations. It handles device placement, moving tensors to the + CPU if necessary for operations not supported on certain accelerators like MPS. + + Attributes: + crop_params_dict: A dictionary mapping image keys to cropping parameters + (top, left, height, width). + resize_size: A tuple (height, width) to resize all images to. + """ + + crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None + resize_size: tuple[int, int] | None = None + + def observation(self, observation: dict) -> dict: + """ + Applies cropping and resizing to all images in the observation dictionary. + + Args: + observation: The observation dictionary, potentially containing image tensors. + + Returns: + A new observation dictionary with transformed images. + """ + if self.resize_size is None and not self.crop_params_dict: + return observation + + new_observation = dict(observation) + + # Process all image keys in the observation + for key in observation: + if "image" not in key: + continue + + image = observation[key] + device = image.device + # NOTE (maractingi): No mps kernel for crop and resize, so we need to move to cpu + if device.type == "mps": + image = image.cpu() + # Crop if crop params are provided for this key + if self.crop_params_dict is not None and key in self.crop_params_dict: + crop_params = self.crop_params_dict[key] + image = F.crop(image, *crop_params) + if self.resize_size is not None: + image = F.resize(image, self.resize_size) + image = image.clamp(0.0, 1.0) + new_observation[key] = image.to(device) + + return new_observation + + def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary with the crop parameters and resize dimensions. + """ + return { + "crop_params_dict": self.crop_params_dict, + "resize_size": self.resize_size, + } + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Updates the image feature shapes in the policy features dictionary if resizing is applied. + + Args: + features: The policy features dictionary. + + Returns: + The updated policy features dictionary with new image shapes. + """ + if self.resize_size is None: + return features + for key in features[PipelineFeatureType.OBSERVATION]: + if "image" in key: + nb_channel = features[PipelineFeatureType.OBSERVATION][key].shape[0] + features[PipelineFeatureType.OBSERVATION][key] = PolicyFeature( + type=features[PipelineFeatureType.OBSERVATION][key].type, + shape=(nb_channel, *self.resize_size), + ) + return features + + +@dataclass +@ProcessorStepRegistry.register("time_limit_processor") +class TimeLimitProcessorStep(TruncatedProcessorStep): + """ + Tracks episode steps and enforces a time limit by truncating the episode. + + Attributes: + max_episode_steps: The maximum number of steps allowed per episode. + current_step: The current step count for the active episode. + """ + + max_episode_steps: int + current_step: int = 0 + + def truncated(self, truncated: bool) -> bool: + """ + Increments the step counter and sets the truncated flag if the time limit is reached. + + Args: + truncated: The incoming truncated flag. + + Returns: + True if the episode step limit is reached, otherwise the incoming value. + """ + self.current_step += 1 + if self.current_step >= self.max_episode_steps: + truncated = True + # TODO (steven): missing an else truncated = False? + return truncated + + def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the `max_episode_steps`. + """ + return { + "max_episode_steps": self.max_episode_steps, + } + + def reset(self) -> None: + """Resets the step counter, typically called at the start of a new episode.""" + self.current_step = 0 + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@dataclass +@ProcessorStepRegistry.register("gripper_penalty_processor") +class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): + """ + Applies a penalty for inefficient gripper usage. + + This step penalizes actions that attempt to close an already closed gripper or + open an already open one, based on position thresholds. + + Attributes: + penalty: The negative reward value to apply. + max_gripper_pos: The maximum position value for the gripper, used for normalization. + """ + + penalty: float = -0.01 + max_gripper_pos: float = 30.0 + + def complementary_data(self, complementary_data: dict) -> dict: + """ + Calculates the gripper penalty and adds it to the complementary data. + + Args: + complementary_data: The incoming complementary data, which should contain + raw joint positions. + + Returns: + A new complementary data dictionary with the `discrete_penalty` key added. + """ + action = self.transition.get(TransitionKey.ACTION) + + raw_joint_positions = complementary_data.get("raw_joint_positions", None) + if raw_joint_positions is None: + return complementary_data + + current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None) + if current_gripper_pos is None: + return complementary_data + + # Gripper action is a PolicyAction at this stage + gripper_action = action[-1].item() + gripper_action_normalized = gripper_action / self.max_gripper_pos + + # Normalize gripper state and action + gripper_state_normalized = current_gripper_pos / self.max_gripper_pos + + # Calculate penalty boolean as in original + gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or ( + gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5 + ) + + gripper_penalty = self.penalty * int(gripper_penalty_bool) + + # Create new complementary data with penalty info + new_complementary_data = dict(complementary_data) + new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty + + return new_complementary_data + + def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the penalty value and max gripper position. + """ + return { + "penalty": self.penalty, + "max_gripper_pos": self.max_gripper_pos, + } + + def reset(self) -> None: + """Resets the processor's internal state.""" + pass + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@dataclass +@ProcessorStepRegistry.register("intervention_action_processor") +class InterventionActionProcessorStep(ProcessorStep): + """ + Handles human intervention, overriding policy actions and managing episode termination. + + When an intervention is detected (via teleoperator events in the `info` dict), + this step replaces the policy's action with the human's teleoperated action. + It also processes signals to terminate the episode or flag success. + + Attributes: + use_gripper: Whether to include the gripper in the teleoperated action. + terminate_on_success: If True, automatically sets the `done` flag when a + `success` event is received. + """ + + use_gripper: bool = False + terminate_on_success: bool = True + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Processes the transition to handle interventions. + + Args: + transition: The incoming environment transition. + + Returns: + The modified transition, potentially with an overridden action, updated + reward, and termination status. + """ + action = transition.get(TransitionKey.ACTION) + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + + # Get intervention signals from complementary data + info = transition.get(TransitionKey.INFO, {}) + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {}) + is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False) + terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False) + success = info.get(TeleopEvents.SUCCESS, False) + rerecord_episode = info.get(TeleopEvents.RERECORD_EPISODE, False) + + new_transition = transition.copy() + + # Override action if intervention is active + if is_intervention and teleop_action is not None: + if isinstance(teleop_action, dict): + # Convert teleop_action dict to tensor format + action_list = [ + teleop_action.get("delta_x", 0.0), + teleop_action.get("delta_y", 0.0), + teleop_action.get("delta_z", 0.0), + ] + if self.use_gripper: + action_list.append(teleop_action.get(GRIPPER_KEY, 1.0)) + elif isinstance(teleop_action, np.ndarray): + action_list = teleop_action.tolist() + else: + action_list = teleop_action + + teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device) + new_transition[TransitionKey.ACTION] = teleop_action_tensor + + # Handle episode termination + new_transition[TransitionKey.DONE] = bool(terminate_episode) or ( + self.terminate_on_success and success + ) + new_transition[TransitionKey.REWARD] = float(success) + + # Update info with intervention metadata + info = new_transition.get(TransitionKey.INFO, {}) + info[TeleopEvents.IS_INTERVENTION] = is_intervention + info[TeleopEvents.RERECORD_EPISODE] = rerecord_episode + info[TeleopEvents.SUCCESS] = success + new_transition[TransitionKey.INFO] = info + + # Update complementary data with teleop action + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION) + new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + + return new_transition + + def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the step's configuration attributes. + """ + return { + "use_gripper": self.use_gripper, + "terminate_on_success": self.terminate_on_success, + } + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@dataclass +@ProcessorStepRegistry.register("reward_classifier_processor") +class RewardClassifierProcessorStep(ProcessorStep): + """ + Applies a pretrained reward classifier to image observations to predict success. + + This step uses a model to determine if the current state is successful, updating + the reward and potentially terminating the episode. + + Attributes: + pretrained_path: Path to the pretrained reward classifier model. + device: The device to run the classifier on. + success_threshold: The probability threshold to consider a prediction as successful. + success_reward: The reward value to assign on success. + terminate_on_success: If True, terminates the episode upon successful classification. + reward_classifier: The loaded classifier model instance. + """ + + pretrained_path: str | None = None + device: str = "cpu" + success_threshold: float = 0.5 + success_reward: float = 1.0 + terminate_on_success: bool = True + + reward_classifier: Any = None + + def __post_init__(self): + """Initializes the reward classifier model after the dataclass is created.""" + if self.pretrained_path is not None: + from lerobot.policies.sac.reward_model.modeling_classifier import Classifier + + self.reward_classifier = Classifier.from_pretrained(self.pretrained_path) + self.reward_classifier.to(self.device) + self.reward_classifier.eval() + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Processes a transition, applying the reward classifier to its image observations. + + Args: + transition: The incoming environment transition. + + Returns: + The modified transition with an updated reward and done flag based on the + classifier's prediction. + """ + new_transition = transition.copy() + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is None or self.reward_classifier is None: + return new_transition + + # Extract images from observation + images = {key: value for key, value in observation.items() if "image" in key} + + if not images: + return new_transition + + # Run reward classifier + start_time = time.perf_counter() + with torch.inference_mode(): + success = self.reward_classifier.predict_reward(images, threshold=self.success_threshold) + + classifier_frequency = 1 / (time.perf_counter() - start_time) + + # Calculate reward and termination + reward = new_transition.get(TransitionKey.REWARD, 0.0) + terminated = new_transition.get(TransitionKey.DONE, False) + + if math.isclose(success, 1, abs_tol=1e-2): + reward = self.success_reward + if self.terminate_on_success: + terminated = True + + # Update transition + new_transition[TransitionKey.REWARD] = reward + new_transition[TransitionKey.DONE] = terminated + + # Update info with classifier frequency + info = new_transition.get(TransitionKey.INFO, {}) + info["reward_classifier_frequency"] = classifier_frequency + new_transition[TransitionKey.INFO] = info + + return new_transition + + def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the step's configuration attributes. + """ + return { + "device": self.device, + "success_threshold": self.success_threshold, + "success_reward": self.success_reward, + "terminate_on_success": self.terminate_on_success, + } + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features diff --git a/src/lerobot/processor/joint_observations_processor.py b/src/lerobot/processor/joint_observations_processor.py new file mode 100644 index 00000000..ab3c6ecc --- /dev/null +++ b/src/lerobot/processor/joint_observations_processor.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any + +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.constants import OBS_STATE +from lerobot.processor.pipeline import ( + ObservationProcessorStep, + ProcessorStepRegistry, +) +from lerobot.robots import Robot + + +@dataclass +@ProcessorStepRegistry.register("joint_velocity_processor") +class JointVelocityProcessorStep(ObservationProcessorStep): + """ + Calculates and appends joint velocity information to the observation state. + + This step computes the velocity of each joint by calculating the finite + difference between the current and the last observed joint positions. The + resulting velocity vector is then concatenated to the original state vector. + + Attributes: + dt: The time step (delta time) in seconds between observations, used for + calculating velocity. + last_joint_positions: Stores the joint positions from the previous step + to enable velocity calculation. + """ + + dt: float = 0.1 + + last_joint_positions: torch.Tensor | None = None + + def observation(self, observation: dict) -> dict: + """ + Computes joint velocities and adds them to the observation state. + + Args: + observation: The input observation dictionary, expected to contain + an `observation.state` key with joint positions. + + Returns: + A new observation dictionary with the `observation.state` tensor + extended to include joint velocities. + + Raises: + ValueError: If `observation.state` is not found in the observation. + """ + # Get current joint positions (assuming they're in observation.state) + current_positions = observation.get(OBS_STATE) + if current_positions is None: + raise ValueError(f"{OBS_STATE} is not in observation") + + # Initialize last joint positions if not already set + if self.last_joint_positions is None: + self.last_joint_positions = current_positions.clone() + joint_velocities = torch.zeros_like(current_positions) + else: + # Compute velocities + joint_velocities = (current_positions - self.last_joint_positions) / self.dt + + self.last_joint_positions = current_positions.clone() + + # Extend observation with velocities + extended_state = torch.cat([current_positions, joint_velocities], dim=-1) + + # Create new observation dict + new_observation = dict(observation) + new_observation[OBS_STATE] = extended_state + + return new_observation + + def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the time step `dt`. + """ + return { + "dt": self.dt, + } + + def reset(self) -> None: + """Resets the internal state, clearing the last known joint positions.""" + self.last_joint_positions = None + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Updates the `observation.state` feature to reflect the added velocities. + + This method doubles the size of the first dimension of the `observation.state` + shape to account for the concatenation of position and velocity vectors. + + Args: + features: The policy features dictionary. + + Returns: + The updated policy features dictionary. + """ + if OBS_STATE in features[PipelineFeatureType.OBSERVATION]: + original_feature = features[PipelineFeatureType.OBSERVATION][OBS_STATE] + # Double the shape to account for positions + velocities + new_shape = (original_feature.shape[0] * 2,) + original_feature.shape[1:] + + features[PipelineFeatureType.OBSERVATION][OBS_STATE] = PolicyFeature( + type=original_feature.type, shape=new_shape + ) + return features + + +@dataclass +@ProcessorStepRegistry.register("current_processor") +class MotorCurrentProcessorStep(ObservationProcessorStep): + """ + Reads motor currents from a robot and appends them to the observation state. + + This step queries the robot's hardware interface to get the present current + for each motor and concatenates this information to the existing state vector. + + Attributes: + robot: An instance of a `lerobot` Robot class that provides access to + the hardware bus. + """ + + robot: Robot | None = None + + def observation(self, observation: dict) -> dict: + """ + Fetches motor currents and adds them to the observation state. + + Args: + observation: The input observation dictionary. + + Returns: + A new observation dictionary with the `observation.state` tensor + extended to include motor currents. + + Raises: + ValueError: If the `robot` attribute has not been set. + """ + # Get current values from robot state + if self.robot is None: + raise ValueError("Robot is not set") + + present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined] + motor_currents = torch.tensor( + [present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined] + dtype=torch.float32, + ).unsqueeze(0) + + current_state = observation.get(OBS_STATE) + if current_state is None: + return observation + + extended_state = torch.cat([current_state, motor_currents], dim=-1) + + # Create new observation dict + new_observation = dict(observation) + new_observation[OBS_STATE] = extended_state + + return new_observation + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Updates the `observation.state` feature to reflect the added motor currents. + + This method increases the size of the first dimension of the `observation.state` + shape by the number of motors in the robot. + + Args: + features: The policy features dictionary. + + Returns: + The updated policy features dictionary. + """ + if OBS_STATE in features[PipelineFeatureType.OBSERVATION] and self.robot is not None: + original_feature = features[PipelineFeatureType.OBSERVATION][OBS_STATE] + # Add motor current dimensions to the original state shape + num_motors = 0 + if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined] + num_motors = len(self.robot.bus.motors) # type: ignore[attr-defined] + + if num_motors > 0: + new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:] + features[PipelineFeatureType.OBSERVATION][OBS_STATE] = PolicyFeature( + type=original_feature.type, shape=new_shape + ) + return features diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py new file mode 100644 index 00000000..131f799d --- /dev/null +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -0,0 +1,646 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A generic script to migrate LeRobot policies with built-in normalization layers to the new +pipeline-based processor system. + +This script performs the following steps: +1. Loads a pretrained policy model and its configuration from a local path or the + Hugging Face Hub. +2. Scans the model's state dictionary to extract normalization statistics (e.g., mean, + std, min, max) for all features. +3. Creates two new processor pipelines: + - A preprocessor that normalizes inputs (observations) and outputs (actions). + - A postprocessor that unnormalizes outputs (actions) for inference. +4. Removes the original normalization layers from the model's state dictionary, + creating a "clean" model. +5. Saves the new clean model, the preprocessor, the postprocessor, and a generated + model card to a new directory. +6. Optionally pushes all the new artifacts to the Hugging Face Hub. + +Usage: + python src/lerobot/processor/migrate_policy_normalization.py \ + --pretrained-path lerobot/act_aloha_sim_transfer_cube_human \ + --push-to-hub \ + --branch main + +Note: This script now uses the modern `make_pre_post_processors` and `make_policy_config` +factory functions from `lerobot.policies.factory` to create processors and configurations, +ensuring consistency with the current codebase. + +The script extracts normalization statistics from the old model's state_dict, creates clean +processor pipelines using the factory functions, and saves a migrated model that is compatible +with the new PolicyProcessorPipeline architecture. +""" + +import argparse +import json +import os +from pathlib import Path +from typing import Any + +import torch +from huggingface_hub import HfApi, hf_hub_download +from safetensors.torch import load_file as load_safetensors + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors + + +def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]: + """ + Scans a model's state_dict to find and extract normalization statistics. + + This function identifies keys corresponding to normalization layers (e.g., those + for mean, std, min, max) based on a set of predefined patterns and organizes + them into a nested dictionary. + + Args: + state_dict: The state dictionary of a pretrained policy model. + + Returns: + A nested dictionary where outer keys are feature names (e.g., + 'observation.state') and inner keys are statistic types ('mean', 'std'), + mapping to their corresponding tensor values. + """ + stats = {} + + # Define patterns to match and their prefixes to remove + normalization_patterns = [ + "normalize_inputs.buffer_", + "unnormalize_outputs.buffer_", + "normalize_targets.buffer_", + "normalize.", # Must come after normalize_* patterns + "unnormalize.", # Must come after unnormalize_* patterns + "input_normalizer.", + "output_normalizer.", + "normalalize_inputs.", + "unnormalize_outputs.", + "normalize_targets.", + "unnormalize_targets.", + ] + + # Process each key in state_dict + for key, tensor in state_dict.items(): + # Try each pattern + for pattern in normalization_patterns: + if key.startswith(pattern): + # Extract the remaining part after the pattern + remaining = key[len(pattern) :] + parts = remaining.split(".") + + # Need at least feature name and stat type + if len(parts) >= 2: + # Last part is the stat type (mean, std, min, max, etc.) + stat_type = parts[-1] + # Everything else is the feature name + feature_name = ".".join(parts[:-1]).replace("_", ".") + + # Add to stats + if feature_name not in stats: + stats[feature_name] = {} + stats[feature_name][stat_type] = tensor.clone() + + # Only process the first matching pattern + break + + return stats + + +def detect_features_and_norm_modes( + config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]] +) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]: + """ + Infers policy features and normalization modes from the model config and stats. + + This function first attempts to find feature definitions and normalization + mappings directly from the policy's configuration file. If this information is + not present, it infers it from the extracted normalization statistics, using + tensor shapes to determine feature shapes and the presence of specific stat + keys (e.g., 'mean'/'std' vs 'min'/'max') to determine the normalization mode. + It applies sensible defaults if inference is not possible. + + Args: + config: The policy's configuration dictionary from `config.json`. + stats: The normalization statistics extracted from the model's state_dict. + + Returns: + A tuple containing: + - A dictionary mapping feature names to `PolicyFeature` objects. + - A dictionary mapping `FeatureType` enums to `NormalizationMode` enums. + """ + features = {} + norm_modes = {} + + # First, check if there's a normalization_mapping in the config + if "normalization_mapping" in config: + print(f"Found normalization_mapping in config: {config['normalization_mapping']}") + # Extract normalization modes from config + for feature_type_str, mode_str in config["normalization_mapping"].items(): + # Convert string to FeatureType enum + try: + if feature_type_str == "VISUAL": + feature_type = FeatureType.VISUAL + elif feature_type_str == "STATE": + feature_type = FeatureType.STATE + elif feature_type_str == "ACTION": + feature_type = FeatureType.ACTION + else: + print(f"Warning: Unknown feature type '{feature_type_str}', skipping") + continue + except (AttributeError, ValueError): + print(f"Warning: Could not parse feature type '{feature_type_str}', skipping") + continue + + # Convert string to NormalizationMode enum + try: + if mode_str == "MEAN_STD": + mode = NormalizationMode.MEAN_STD + elif mode_str == "MIN_MAX": + mode = NormalizationMode.MIN_MAX + elif mode_str == "IDENTITY": + mode = NormalizationMode.IDENTITY + else: + print( + f"Warning: Unknown normalization mode '{mode_str}' for feature type '{feature_type_str}'" + ) + continue + except (AttributeError, ValueError): + print(f"Warning: Could not parse normalization mode '{mode_str}', skipping") + continue + + norm_modes[feature_type] = mode + + # Try to extract from config + if "features" in config: + for key, feature_config in config["features"].items(): + shape = feature_config.get("shape", feature_config.get("dim")) + shape = (shape,) if isinstance(shape, int) else tuple(shape) + + # Determine feature type + if "image" in key or "visual" in key: + feature_type = FeatureType.VISUAL + elif "state" in key: + feature_type = FeatureType.STATE + elif "action" in key: + feature_type = FeatureType.ACTION + else: + feature_type = FeatureType.STATE # Default + + features[key] = PolicyFeature(feature_type, shape) + + # If no features in config, infer from stats + if not features: + for key, stat_dict in stats.items(): + # Get shape from any stat tensor + tensor = next(iter(stat_dict.values())) + shape = tuple(tensor.shape) + + # Determine feature type based on key + if "image" in key or "visual" in key or "pixels" in key: + feature_type = FeatureType.VISUAL + elif "state" in key or "joint" in key or "position" in key: + feature_type = FeatureType.STATE + elif "action" in key: + feature_type = FeatureType.ACTION + else: + feature_type = FeatureType.STATE + + features[key] = PolicyFeature(feature_type, shape) + + # If normalization modes weren't in config, determine based on available stats + if not norm_modes: + for key, stat_dict in stats.items(): + if key in features: + if "mean" in stat_dict and "std" in stat_dict: + feature_type = features[key].type + if feature_type not in norm_modes: + norm_modes[feature_type] = NormalizationMode.MEAN_STD + elif "min" in stat_dict and "max" in stat_dict: + feature_type = features[key].type + if feature_type not in norm_modes: + norm_modes[feature_type] = NormalizationMode.MIN_MAX + + # Default normalization modes if not detected + if FeatureType.VISUAL not in norm_modes: + norm_modes[FeatureType.VISUAL] = NormalizationMode.MEAN_STD + if FeatureType.STATE not in norm_modes: + norm_modes[FeatureType.STATE] = NormalizationMode.MIN_MAX + if FeatureType.ACTION not in norm_modes: + norm_modes[FeatureType.ACTION] = NormalizationMode.MEAN_STD + + return features, norm_modes + + +def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Creates a new state_dict with all normalization-related layers removed. + + This function filters the original state dictionary, excluding any keys that + match a set of predefined patterns associated with normalization modules. + + Args: + state_dict: The original model state dictionary. + + Returns: + A new state dictionary containing only the core model weights, without + any normalization parameters. + """ + new_state_dict = {} + + # Patterns to remove + remove_patterns = [ + "normalize_inputs.", + "unnormalize_outputs.", + "normalize_targets.", # Added pattern for target normalization + "normalize.", + "unnormalize.", + "input_normalizer.", + "output_normalizer.", + "normalizer.", + ] + + for key, tensor in state_dict.items(): + should_remove = any(pattern in key for pattern in remove_patterns) + if not should_remove: + new_state_dict[key] = tensor + + return new_state_dict + + +def clean_state_dict( + state_dict: dict[str, torch.Tensor], remove_str: str = "._orig_mod" +) -> dict[str, torch.Tensor]: + """ + Remove a substring (e.g. '._orig_mod') from all keys in a state dict. + + Args: + state_dict (dict): The original state dict. + remove_str (str): The substring to remove from the keys. + + Returns: + dict: A new state dict with cleaned keys. + """ + new_state_dict = {} + for k, v in state_dict.items(): + new_k = k.replace(remove_str, "") + new_state_dict[new_k] = v + return new_state_dict + + +def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]: + """ + Converts a feature dictionary from the old config format to the new `PolicyFeature` format. + + Args: + features_dict: The feature dictionary in the old format, where values are + simple dictionaries (e.g., `{"shape": [7]}`). + + Returns: + A dictionary mapping feature names to `PolicyFeature` dataclass objects. + """ + converted_features = {} + + for key, feature_dict in features_dict.items(): + # Determine feature type based on key + if "image" in key or "visual" in key: + feature_type = FeatureType.VISUAL + elif "state" in key: + feature_type = FeatureType.STATE + elif "action" in key: + feature_type = FeatureType.ACTION + else: + feature_type = FeatureType.STATE + + # Get shape from feature dict + shape = feature_dict.get("shape", feature_dict.get("dim")) + shape = (shape,) if isinstance(shape, int) else tuple(shape) if shape is not None else () + + converted_features[key] = PolicyFeature(feature_type, shape) + + return converted_features + + +def load_model_from_hub( + repo_id: str, revision: str | None = None +) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: + """ + Downloads and loads a model's state_dict and configs from the Hugging Face Hub. + + Args: + repo_id: The repository ID on the Hub (e.g., 'lerobot/aloha'). + revision: The specific git revision (branch, tag, or commit hash) to use. + + Returns: + A tuple containing the model's state dictionary, the policy configuration, + and the training configuration. + """ + # Download files. + safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision) + + config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision) + train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision) + + # Load state_dict + state_dict = load_safetensors(safetensors_path) + + # Load config + with open(config_path) as f: + config = json.load(f) + + with open(train_config_path) as f: + train_config = json.load(f) + + return state_dict, config, train_config + + +def main(): + parser = argparse.ArgumentParser( + description="Migrate policy models with normalization layers to new pipeline system" + ) + parser.add_argument( + "--pretrained-path", + type=str, + required=True, + help="Path to pretrained model (hub repo or local directory)", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Output directory for migrated model (default: same as pretrained-path)", + ) + parser.add_argument("--push-to-hub", action="store_true", help="Push migrated model to hub") + parser.add_argument( + "--hub-repo-id", + type=str, + default=None, + help="Hub repository ID for pushing (default: same as pretrained-path)", + ) + parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load") + parser.add_argument("--private", action="store_true", help="Make the hub repository private") + parser.add_argument( + "--branch", + type=str, + default=None, + help="Git branch to use when pushing to hub. If specified, a PR will be created automatically (default: push directly to main)", + ) + + args = parser.parse_args() + + # Load model and config + print(f"Loading model from {args.pretrained_path}...") + if os.path.isdir(args.pretrained_path): + # Local directory + state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors")) + with open(os.path.join(args.pretrained_path, "config.json")) as f: + config = json.load(f) + with open(os.path.join(args.pretrained_path, "train_config.json")) as f: + train_config = json.load(f) + else: + # Hub repository + state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision) + + # Extract normalization statistics + print("Extracting normalization statistics...") + stats = extract_normalization_stats(state_dict) + + print(f"Found normalization statistics for: {list(stats.keys())}") + + # Detect input features and normalization modes + print("Detecting features and normalization modes...") + features, norm_map = detect_features_and_norm_modes(config, stats) + + print(f"Detected features: {list(features.keys())}") + print(f"Normalization modes: {norm_map}") + + # Remove normalization layers from state_dict + print("Removing normalization layers from model...") + new_state_dict = remove_normalization_layers(state_dict) + new_state_dict = clean_state_dict(new_state_dict, remove_str="._orig_mod") + + removed_keys = set(state_dict.keys()) - set(new_state_dict.keys()) + if removed_keys: + print(f"Removed {len(removed_keys)} normalization layer keys") + + # Determine output path + if args.output_dir: + output_dir = Path(args.output_dir) + else: + if os.path.isdir(args.pretrained_path): + output_dir = Path(args.pretrained_path).parent / f"{Path(args.pretrained_path).name}_migrated" + else: + output_dir = Path(f"./{args.pretrained_path.replace('/', '_')}_migrated") + + output_dir.mkdir(parents=True, exist_ok=True) + + # Extract policy type from config + if "type" not in config: + raise ValueError("Policy type not found in config.json. The config must contain a 'type' field.") + + policy_type = config["type"] + print(f"Detected policy type: {policy_type}") + + # Clean up config - remove fields that shouldn't be passed to config constructor + cleaned_config = dict(config) + + # Remove fields that are not part of the config class constructors + fields_to_remove = ["normalization_mapping", "type"] + for field in fields_to_remove: + if field in cleaned_config: + print(f"Removing '{field}' field from config") + del cleaned_config[field] + + # Convert input_features and output_features to PolicyFeature objects if they exist + if "input_features" in cleaned_config: + cleaned_config["input_features"] = convert_features_to_policy_features( + cleaned_config["input_features"] + ) + if "output_features" in cleaned_config: + cleaned_config["output_features"] = convert_features_to_policy_features( + cleaned_config["output_features"] + ) + + # Add normalization mapping to config + cleaned_config["normalization_mapping"] = norm_map + + # Create policy configuration using the factory + print(f"Creating {policy_type} policy configuration...") + policy_config = make_policy_config(policy_type, **cleaned_config) + + # Create policy instance using the factory + print(f"Instantiating {policy_type} policy...") + policy_class = get_policy_class(policy_type) + policy = policy_class(policy_config) + + # Load the cleaned state dict + policy.load_state_dict(new_state_dict, strict=True) + print("Successfully loaded cleaned state dict into policy model") + + # Create preprocessor and postprocessor using the factory + print("Creating preprocessor and postprocessor using make_pre_post_processors...") + preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats) + + # Determine hub repo ID if pushing to hub + hub_repo_id = None + if args.push_to_hub: + if args.hub_repo_id: + hub_repo_id = args.hub_repo_id + else: + if not os.path.isdir(args.pretrained_path): + # Use same repo with "_migrated" suffix + hub_repo_id = f"{args.pretrained_path}_migrated" + else: + raise ValueError("--hub-repo-id must be specified when pushing local model to hub") + + # Save all components to local directory first + print(f"Saving preprocessor to {output_dir}...") + preprocessor.save_pretrained(output_dir) + + print(f"Saving postprocessor to {output_dir}...") + postprocessor.save_pretrained(output_dir) + + print(f"Saving model to {output_dir}...") + policy.save_pretrained(output_dir) + + # Generate and save model card + print("Generating model card...") + # Get metadata from original config + dataset_repo_id = train_config.get("repo_id", "unknown") + license = config.get("license", "apache-2.0") + + tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type] + tags = set(tags).union({"robotics", "lerobot", policy_type}) + tags = list(tags) + + # Generate model card + card = policy.generate_model_card( + dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags + ) + + # Save model card locally + card.save(str(output_dir / "README.md")) + print(f"Model card saved to {output_dir / 'README.md'}") + # Push all files to hub in a single operation if requested + if args.push_to_hub and hub_repo_id: + api = HfApi() + + # Determine if we should create a PR (automatically if branch is specified) + create_pr = args.branch is not None + target_location = f"branch '{args.branch}'" if args.branch else "main branch" + + print(f"Pushing all migrated files to {hub_repo_id} on {target_location}...") + + # Upload all files in a single commit with automatic PR creation if branch specified + commit_message = "Migrate policy to PolicyProcessorPipeline system" + commit_description = None + + if create_pr: + # Separate commit description for PR body + commit_description = """🤖 **Automated Policy Migration to PolicyProcessorPipeline** + +This PR migrates your model to the new LeRobot policy format using the modern PolicyProcessorPipeline architecture. + +## What Changed + +### ✨ **New Architecture - PolicyProcessorPipeline** +Your model now uses external PolicyProcessorPipeline components for data processing instead of built-in normalization layers. This provides: +- **Modularity**: Separate preprocessing and postprocessing pipelines +- **Flexibility**: Easy to swap, configure, and debug processing steps +- **Compatibility**: Works with the latest LeRobot ecosystem + +### 🔧 **Normalization Extraction** +We've extracted normalization statistics from your model's state_dict and removed the built-in normalization layers: +- **Extracted patterns**: `normalize_inputs.*`, `unnormalize_outputs.*`, `normalize.*`, `unnormalize.*`, `input_normalizer.*`, `output_normalizer.*` +- **Statistics preserved**: Mean, std, min, max values for all features +- **Clean model**: State dict now contains only core model weights + +### 📦 **Files Added** +- **preprocessor_config.json**: Configuration for input preprocessing pipeline +- **postprocessor_config.json**: Configuration for output postprocessing pipeline +- **model.safetensors**: Clean model weights without normalization layers +- **config.json**: Updated model configuration +- **train_config.json**: Training configuration +- **README.md**: Updated model card with migration information + +### 🚀 **Benefits** +- **Backward Compatible**: Your model behavior remains identical +- **Future Ready**: Compatible with latest LeRobot features and updates +- **Debuggable**: Easy to inspect and modify processing steps +- **Portable**: Processors can be shared and reused across models + +### 💻 **Usage** +```python +# Load your migrated model +from lerobot.policies import get_policy_class +from lerobot.processor import PolicyProcessorPipeline + +# The preprocessor and postprocessor are now external +preprocessor = PolicyProcessorPipeline.from_pretrained("your-model-repo", config_filename="preprocessor_config.json") +postprocessor = PolicyProcessorPipeline.from_pretrained("your-model-repo", config_filename="postprocessor_config.json") +policy = get_policy_class("your-policy-type").from_pretrained("your-model-repo") + +# Process data through the pipeline +processed_batch = preprocessor(raw_batch) +action = policy(processed_batch) +final_action = postprocessor(action) +``` + +*Generated automatically by the LeRobot policy migration script*""" + + upload_kwargs = { + "repo_id": hub_repo_id, + "folder_path": output_dir, + "repo_type": "model", + "commit_message": commit_message, + "revision": args.branch, + "create_pr": create_pr, + "allow_patterns": ["*.json", "*.safetensors", "*.md"], + "ignore_patterns": ["*.tmp", "*.log"], + } + + # Add commit_description for PR body if creating PR + if create_pr and commit_description: + upload_kwargs["commit_description"] = commit_description + + api.upload_folder(**upload_kwargs) + + if create_pr: + print("All files pushed and pull request created successfully!") + else: + print("All files pushed to main branch successfully!") + + print("\nMigration complete!") + print(f"Migrated model saved to: {output_dir}") + if args.push_to_hub and hub_repo_id: + if args.branch: + print( + f"Successfully pushed all files to branch '{args.branch}' and created PR on https://huggingface.co/{hub_repo_id}" + ) + else: + print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}") + if args.branch: + print(f"\nView the branch at: https://huggingface.co/{hub_repo_id}/tree/{args.branch}") + print( + f"View the PR at: https://huggingface.co/{hub_repo_id}/discussions (look for the most recent PR)" + ) + else: + print(f"\nView the changes at: https://huggingface.co/{hub_repo_id}") + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 14628727..bece54f0 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -1,67 +1,353 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations -from collections.abc import Mapping +from copy import deepcopy from dataclasses import dataclass, field from typing import Any -import numpy as np import torch from torch import Tensor -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey + +from .converters import from_tensor_to_numpy, to_tensor +from .core import EnvTransition, PolicyAction, TransitionKey +from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry -def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]: - """Convert numpy arrays and other types to torch tensors.""" - tensor_stats: dict[str, dict[str, Tensor]] = {} - for key, sub in stats.items(): - tensor_stats[key] = {} - for stat_name, value in sub.items(): - if isinstance(value, np.ndarray): - tensor_val = torch.from_numpy(value.astype(np.float32)) - elif isinstance(value, torch.Tensor): - tensor_val = value.to(dtype=torch.float32) - elif isinstance(value, (int, float, list, tuple)): - tensor_val = torch.tensor(value, dtype=torch.float32) - else: - raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}") - tensor_stats[key][stat_name] = tensor_val - return tensor_stats +@dataclass +class _NormalizationMixin: + """ + A mixin class providing core functionality for normalization and unnormalization. + + This class manages normalization statistics (`stats`), converts them to tensors for + efficient computation, handles device placement, and implements the logic for + applying normalization transformations (mean/std and min/max). It is designed to + be inherited by concrete `ProcessorStep` implementations and should not be used + directly. + + **Stats Override Preservation:** + When stats are explicitly provided during construction (e.g., via overrides in + `DataProcessorPipeline.from_pretrained()`), they are preserved even when + `load_state_dict()` is called. This allows users to override normalization + statistics from saved models while keeping the rest of the model state intact. + + Examples: + ```python + # Common use case: Override with dataset stats + from lerobot.datasets import LeRobotDataset + + dataset = LeRobotDataset("my_dataset") + pipeline = DataProcessorPipeline.from_pretrained( + "model_path", overrides={"normalizer_processor": {"stats": dataset.meta.stats}} + ) + # dataset.meta.stats will be used, not the stats from the saved model + + # Custom stats override + custom_stats = {"action": {"mean": [0.0], "std": [1.0]}} + pipeline = DataProcessorPipeline.from_pretrained( + "model_path", overrides={"normalizer_processor": {"stats": custom_stats}} + ) + ``` + + Attributes: + features: A dictionary mapping feature names to `PolicyFeature` objects, defining + the data structure to be processed. + norm_map: A dictionary mapping `FeatureType` to `NormalizationMode`, specifying + which normalization method to use for each type of feature. + stats: A dictionary containing the normalization statistics (e.g., mean, std, + min, max) for each feature. + device: The PyTorch device on which to store and perform tensor operations. + eps: A small epsilon value to prevent division by zero in normalization + calculations. + normalize_observation_keys: An optional set of keys to selectively apply + normalization to specific observation features. + _tensor_stats: An internal dictionary holding the normalization statistics as + PyTorch tensors. + _stats_explicitly_provided: Internal flag tracking whether stats were explicitly + provided during construction (used for override preservation). + """ + + features: dict[str, PolicyFeature] + norm_map: dict[FeatureType, NormalizationMode] + stats: dict[str, dict[str, Any]] | None = None + device: torch.device | str | None = None + dtype: torch.dtype | None = None + eps: float = 1e-8 + normalize_observation_keys: set[str] | None = None + + _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + _stats_explicitly_provided: bool = field(default=False, init=False, repr=False) + + def __post_init__(self): + """ + Initializes the mixin after dataclass construction. + + This method handles the robust deserialization of `features` and `norm_map` + from JSON-compatible formats (where enums become strings and tuples become + lists) and converts the provided `stats` dictionary into a dictionary of + tensors (`_tensor_stats`) on the specified device. + """ + # Track if stats were explicitly provided (not None and not empty) + self._stats_explicitly_provided = self.stats is not None and bool(self.stats) + # Robust JSON deserialization handling (guard empty maps). + if self.features: + first_val = next(iter(self.features.values())) + if isinstance(first_val, dict): + reconstructed = {} + for key, ft_dict in self.features.items(): + reconstructed[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed + + if self.norm_map: + # if keys are strings (JSON), rebuild enum map + if all(isinstance(k, str) for k in self.norm_map.keys()): + reconstructed = {} + for ft_type_str, norm_mode_str in self.norm_map.items(): + reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) + self.norm_map = reconstructed + + # Convert stats to tensors and move to the target device once during initialization. + self.stats = self.stats or {} + if self.dtype is None: + self.dtype = torch.float32 + self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) + + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ) -> _NormalizationMixin: + """ + Moves the processor's normalization stats to the specified device. + + Args: + device: The target PyTorch device. + + Returns: + The instance of the class, allowing for method chaining. + """ + if device is not None: + self.device = device + if dtype is not None: + self.dtype = dtype + self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) + return self + + def state_dict(self) -> dict[str, Tensor]: + """ + Returns the normalization statistics as a flat state dictionary. + + All tensors are moved to the CPU before being returned, which is standard practice + for saving state dictionaries. + + Returns: + A flat dictionary mapping from `'feature_name.stat_name'` to the + corresponding statistics tensor on the CPU. + """ + flat: dict[str, Tensor] = {} + for key, sub in self._tensor_stats.items(): + for stat_name, tensor in sub.items(): + flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU + return flat + + def load_state_dict(self, state: dict[str, Tensor]) -> None: + """ + Loads normalization statistics from a state dictionary. + + The loaded tensors are moved to the processor's configured device. + + **Stats Override Preservation:** + If stats were explicitly provided during construction (e.g., via overrides in + `DataProcessorPipeline.from_pretrained()`), they are preserved and the state + dictionary is ignored. This allows users to override normalization statistics + while still loading the rest of the model state. + + This behavior is crucial for scenarios where users want to adapt a pretrained + model to a new dataset with different statistics without retraining the entire + model. + + Args: + state: A flat state dictionary with keys in the format + `'feature_name.stat_name'`. + + Note: + When stats are preserved due to explicit provision, only the tensor + representation is updated to ensure consistency with the current device + and dtype settings. + """ + # If stats were explicitly provided during construction, preserve them + if self._stats_explicitly_provided and self.stats is not None: + # Don't load from state_dict, keep the explicitly provided stats + # But ensure _tensor_stats is properly initialized + self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment] + return + + # Normal behavior: load stats from state_dict + self._tensor_stats.clear() + for flat_key, tensor in state.items(): + key, stat_name = flat_key.rsplit(".", 1) + # Load to the processor's configured device. + self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to( + dtype=torch.float32, device=self.device + ) + + # Reconstruct the original stats dict from tensor stats for compatibility with to() method + # and other functions that rely on self.stats + self.stats = {} + for key, tensor_dict in self._tensor_stats.items(): + self.stats[key] = {} + for stat_name, tensor in tensor_dict.items(): + # Convert tensor back to python/numpy format + self.stats[key][stat_name] = from_tensor_to_numpy(tensor) + + def get_config(self) -> dict[str, Any]: + """ + Returns a serializable dictionary of the processor's configuration. + + This method is used when saving the processor to disk, ensuring that its + configuration can be reconstructed later. + + Returns: + A JSON-serializable dictionary containing the configuration. + """ + config = { + "eps": self.eps, + "features": { + key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() + }, + "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, + } + if self.normalize_observation_keys is not None: + config["normalize_observation_keys"] = sorted(self.normalize_observation_keys) + return config + + def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]: + """ + Applies (un)normalization to all relevant features in an observation dictionary. + + Args: + observation: The observation dictionary to process. + inverse: If `True`, applies unnormalization; otherwise, applies normalization. + + Returns: + A new observation dictionary with the transformed tensor values. + """ + new_observation = dict(observation) + for key, feature in self.features.items(): + if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys: + continue + if feature.type != FeatureType.ACTION and key in new_observation: + # Convert to tensor but preserve original dtype for adaptation logic + tensor = torch.as_tensor(new_observation[key]) + new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse) + return new_observation + + def _normalize_action(self, action: Tensor, inverse: bool) -> Tensor: + # Convert to tensor but preserve original dtype for adaptation logic + """ + Applies (un)normalization to an action tensor. + + Args: + action: The action tensor to process. + inverse: If `True`, applies unnormalization; otherwise, applies normalization. + + Returns: + The transformed action tensor. + """ + processed_action = self._apply_transform(action, "action", FeatureType.ACTION, inverse=inverse) + return processed_action + + def _apply_transform( + self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False + ) -> Tensor: + """ + Core logic to apply a normalization or unnormalization transformation to a tensor. + + This method selects the appropriate normalization mode (e.g., mean/std, min/max) + based on the feature type and applies the corresponding mathematical operation. + + Args: + tensor: The input tensor to transform. + key: The feature key corresponding to the tensor. + feature_type: The `FeatureType` of the tensor. + inverse: If `True`, applies the inverse transformation (unnormalization). + + Returns: + The transformed tensor. + + Raises: + ValueError: If an unsupported normalization mode is encountered. + """ + norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY) + if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats: + return tensor + + if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX): + raise ValueError(f"Unsupported normalization mode: {norm_mode}") + + # For Accelerate compatibility: Ensure stats are on the same device and dtype as the input tensor + if self._tensor_stats and key in self._tensor_stats: + first_stat = next(iter(self._tensor_stats[key].values())) + if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype: + self.to(device=tensor.device, dtype=tensor.dtype) + + stats = self._tensor_stats[key] + + if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + # Avoid division by zero by adding a small epsilon. + denom = std + self.eps + if inverse: + return tensor * std + mean + return (tensor - mean) / denom + + if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + denom = max_val - min_val + # When min_val == max_val, substitute the denominator with a small epsilon + # to prevent division by zero. This consistently maps an input equal to + # min_val to -1, ensuring a stable transformation. + denom = torch.where( + denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom + ) + if inverse: + # Map from [-1, 1] back to [min, max] + return (tensor + 1) / 2 * denom + min_val + # Map from [min, max] to [-1, 1] + return 2 * (tensor - min_val) / denom - 1 + + # If necessary stats are missing, return input unchanged. + return tensor @dataclass @ProcessorStepRegistry.register(name="normalizer_processor") -class NormalizerProcessor: - """Normalizes observations and actions in a single processor step. - - This processor handles normalization of both observation and action tensors - using either mean/std normalization or min/max scaling to a [-1, 1] range. - - For each tensor key in the stats dictionary, the processor will: - - Use mean/std normalization if those statistics are provided: (x - mean) / std - - Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1 - - The processor can be configured to normalize only specific keys by setting - the normalize_keys parameter. +class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep): """ + A processor step that applies normalization to observations and actions in a transition. - # Features and normalisation map are mandatory to match the design of normalize.py - features: dict[str, PolicyFeature] - norm_map: dict[FeatureType, NormalizationMode] - - # Pre-computed statistics coming from dataset.meta.stats for instance. - stats: dict[str, dict[str, Any]] | None = None - - # Explicit subset of keys to normalise. If ``None`` every key (except - # "action") found in ``stats`` will be normalised. Using a ``set`` makes - # membership checks O(1). - normalize_keys: set[str] | None = None - - eps: float = 1e-8 - - _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + This class uses the logic from `_NormalizationMixin` to perform forward normalization + (e.g., scaling data to have zero mean and unit variance, or to the range [-1, 1]). + It is typically used in the pre-processing pipeline before feeding data to a policy. + """ @classmethod def from_lerobot_dataset( @@ -70,158 +356,73 @@ class NormalizerProcessor: features: dict[str, PolicyFeature], norm_map: dict[FeatureType, NormalizationMode], *, - normalize_keys: set[str] | None = None, + normalize_observation_keys: set[str] | None = None, eps: float = 1e-8, - ) -> NormalizerProcessor: - """Factory helper that pulls statistics from a :class:`LeRobotDataset`. - - The features and norm_map parameters are mandatory to match the design - pattern used in normalize.py. + device: torch.device | str | None = None, + ) -> NormalizerProcessorStep: """ + Creates a `NormalizerProcessorStep` instance using statistics from a `LeRobotDataset`. + Args: + dataset: The dataset from which to extract normalization statistics. + features: The feature definition for the processor. + norm_map: The mapping from feature types to normalization modes. + normalize_observation_keys: An optional set of observation keys to normalize. + eps: A small epsilon value for numerical stability. + device: The target device for the processor. + + Returns: + A new instance of `NormalizerProcessorStep`. + """ return cls( features=features, norm_map=norm_map, stats=dataset.meta.stats, - normalize_keys=normalize_keys, + normalize_observation_keys=normalize_observation_keys, eps=eps, + device=device, ) - def __post_init__(self): - # Handle deserialization from JSON config - if self.features and isinstance(list(self.features.values())[0], dict): - # Features came from JSON - need to reconstruct PolicyFeature objects - reconstructed_features = {} - for key, ft_dict in self.features.items(): - reconstructed_features[key] = PolicyFeature( - type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) - ) - self.features = reconstructed_features - - if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): - # norm_map came from JSON - need to reconstruct enum keys and values - reconstructed_norm_map = {} - for ft_type_str, norm_mode_str in self.norm_map.items(): - reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) - self.norm_map = reconstructed_norm_map - - # Convert statistics once so we avoid repeated numpy→Tensor conversions - # during runtime. - self.stats = self.stats or {} - self._tensor_stats = _convert_stats_to_tensors(self.stats) - - # Ensure *normalize_keys* is a set for fast look-ups and compare by - # value later when returning the configuration. - if self.normalize_keys is not None and not isinstance(self.normalize_keys, set): - self.normalize_keys = set(self.normalize_keys) - - def _normalize_obs(self, observation): - if observation is None: - return None - - # Decide which keys should be normalised for this call. - if self.normalize_keys is not None: - keys_to_norm = self.normalize_keys - else: - # Use feature map to skip action keys. - keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION} - - processed = dict(observation) - for key in keys_to_norm: - if key not in processed or key not in self._tensor_stats: - continue - - orig_val = processed[key] - tensor = ( - orig_val.to(dtype=torch.float32) - if isinstance(orig_val, torch.Tensor) - else torch.as_tensor(orig_val, dtype=torch.float32) - ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} - - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - processed[key] = (tensor - mean) / (std + self.eps) - elif "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 - return processed - - def _normalize_action(self, action): - if action is None or "action" not in self._tensor_stats: - return action - - tensor = ( - action.to(dtype=torch.float32) - if isinstance(action, torch.Tensor) - else torch.as_tensor(action, dtype=torch.float32) - ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - return (tensor - mean) / (std + self.eps) - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 - raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") - def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION)) - action = self._normalize_action(transition.get(TransitionKey.ACTION)) - - # Create a new transition with normalized values new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = observation - new_transition[TransitionKey.ACTION] = action + + # Handle observation normalization. + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is not None: + new_transition[TransitionKey.OBSERVATION] = self._normalize_observation( + observation, inverse=False + ) + + # Handle action normalization. + action = new_transition.get(TransitionKey.ACTION) + + if action is None: + return new_transition + + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + + new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False) + return new_transition - def get_config(self) -> dict[str, Any]: - config = { - "eps": self.eps, - "features": { - key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() - }, - "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, - } - if self.normalize_keys is not None: - # Serialise as a list for YAML / JSON friendliness - config["normalize_keys"] = sorted(self.normalize_keys) - return config - - def state_dict(self) -> dict[str, Tensor]: - flat = {} - for key, sub in self._tensor_stats.items(): - for stat_name, tensor in sub.items(): - flat[f"{key}.{stat_name}"] = tensor - return flat - - def load_state_dict(self, state: Mapping[str, Tensor]) -> None: - self._tensor_stats.clear() - for flat_key, tensor in state.items(): - key, stat_name = flat_key.rsplit(".", 1) - self._tensor_stats.setdefault(key, {})[stat_name] = tensor - - def reset(self): - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: return features @dataclass @ProcessorStepRegistry.register(name="unnormalizer_processor") -class UnnormalizerProcessor: - """Inverse normalisation for observations and actions. - - Exactly mirrors :class:`NormalizerProcessor` but applies the inverse - transform. +class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep): """ + A processor step that applies unnormalization to observations and actions. - features: dict[str, PolicyFeature] - norm_map: dict[FeatureType, NormalizationMode] - stats: dict[str, dict[str, Any]] | None = None - - _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + This class inverts the normalization process, scaling data back to its original + range. It is typically used in the post-processing pipeline to convert a policy's + normalized action output into a format that can be executed by a robot or + environment. + """ @classmethod def from_lerobot_dataset( @@ -229,103 +430,72 @@ class UnnormalizerProcessor: dataset: LeRobotDataset, features: dict[str, PolicyFeature], norm_map: dict[FeatureType, NormalizationMode], - ) -> UnnormalizerProcessor: - return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats) + *, + device: torch.device | str | None = None, + ) -> UnnormalizerProcessorStep: + """ + Creates an `UnnormalizerProcessorStep` using statistics from a `LeRobotDataset`. - def __post_init__(self): - # Handle deserialization from JSON config - if self.features and isinstance(list(self.features.values())[0], dict): - # Features came from JSON - need to reconstruct PolicyFeature objects - reconstructed_features = {} - for key, ft_dict in self.features.items(): - reconstructed_features[key] = PolicyFeature( - type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) - ) - self.features = reconstructed_features + Args: + dataset: The dataset from which to extract normalization statistics. + features: The feature definition for the processor. + norm_map: The mapping from feature types to normalization modes. + device: The target device for the processor. - if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): - # norm_map came from JSON - need to reconstruct enum keys and values - reconstructed_norm_map = {} - for ft_type_str, norm_mode_str in self.norm_map.items(): - reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) - self.norm_map = reconstructed_norm_map - - self.stats = self.stats or {} - self._tensor_stats = _convert_stats_to_tensors(self.stats) - - def _unnormalize_obs(self, observation): - if observation is None: - return None - keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION] - processed = dict(observation) - for key in keys: - if key not in processed or key not in self._tensor_stats: - continue - orig_val = processed[key] - tensor = ( - orig_val.to(dtype=torch.float32) - if isinstance(orig_val, torch.Tensor) - else torch.as_tensor(orig_val, dtype=torch.float32) - ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - processed[key] = tensor * std + mean - elif "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val - return processed - - def _unnormalize_action(self, action): - if action is None or "action" not in self._tensor_stats: - return action - tensor = ( - action.to(dtype=torch.float32) - if isinstance(action, torch.Tensor) - else torch.as_tensor(action, dtype=torch.float32) - ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - return tensor * std + mean - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - return (tensor + 1) / 2 * (max_val - min_val) + min_val - raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") + Returns: + A new instance of `UnnormalizerProcessorStep`. + """ + return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device) def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION)) - action = self._unnormalize_action(transition.get(TransitionKey.ACTION)) - - # Create a new transition with unnormalized values new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = observation - new_transition[TransitionKey.ACTION] = action + + # Handle observation unnormalization. + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is not None: + new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(observation, inverse=True) + + # Handle action unnormalization. + action = new_transition.get(TransitionKey.ACTION) + + if action is None: + return new_transition + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + + new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True) + return new_transition - def get_config(self) -> dict[str, Any]: - return { - "features": { - key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() - }, - "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, - } - - def state_dict(self) -> dict[str, Tensor]: - flat = {} - for key, sub in self._tensor_stats.items(): - for stat_name, tensor in sub.items(): - flat[f"{key}.{stat_name}"] = tensor - return flat - - def load_state_dict(self, state: Mapping[str, Tensor]) -> None: - self._tensor_stats.clear() - for flat_key, tensor in state.items(): - key, stat_name = flat_key.rsplit(".", 1) - self._tensor_stats.setdefault(key, {})[stat_name] = tensor - - def reset(self): - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: return features + + +def hotswap_stats( + policy_processor: PolicyProcessorPipeline, stats: dict[str, dict[str, Any]] +) -> PolicyProcessorPipeline: + """ + Replaces normalization statistics in an existing `PolicyProcessorPipeline` instance. + + This function creates a deep copy of the provided pipeline and updates the + statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` it + contains. This is useful for adapting a trained policy to a new environment or + dataset with different data distributions without having to reconstruct the entire + pipeline. + + Args: + policy_processor: The policy processor pipeline to modify. + stats: The new dictionary of normalization statistics to apply. + + Returns: + A new `PolicyProcessorPipeline` instance with the updated statistics. + """ + rp = deepcopy(policy_processor) + for step in rp.steps: + if isinstance(step, _NormalizationMixin): + step.stats = stats + # Re-initialize tensor_stats on the correct device. + step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype) # type: ignore[assignment] + return rp diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 7d63db23..71fdbbf0 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -20,32 +20,54 @@ import numpy as np import torch from torch import Tensor -from lerobot.configs.types import PolicyFeature +from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE -from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry + +from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @dataclass @ProcessorStepRegistry.register(name="observation_processor") -class VanillaObservationProcessor(ObservationProcessor): +class VanillaObservationProcessorStep(ObservationProcessorStep): """ - Processes environment observations into the LeRobot format by handling both images and states. + Processes standard Gymnasium observations into the LeRobot format. - Image processing: - - Converts channel-last (H, W, C) images to channel-first (C, H, W) - - Normalizes uint8 images ([0, 255]) to float32 ([0, 1]) - - Adds a batch dimension if missing - - Supports single images and image dictionaries + This step handles both image and state data from a typical observation dictionary, + preparing it for use in a LeRobot policy. - State processing: - - Maps 'environment_state' to observation.environment_state - - Maps 'agent_pos' to observation.state - - Converts numpy arrays to tensors - - Adds a batch dimension if missing + **Image Processing:** + - Converts channel-last (H, W, C), `uint8` images to channel-first (C, H, W), + `float32` tensors. + - Normalizes pixel values from the [0, 255] range to [0, 1]. + - Adds a batch dimension if one is not already present. + - Recognizes a single image under the key `"pixels"` and maps it to + `"observation.image"`. + - Recognizes a dictionary of images under the key `"pixels"` and maps them + to `"observation.images.{camera_name}"`. + + **State Processing:** + - Maps the `"environment_state"` key to `"observation.environment_state"`. + - Maps the `"agent_pos"` key to `"observation.state"`. + - Converts NumPy arrays to PyTorch tensors. + - Adds a batch dimension if one is not already present. """ def _process_single_image(self, img: np.ndarray) -> Tensor: - """Process a single image array.""" + """ + Processes a single NumPy image array into a channel-first, normalized tensor. + + Args: + img: A NumPy array representing the image, expected to be in channel-last + (H, W, C) format with a `uint8` dtype. + + Returns: + A `float32` PyTorch tensor in channel-first (B, C, H, W) format, with + pixel values normalized to the [0, 1] range. + + Raises: + ValueError: If the input image does not appear to be in channel-last + format or is not of `uint8` dtype. + """ # Convert to tensor img_tensor = torch.from_numpy(img) @@ -106,19 +128,32 @@ class VanillaObservationProcessor(ObservationProcessor): def observation(self, observation): return self._process_observation(observation) - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - """Transforms feature keys to a standardized contract. - - This method handles several renaming patterns: - - Exact matches (e.g., 'pixels' -> 'OBS_IMAGE'). - - Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE'). - - Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1'). - - Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1'). - - environment_state -> OBS_ENV_STATE, - - agent_pos -> OBS_STATE, - - observation.environment_state -> OBS_ENV_STATE, - - observation.agent_pos -> OBS_STATE + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: """ + Transforms feature keys from the Gym standard to the LeRobot standard. + + This method standardizes the feature dictionary by renaming keys according + to LeRobot's conventions, ensuring that policies can be constructed correctly. + It handles various raw key formats, including those with an "observation." prefix. + + **Renaming Rules:** + - `pixels` or `observation.pixels` -> `observation.image` + - `pixels.{cam}` or `observation.pixels.{cam}` -> `observation.images.{cam}` + - `environment_state` or `observation.environment_state` -> `observation.environment_state` + - `agent_pos` or `observation.agent_pos` -> `observation.state` + + Args: + features: The policy features dictionary with Gym-style keys. + + Returns: + The policy features dictionary with standardized LeRobot keys. + """ + # Build a new features mapping keyed by the same FeatureType buckets + # We assume callers already placed features in the correct FeatureType. + new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features.keys()} + exact_pairs = { "pixels": OBS_IMAGE, "environment_state": OBS_ENV_STATE, @@ -129,29 +164,43 @@ class VanillaObservationProcessor(ObservationProcessor): "pixels.": f"{OBS_IMAGES}.", } - for key in list(features.keys()): - matched_prefix = False - for old_prefix, new_prefix in prefix_pairs.items(): - prefixed_old = f"observation.{old_prefix}" - if key.startswith(prefixed_old): - suffix = key[len(prefixed_old) :] - features[f"{new_prefix}{suffix}"] = features.pop(key) - matched_prefix = True - break + # Iterate over all incoming feature buckets and normalize/move each entry + for src_ft, bucket in features.items(): + for key, feat in list(bucket.items()): + handled = False - if key.startswith(old_prefix): - suffix = key[len(old_prefix) :] - features[f"{new_prefix}{suffix}"] = features.pop(key) - matched_prefix = True - break - - if matched_prefix: - continue - - for old, new in exact_pairs.items(): - if key == old or key == f"observation.{old}": - if key in features: - features[new] = features.pop(key) + # Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1) + for old_prefix, new_prefix in prefix_pairs.items(): + prefixed_old = f"observation.{old_prefix}" + if key.startswith(prefixed_old): + suffix = key[len(prefixed_old) :] + new_key = f"{new_prefix}{suffix}" + new_features[src_ft][new_key] = feat + handled = True break - return features + if key.startswith(old_prefix): + suffix = key[len(old_prefix) :] + new_key = f"{new_prefix}{suffix}" + new_features[src_ft][new_key] = feat + handled = True + break + + if handled: + continue + + # Exact-name rules (pixels, environment_state, agent_pos) + for old, new in exact_pairs.items(): + if key == old or key == f"observation.{old}": + new_key = new + new_features[src_ft][new_key] = feat + handled = True + break + + if handled: + continue + + # Default: keep key in the same source FeatureType bucket + new_features[src_ft][key] = feat + + return new_features diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 6e1b2a2c..1c88cd74 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -13,72 +13,76 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +This module defines a generic, sequential data processing pipeline framework, primarily designed for +transforming robotics data (observations, actions, rewards, etc.). + +The core components are: +- ProcessorStep: An abstract base class for a single data transformation operation. +- ProcessorStepRegistry: A mechanism to register and retrieve ProcessorStep classes by name. +- DataProcessorPipeline: A class that chains multiple ProcessorStep instances together to form a complete + data processing workflow. It integrates with the Hugging Face Hub for easy sharing and versioning of + pipelines, including their configuration and state. +- Specialized abstract ProcessorStep subclasses (e.g., ObservationProcessorStep, ActionProcessorStep) + to simplify the creation of steps that target specific parts of a data transition. +""" + from __future__ import annotations import importlib import json import os +import re +from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence from copy import deepcopy from dataclasses import dataclass, field -from enum import Enum from pathlib import Path -from typing import Any, Protocol, TypedDict +from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast import torch -from huggingface_hub import ModelHubMixin, hf_hub_download -from huggingface_hub.errors import HfHubHTTPError +from huggingface_hub import hf_hub_download from safetensors.torch import load_file, save_file -from lerobot.configs.types import PolicyFeature +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.utils.hub import HubMixin +from .converters import batch_to_transition, create_transition, transition_to_batch +from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey -class TransitionKey(str, Enum): - """Keys for accessing EnvTransition dictionary components.""" - - # TODO(Steven): Use consts - OBSERVATION = "observation" - ACTION = "action" - REWARD = "reward" - DONE = "done" - TRUNCATED = "truncated" - INFO = "info" - COMPLEMENTARY_DATA = "complementary_data" - - -EnvTransition = TypedDict( - "EnvTransition", - { - TransitionKey.OBSERVATION.value: dict[str, Any] | None, - TransitionKey.ACTION.value: Any | torch.Tensor | None, - TransitionKey.REWARD.value: float | torch.Tensor | None, - TransitionKey.DONE.value: bool | torch.Tensor | None, - TransitionKey.TRUNCATED.value: bool | torch.Tensor | None, - TransitionKey.INFO.value: dict[str, Any] | None, - TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None, - }, -) +# Generic type variables for pipeline input and output. +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") class ProcessorStepRegistry: - """Registry for processor steps that enables saving/loading by name instead of module path.""" + """A registry for ProcessorStep classes to allow instantiation from a string name. + + This class provides a way to map string identifiers to `ProcessorStep` classes, + which is useful for deserializing pipelines from configuration files without + + hardcoding class imports. + """ _registry: dict[str, type] = {} @classmethod - def register(cls, name: str = None): - """Decorator to register a processor step class. + def register(cls, name: str | None = None): + """A class decorator to register a ProcessorStep. Args: - name: Optional registration name. If not provided, uses class name. + name: The name to register the class under. If None, the class's `__name__` is used. - Example: - @ProcessorStepRegistry.register("adaptive_normalizer") - class AdaptiveObservationNormalizer: - ... + Returns: + A decorator function that registers the class and returns it. + + Raises: + ValueError: If a step with the same name is already registered. """ def decorator(step_class: type) -> type: + """The actual decorator that performs the registration.""" registration_name = name if name is not None else step_class.__name__ if registration_name in cls._registry: @@ -88,7 +92,7 @@ class ProcessorStepRegistry: ) cls._registry[registration_name] = step_class - # Store the registration name on the class for later reference + # Store the registration name on the class for easy lookup during serialization. step_class._registry_name = registration_name return step_class @@ -96,16 +100,16 @@ class ProcessorStepRegistry: @classmethod def get(cls, name: str) -> type: - """Get a registered processor step class by name. + """Retrieves a processor step class from the registry by its name. Args: - name: The registration name of the step. + name: The name of the step to retrieve. Returns: - The registered step class. + The processor step class corresponding to the given name. Raises: - KeyError: If the step is not registered. + KeyError: If the name is not found in the registry. """ if name not in cls._registry: available = list(cls._registry.keys()) @@ -118,310 +122,231 @@ class ProcessorStepRegistry: @classmethod def unregister(cls, name: str) -> None: - """Remove a step from the registry.""" + """Removes a processor step from the registry. + + Args: + name: The name of the step to unregister. + """ cls._registry.pop(name, None) @classmethod def list(cls) -> list[str]: - """List all registered step names.""" + """Returns a list of all registered processor step names.""" return list(cls._registry.keys()) @classmethod def clear(cls) -> None: - """Clear all registrations.""" + """Clears all processor steps from the registry.""" cls._registry.clear() -class ProcessorStep(Protocol): - """Structural typing interface for a single processor step. +class ProcessorStep(ABC): + """Abstract base class for a single step in a data processing pipeline. - A step is any callable accepting a full `EnvTransition` dict and - returning a (possibly modified) dict of the same structure. Implementers - are encouraged—but not required—to expose the optional helper methods - listed below. When present, these hooks let `RobotProcessor` - automatically serialise the step's configuration and learnable state using - a safe-to-share JSON + SafeTensors format. + Each step must implement the `__call__` method to perform its transformation + on a data transition and the `transform_features` method to describe how it + alters the shape or type of data features. - - **Required**: - - ``__call__(transition: EnvTransition) -> EnvTransition`` - - ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]`` - - Optional helper protocol: - * ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable - configuration and state. YOU decide what to save here. This is where all - non-tensor state goes (e.g., name, counter, threshold, window_size). - The config dict will be passed to your class constructor when loading. - * ``state_dict() -> dict[str, torch.Tensor]`` – PyTorch tensor state ONLY. - This is exclusively for torch.Tensor objects (e.g., learned weights, - running statistics as tensors). Never put simple Python types here. - * ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict - containing torch tensors only. - * ``reset()`` – Clear internal buffers at episode boundaries. - - Example separation: - - get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10} - - state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)} + Subclasses can optionally be stateful by implementing `state_dict` and `load_state_dict`. """ - def __call__(self, transition: EnvTransition) -> EnvTransition: ... + _current_transition: EnvTransition | None = None - def get_config(self) -> dict[str, Any]: ... + @property + def transition(self) -> EnvTransition: + """Provides access to the most recent transition being processed. - def state_dict(self) -> dict[str, torch.Tensor]: ... + This is useful for steps that need to access other parts of the transition + data beyond their primary target (e.g., an action processing step that + needs to look at the observation). - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ... + Raises: + ValueError: If accessed before the step has been called with a transition. + """ + if self._current_transition is None: + raise ValueError("Transition is not set. Make sure to call the step with a transition first.") + return self._current_transition - def reset(self) -> None: ... + @abstractmethod + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Processes an environment transition. - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ... + This method should contain the core logic of the processing step. + + Args: + transition: The input data transition to be processed. + + Returns: + The processed transition. + """ + return transition + + def get_config(self) -> dict[str, Any]: + """Returns the configuration of the step for serialization. + + Returns: + A JSON-serializable dictionary of configuration parameters. + """ + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + """Returns the state of the step (e.g., learned parameters, running means). + + Returns: + A dictionary mapping state names to tensors. + """ + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Loads the step's state from a state dictionary. + + Args: + state: A dictionary of state tensors. + """ + return None + + def reset(self) -> None: + """Resets the internal state of the processor step, if any.""" + return None + + @abstractmethod + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Defines how this step modifies the description of pipeline features. + + This method is used to track changes in data shapes, dtypes, or modalities + as data flows through the pipeline, without needing to process actual data. + + Args: + features: A dictionary describing the input features for observations, actions, etc. + + Returns: + A dictionary describing the output features after this step's transformation. + """ + return features -def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401 - """Convert a *batch* dict coming from Learobot replay/dataset code into an - ``EnvTransition`` dictionary. +class ProcessorKwargs(TypedDict, total=False): + """A TypedDict for optional keyword arguments used in pipeline construction.""" - The function maps well known keys to the EnvTransition structure. Missing keys are - filled with sane defaults (``None`` or ``0.0``/``False``). - - Keys recognised (case-sensitive): - - * "observation.*" (keys starting with "observation." are grouped into observation dict) - * "action" - * "next.reward" - * "next.done" - * "next.truncated" - * "info" - - Additional keys are ignored so that existing dataloaders can carry extra - metadata without breaking the processor. - """ - - # Extract observation keys - observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} - observation = observation_keys if observation_keys else None - - # Extract padding and task keys for complementary data - pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} - task_key = {"task": batch["task"]} if "task" in batch else {} - complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {} - - transition: EnvTransition = { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: batch.get("action"), - TransitionKey.REWARD: batch.get("next.reward", 0.0), - TransitionKey.DONE: batch.get("next.done", False), - TransitionKey.TRUNCATED: batch.get("next.truncated", False), - TransitionKey.INFO: batch.get("info", {}), - TransitionKey.COMPLEMENTARY_DATA: complementary_data, - } - return transition + to_transition: Callable[[dict[str, Any]], EnvTransition] | None + to_output: Callable[[EnvTransition], Any] | None + name: str | None + before_step_hooks: list[Callable[[int, EnvTransition], None]] | None + after_step_hooks: list[Callable[[int, EnvTransition], None]] | None -def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401 - """Inverse of :pyfunc:`_default_batch_to_transition`. Returns a dict with - the canonical field names used throughout *LeRobot*. - """ +class ProcessorMigrationError(Exception): + """Raised when a model needs migration to the processor format""" - batch = { - "action": transition.get(TransitionKey.ACTION), - "next.reward": transition.get(TransitionKey.REWARD, 0.0), - "next.done": transition.get(TransitionKey.DONE, False), - "next.truncated": transition.get(TransitionKey.TRUNCATED, False), - "info": transition.get(TransitionKey.INFO, {}), - } - - # Add padding and task data from complementary_data - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) - if complementary_data: - pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k} - batch.update(pad_data) - - if "task" in complementary_data: - batch["task"] = complementary_data["task"] - - # Handle observation - flatten dict to observation.* keys if it's a dict - observation = transition.get(TransitionKey.OBSERVATION) - if isinstance(observation, dict): - batch.update(observation) - - return batch + def __init__(self, model_path: str | Path, migration_command: str, original_error: str): + self.model_path = model_path + self.migration_command = migration_command + self.original_error = original_error + super().__init__( + f"Model '{model_path}' requires migration to processor format. " + f"Run: {migration_command}\n\nOriginal error: {original_error}" + ) @dataclass -class RobotProcessor(ModelHubMixin): - """ - Composable, debuggable post-processing processor for robot transitions. +class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]): + """A sequential pipeline for processing data, integrated with the Hugging Face Hub. - The class orchestrates an ordered collection of small, functional transforms—steps—executed - left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts - and batch dictionaries, automatically converting between formats as needed. + This class chains together multiple `ProcessorStep` instances to form a complete + data processing workflow. It's generic, allowing for custom input and output types, + which are handled by the `to_transition` and `to_output` converters. - Args: - steps: Ordered list of processing steps executed on every call. Defaults to empty list. - name: Human-readable identifier that is persisted inside the JSON config. - Defaults to "RobotProcessor". - to_transition: Function to convert batch dict to EnvTransition dict. - Defaults to _default_batch_to_transition. - to_output: Function to convert EnvTransition dict to the desired output format. - Usually it is a batch dict or EnvTransition dict. - Defaults to _default_transition_to_batch. - before_step_hooks: List of hooks called before each step. Each hook receives the step - index and transition, and can optionally return a modified transition. - after_step_hooks: List of hooks called after each step. Each hook receives the step - index and transition, and can optionally return a modified transition. - - Hook Semantics: - - Hooks are executed sequentially in the order they were registered. There is no way to - reorder hooks after registration without creating a new pipeline. - - Hooks are for observation/monitoring only and DO NOT modify transitions. They are called - with the step index and current transition for logging, debugging, or monitoring purposes. - - All hooks for a given type (before/after) are executed for every step, or none at all if - an error occurs. There is no partial execution of hooks. - - Hooks should generally be stateless to maintain predictable behavior. If you need stateful - processing, consider implementing a proper ProcessorStep instead. - - To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline. - - Hooks ALWAYS receive transitions in EnvTransition format, regardless of the input format - passed to __call__. This ensures consistent hook behavior whether processing batch dicts - or EnvTransition objects. + Attributes: + steps: A sequence of `ProcessorStep` objects that make up the pipeline. + name: A descriptive name for the pipeline. + to_transition: A function to convert raw input data into the standardized `EnvTransition` format. + to_output: A function to convert the final `EnvTransition` into the desired output format. + before_step_hooks: A list of functions to be called before each step is executed. + after_step_hooks: A list of functions to be called after each step is executed. """ steps: Sequence[ProcessorStep] = field(default_factory=list) - name: str = "RobotProcessor" + name: str = "DataProcessorPipeline" - to_transition: Callable[[dict[str, Any]], EnvTransition] = field( - default_factory=lambda: _default_batch_to_transition, repr=False + to_transition: Callable[[TInput], EnvTransition] = field( + default_factory=lambda: cast(Callable[[TInput], EnvTransition], batch_to_transition), repr=False ) - to_output: Callable[[EnvTransition], dict[str, Any] | EnvTransition] = field( - default_factory=lambda: _default_transition_to_batch, repr=False + to_output: Callable[[EnvTransition], TOutput] = field( + default_factory=lambda: cast(Callable[[EnvTransition], TOutput], transition_to_batch), + repr=False, ) - # Processor-level hooks for observation/monitoring - # Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) - def __call__(self, data: EnvTransition | dict[str, Any]): - """Process data through all steps. - - The method accepts either the classic EnvTransition dict or a batch dictionary - (like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied - it is first converted to the internal dict format using to_transition; after all - steps are executed the dict is transformed back into a batch dict with to_batch and the - result is returned – thereby preserving the caller's original data type. + def __call__(self, data: TInput) -> TOutput: + """Processes input data through the full pipeline. Args: - data: Either an EnvTransition dict or a batch dictionary to process. + data: The input data to process. Returns: - The processed data in the same format as the input (EnvTransition or batch dict). - - Raises: - ValueError: If the transition is not a valid EnvTransition format. + The processed data in the specified output format. """ - # Check if we need to convert back to batch format at the end - _, called_with_batch = self._prepare_transition(data) + transition = self.to_transition(data) + transformed_transition = self._forward(transition) + return self.to_output(transformed_transition) - # Use step_through to get the iterator - step_iterator = self.step_through(data) + def _forward(self, transition: EnvTransition) -> EnvTransition: + """Executes all processing steps and hooks in sequence. - # Get initial state (before any steps) - current_transition = next(step_iterator) + Args: + transition: The initial `EnvTransition` object. - # Process each step with hooks - for idx, next_transition in enumerate(step_iterator): - # Apply before hooks with current state (before step execution) + Returns: + The final `EnvTransition` after all steps have been applied. + """ + for idx, processor_step in enumerate(self.steps): + # Execute pre-hooks for hook in self.before_step_hooks: - hook(idx, current_transition) + hook(idx, transition) - # Move to next state (after step execution) - current_transition = next_transition + transition = processor_step(transition) - # Apply after hooks with updated state + # Execute post-hooks for hook in self.after_step_hooks: - hook(idx, current_transition) + hook(idx, transition) + return transition - # Convert back to original format if needed - return self.to_output(current_transition) if called_with_batch else current_transition + def step_through(self, data: TInput) -> Iterable[EnvTransition]: + """Processes data step-by-step, yielding the transition at each stage. - def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]: - """Prepare and validate transition data for processing. + This is a generator method useful for debugging and inspecting the intermediate + state of the data as it passes through the pipeline. Args: - data: Either an EnvTransition dict or a batch dictionary to process. - - Returns: - A tuple of (prepared_transition, called_with_batch_flag) - - Raises: - ValueError: If the transition is not a valid EnvTransition format. - """ - # Check if data is already an EnvTransition or needs conversion - if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()): - # It's a batch dict, convert it - called_with_batch = True - transition = self.to_transition(data) - else: - # It's already an EnvTransition - called_with_batch = False - transition = data - - # Basic validation - if not isinstance(transition, dict): - raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}") - - return transition, called_with_batch - - def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]: - """Yield the intermediate results after each processor step. - - This is a low-level method that does NOT apply hooks. It simply executes each step - and yields the intermediate results. This allows users to debug the pipeline or - apply custom logic between steps if needed. - - Note: This method always yields EnvTransition objects regardless of input format. - If you need the results in the original input format, you'll need to convert them - using `to_output()`. - - Args: - data: Either an EnvTransition dict or a batch dictionary to process. + data: The input data. Yields: - The intermediate EnvTransition results after each step. + The `EnvTransition` object, starting with the initial state and then after + each processing step. """ - transition, _ = self._prepare_transition(data) + transition = self.to_transition(data) - # Yield initial state + # Yield the initial state before any processing. yield transition - # Process each step WITHOUT hooks (low-level method) for processor_step in self.steps: transition = processor_step(transition) yield transition def _save_pretrained(self, save_directory: Path, **kwargs): - """Internal save method for ModelHubMixin compatibility.""" - # Extract config_filename from kwargs if provided - config_filename = kwargs.pop("config_filename", None) - self.save_pretrained(save_directory, config_filename=config_filename) + """Internal method to comply with `HubMixin`'s saving mechanism. - def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs): - """Serialize the processor definition and parameters to *save_directory*. - - Args: - save_directory: Directory where the processor will be saved. - config_filename: Optional custom config filename. If not provided, defaults to - "{self.name}.json" where self.name is sanitized for filesystem compatibility. + This method does the actual saving work and is called by HubMixin.save_pretrained. """ - os.makedirs(str(save_directory), exist_ok=True) + config_filename = kwargs.pop("config_filename", None) - # Sanitize processor name for use in filenames - import re - - # The huggingface hub does not allow special characters in the repo name, so we sanitize the name + # Sanitize the pipeline name to create a valid filename prefix. sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) - # Use sanitized name for config if not provided if config_filename is None: config_filename = f"{sanitized_name}.json" @@ -430,40 +355,31 @@ class RobotProcessor(ModelHubMixin): "steps": [], } + # Iterate through each step to build its configuration entry. for step_index, processor_step in enumerate(self.steps): - # Check if step was registered registry_name = getattr(processor_step.__class__, "_registry_name", None) step_entry: dict[str, Any] = {} + # Prefer registry name for portability, otherwise fall back to full class path. if registry_name: - # Use registry name for registered steps step_entry["registry_name"] = registry_name else: - # Fall back to full module path for unregistered steps step_entry["class"] = ( f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}" ) + # Save step configuration if `get_config` is implemented. if hasattr(processor_step, "get_config"): step_entry["config"] = processor_step.get_config() + # Save step state if `state_dict` is implemented and returns a non-empty dict. if hasattr(processor_step, "state_dict"): state = processor_step.state_dict() if state: - # Clone tensors to avoid shared memory issues - # This ensures each tensor has its own memory allocation - # The reason is to avoid the following error: - # RuntimeError: Some tensors share memory, this will lead to duplicate memory on disk - # and potential differences when loading them again - # ------------------------------------------------------------------------------ - # Since the state_dict of processor will be light, we can just clone the tensors - # and save them to the disk. - cloned_state = {} - for key, tensor in state.items(): - cloned_state[key] = tensor.clone() + # Clone tensors to avoid modifying the original state. + cloned_state = {key: tensor.clone() for key, tensor in state.items()} - # Include pipeline name and step index to ensure unique filenames - # This prevents conflicts when multiple processors are saved in the same directory + # Create a unique filename for the state file. if registry_name: state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors" else: @@ -474,13 +390,69 @@ class RobotProcessor(ModelHubMixin): config["steps"].append(step_entry) + # Write the main configuration JSON file. with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer: json.dump(config, file_pointer, indent=2) + def save_pretrained( + self, + save_directory: str | Path | None = None, + *, + repo_id: str | None = None, + push_to_hub: bool = False, + card_kwargs: dict[str, Any] | None = None, + config_filename: str | None = None, + **push_to_hub_kwargs, + ): + """Saves the pipeline's configuration and state to a directory. + + This method creates a JSON configuration file that defines the pipeline's structure + (name and steps). For each stateful step, it also saves a `.safetensors` file + containing its state dictionary. + + Args: + save_directory: The directory where the pipeline will be saved. If None, saves to + HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}. + repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`. + push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it. + card_kwargs: Additional arguments passed to the card template to customize the card. + config_filename: The name of the JSON configuration file. If None, a name is + generated from the pipeline's `name` attribute. + **push_to_hub_kwargs: Additional key word arguments passed along to the push_to_hub method. + """ + if save_directory is None: + # Use default directory in HF_LEROBOT_HOME + from lerobot.constants import HF_LEROBOT_HOME + + sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) + save_directory = HF_LEROBOT_HOME / "processors" / sanitized_name + + # For direct saves (not through hub), handle config_filename + if not push_to_hub and config_filename is not None: + # Call _save_pretrained directly with config_filename + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + self._save_pretrained(save_directory, config_filename=config_filename) + return None + + # Pass config_filename through kwargs for _save_pretrained when using hub + if config_filename is not None: + push_to_hub_kwargs["config_filename"] = config_filename + + # Call parent's save_pretrained which will call our _save_pretrained + return super().save_pretrained( + save_directory=save_directory, + repo_id=repo_id, + push_to_hub=push_to_hub, + card_kwargs=card_kwargs, + **push_to_hub_kwargs, + ) + @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str | Path, + config_filename: str, *, force_download: bool = False, resume_download: bool | None = None, @@ -489,267 +461,798 @@ class RobotProcessor(ModelHubMixin): cache_dir: str | Path | None = None, local_files_only: bool = False, revision: str | None = None, - config_filename: str | None = None, overrides: dict[str, Any] | None = None, + to_transition: Callable[[TInput], EnvTransition] | None = None, + to_output: Callable[[EnvTransition], TOutput] | None = None, **kwargs, - ) -> RobotProcessor: - """Load a serialized processor from source (local path or Hugging Face Hub identifier). + ) -> DataProcessorPipeline[TInput, TOutput]: + """Loads a pipeline from a local directory, single file, or Hugging Face Hub repository. + + This method implements a simplified loading pipeline with intelligent migration detection: + + **Simplified Loading Strategy**: + 1. **Config Loading** (_load_config): + - **Directory**: Load specified config_filename from directory + - **Single file**: Load file directly (config_filename ignored) + - **Hub repository**: Download specified config_filename from Hub + + 2. **Config Validation** (_validate_loaded_config): + - Format validation: Ensure config is valid processor format + - Migration detection: Guide users to migrate old LeRobot models + - Clear errors: Provide actionable error messages + + 3. **Step Construction** (_build_steps_with_overrides): + - Class resolution: Registry lookup or dynamic imports + - Override merging: User parameters override saved config + - State loading: Load .safetensors files for stateful steps + + 4. **Override Validation** (_validate_overrides_used): + - Ensure all user overrides were applied (catch typos) + - Provide helpful error messages with available keys + + **Migration Detection**: + - **Smart detection**: Analyzes JSON files to detect old LeRobot models + - **Precise targeting**: Avoids false positives on other HuggingFace models + - **Clear guidance**: Provides exact migration command to run + - **Error mode**: Always raises ProcessorMigrationError for clear user action + + **Loading Examples**: + ```python + # Directory loading + pipeline = DataProcessorPipeline.from_pretrained("/models/my_model", config_filename="processor.json") + + # Single file loading + pipeline = DataProcessorPipeline.from_pretrained( + "/models/my_model/processor.json", config_filename="processor.json" + ) + + # Hub loading + pipeline = DataProcessorPipeline.from_pretrained("user/repo", config_filename="processor.json") + + # Multiple configs (preprocessor/postprocessor) + preprocessor = DataProcessorPipeline.from_pretrained( + "model", config_filename="policy_preprocessor.json" + ) + postprocessor = DataProcessorPipeline.from_pretrained( + "model", config_filename="policy_postprocessor.json" + ) + ``` + + **Override System**: + - **Key matching**: Use registry names or class names as override keys + - **Config merging**: User overrides take precedence over saved config + - **Validation**: Ensure all override keys match actual steps (catch typos) + - **Example**: overrides={"NormalizeStep": {"device": "cuda"}} Args: - pretrained_model_name_or_path: Local path to a saved processor directory or Hugging Face Hub identifier - (e.g., "username/processor-name"). - config_filename: Optional specific config filename to load. If not provided, will: - - For local paths: look for any .json file in the directory (error if multiple found) - - For HF Hub: try common names ("processor.json", "preprocessor.json", "postprocessor.json") - overrides: Optional dictionary mapping step names to configuration overrides. - Keys must match exact step class names (for unregistered steps) or registry names - (for registered steps). Values are dictionaries containing parameter overrides - that will be merged with the saved configuration. This is useful for providing - non-serializable objects like environment instances. + pretrained_model_name_or_path: The identifier of the repository on the Hugging Face Hub, + a path to a local directory, or a path to a single config file. + config_filename: The name of the pipeline's JSON configuration file. Always required + to prevent ambiguity when multiple configs exist (e.g., preprocessor vs postprocessor). + force_download: Whether to force (re)downloading the files. + resume_download: Whether to resume a previously interrupted download. + proxies: A dictionary of proxy servers to use. + token: The token to use as HTTP bearer authorization for private Hub repositories. + cache_dir: The path to a specific cache folder to store downloaded files. + local_files_only: If True, avoid downloading files from the Hub. + revision: The specific model version to use (e.g., a branch name, tag name, or commit id). + overrides: A dictionary to override the configuration of specific steps. Keys should + match the step's class name or registry name. + to_transition: A custom function to convert input data to `EnvTransition`. + to_output: A custom function to convert the final `EnvTransition` to the output format. + **kwargs: Additional arguments (not used). Returns: - A RobotProcessor instance loaded from the saved configuration. + An instance of `DataProcessorPipeline` loaded with the specified configuration and state. Raises: - ImportError: If a processor step class cannot be loaded or imported. - ValueError: If a step cannot be instantiated with the provided configuration. - KeyError: If an override key doesn't match any step in the saved configuration. - - Examples: - Basic loading: - ```python - processor = RobotProcessor.from_pretrained("path/to/processor") - ``` - - Loading specific config file: - ```python - processor = RobotProcessor.from_pretrained( - "username/multi-processor-repo", config_filename="preprocessor.json" - ) - ``` - - Loading with overrides for non-serializable objects: - ```python - import gym - - env = gym.make("CartPole-v1") - processor = RobotProcessor.from_pretrained( - "username/cartpole-processor", overrides={"ActionRepeatStep": {"env": env}} - ) - ``` - - Multiple overrides: - ```python - processor = RobotProcessor.from_pretrained( - "path/to/processor", - overrides={ - "CustomStep": {"param1": "new_value"}, - "device_processor": {"device": "cuda:1"}, # For registered steps - }, - ) - ``` + FileNotFoundError: If the config file cannot be found. + ValueError: If configuration is ambiguous or instantiation fails. + ImportError: If a step's class cannot be imported. + KeyError: If an override key doesn't match any step in the pipeline. + ProcessorMigrationError: If the model requires migration to processor format. """ - # Use the local variable name 'source' for clarity - source = str(pretrained_model_name_or_path) + model_id = str(pretrained_model_name_or_path) + hub_download_kwargs = { + "force_download": force_download, + "resume_download": resume_download, + "proxies": proxies, + "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + "revision": revision, + } - if Path(source).is_dir(): - # Local path - use it directly - base_path = Path(source) + # 1. Load configuration using simplified 3-way logic + loaded_config, base_path = cls._load_config(model_id, config_filename, hub_download_kwargs) - if config_filename is None: - # Look for any .json file in the directory - json_files = list(base_path.glob("*.json")) - if len(json_files) == 0: - raise FileNotFoundError(f"No .json configuration files found in {source}") - elif len(json_files) > 1: - raise ValueError( - f"Multiple .json files found in {source}: {[f.name for f in json_files]}. " - f"Please specify which one to load using the config_filename parameter." - ) - config_filename = json_files[0].name + # 2. Validate configuration and handle migration + cls._validate_loaded_config(model_id, loaded_config, config_filename) - with open(base_path / config_filename) as file_pointer: - loaded_config: dict[str, Any] = json.load(file_pointer) - else: - # Hugging Face Hub - download all required files - if config_filename is None: - # Try common config names - common_names = [ - "processor.json", - "preprocessor.json", - "postprocessor.json", - "robotprocessor.json", - ] - config_path = None - for name in common_names: - try: - config_path = hf_hub_download( - source, - name, - repo_type="model", - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - ) - config_filename = name - break - except (FileNotFoundError, OSError, HfHubHTTPError): - # FileNotFoundError: local file issues - # OSError: network/system errors - # HfHubHTTPError: file not found on Hub (404) or other HTTP errors - continue + # 3. Build steps with overrides + steps, validated_overrides = cls._build_steps_with_overrides( + loaded_config, overrides or {}, model_id, base_path, hub_download_kwargs + ) - if config_path is None: - raise FileNotFoundError( - f"No processor configuration file found in {source}. " - f"Tried: {common_names}. Please specify the config_filename parameter." - ) - else: - # Download specific config file - config_path = hf_hub_download( - source, - config_filename, - repo_type="model", - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, + # 4. Validate that all overrides were used + cls._validate_overrides_used(validated_overrides, loaded_config) + + # 5. Construct and return the final pipeline instance + return cls( + steps=steps, + name=loaded_config.get("name", "DataProcessorPipeline"), + to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition), + to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch), + ) + + @classmethod + def _load_config( + cls, + model_id: str, + config_filename: str, + hub_download_kwargs: dict[str, Any], + ) -> tuple[dict[str, Any], Path]: + """Load configuration from local file or Hugging Face Hub. + + This method implements a super-simplified 3-way loading strategy: + + 1. **Local directory**: Load config_filename from directory + - Example: model_id="/models/my_model", config_filename="processor.json" + - Loads: "/models/my_model/processor.json" + + 2. **Single file**: Load file directly (ignore config_filename) + - Example: model_id="/models/my_model/processor.json" + - Loads: "/models/my_model/processor.json" (config_filename ignored) + + 3. **Hub repository**: Download config_filename from Hub + - Example: model_id="user/repo", config_filename="processor.json" + - Downloads and loads: config_filename from Hub repo + + **Benefits of Explicit config_filename**: + - No auto-detection complexity or edge cases + - No risk of loading wrong config (preprocessor vs postprocessor) + - Consistent behavior across local and Hub usage + - Clear, predictable errors + + Args: + model_id: The model identifier (Hub repo ID, local directory, or file path) + config_filename: The explicit config filename to load (always required) + hub_download_kwargs: Parameters for hf_hub_download (tokens, cache, etc.) + + Returns: + Tuple of (loaded_config, base_path) + - loaded_config: Parsed JSON config dict (always loaded, never None) + - base_path: Directory containing config file (for state file resolution) + + Raises: + FileNotFoundError: If config file cannot be found locally or on Hub + """ + model_path = Path(model_id) + + if model_path.is_dir(): + # Directory: load specified config from directory + config_path = model_path / config_filename + if not config_path.exists(): + # Check for migration before giving clear error + if cls._should_suggest_migration(model_path): + cls._suggest_processor_migration(model_id, f"Config file '{config_filename}' not found") + raise FileNotFoundError( + f"Config file '{config_filename}' not found in directory '{model_id}'" ) - with open(config_path) as file_pointer: - loaded_config = json.load(file_pointer) + with open(config_path) as f: + return json.load(f), model_path - # Store downloaded files in the same directory as the config - base_path = Path(config_path).parent + elif model_path.is_file(): + # File: load file directly (config_filename is ignored for single files) + with open(model_path) as f: + return json.load(f), model_path.parent - # Handle None overrides - if overrides is None: - overrides = {} - - # Validate that all override keys will be matched - override_keys = set(overrides.keys()) - - steps: list[ProcessorStep] = [] - for step_entry in loaded_config["steps"]: - # Check if step uses registry name or module path - if "registry_name" in step_entry: - # Load from registry - try: - step_class = ProcessorStepRegistry.get(step_entry["registry_name"]) - step_key = step_entry["registry_name"] - except KeyError as e: - raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e - else: - # Fall back to module path loading for backward compatibility - full_class_path = step_entry["class"] - module_path, class_name = full_class_path.rsplit(".", 1) - - # Import the module containing the step class - try: - module = importlib.import_module(module_path) - step_class = getattr(module, class_name) - step_key = class_name - except (ImportError, AttributeError) as e: - raise ImportError( - f"Failed to load processor step '{full_class_path}'. " - f"Make sure the module '{module_path}' is installed and contains class '{class_name}'. " - f"Consider registering the step using @ProcessorStepRegistry.register() for better portability. " - f"Error: {str(e)}" - ) from e - - # Instantiate the step with its config + else: + # Hub: download specified config try: - saved_cfg = step_entry.get("config", {}) - step_overrides = overrides.get(step_key, {}) - merged_cfg = {**saved_cfg, **step_overrides} - step_instance: ProcessorStep = step_class(**merged_cfg) + config_path = hf_hub_download( + repo_id=model_id, + filename=config_filename, + repo_type="model", + **hub_download_kwargs, + ) - # Track which override keys were used - if step_key in override_keys: - override_keys.discard(step_key) + with open(config_path) as f: + return json.load(f), Path(config_path).parent except Exception as e: - step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown")) - raise ValueError( - f"Failed to instantiate processor step '{step_name}' with config: {step_entry.get('config', {})}. " - f"Error: {str(e)}" + raise FileNotFoundError( + f"Could not find '{config_filename}' on the HuggingFace Hub at '{model_id}'" ) from e - # Load state if available - if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"): - if Path(source).is_dir(): - # Local path - read directly - state_path = str(base_path / step_entry["state_file"]) - else: - # Hugging Face Hub - download the state file - state_path = hf_hub_download( - source, - step_entry["state_file"], - repo_type="model", - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - ) + @classmethod + def _validate_loaded_config( + cls, model_id: str, loaded_config: dict[str, Any], config_filename: str + ) -> None: + """Validate that a config was loaded and is a valid processor config. - step_instance.load_state_dict(load_file(state_path)) + This method validates processor config format with intelligent migration detection: + + **Config Format Validation**: + - Use _is_processor_config() to validate structure + - Must have "steps" field with list of step configurations + - Each step needs "class" or "registry_name" + - If validation fails AND local directory: Check for migration need + - If migration needed: Raise ProcessorMigrationError with command + - If no migration: Raise ValueError with helpful error message + + **Migration Detection Logic**: + - Only triggered for local directories (not Hub repos) + - Analyzes all JSON files in directory to detect old LeRobot models + - Provides exact migration command with model path + + Args: + model_id: The model identifier (used for migration detection) + loaded_config: The loaded config dictionary (guaranteed non-None) + config_filename: The config filename that was loaded (for error messages) + + Raises: + ValueError: If config format is invalid + ProcessorMigrationError: If model needs migration to processor format + """ + # Validate that this is actually a processor config + if not cls._is_processor_config(loaded_config): + if Path(model_id).is_dir() and cls._should_suggest_migration(Path(model_id)): + cls._suggest_processor_migration( + model_id, + f"Config file '{config_filename}' is not a valid processor configuration", + ) + raise ValueError( + f"Config file '{config_filename}' is not a valid processor configuration. " + f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}" + ) + + @classmethod + def _build_steps_with_overrides( + cls, + loaded_config: dict[str, Any], + overrides: dict[str, Any], + model_id: str, + base_path: Path | None, + hub_download_kwargs: dict[str, Any], + ) -> tuple[list[ProcessorStep], set[str]]: + """Build all processor steps with overrides and state loading. + + This method orchestrates the complete step construction pipeline: + + **For each step in loaded_config["steps"]**: + + 1. **Class Resolution** (via _resolve_step_class): + - **If "registry_name" exists**: Look up in ProcessorStepRegistry + Example: {"registry_name": "normalize_step"} -> Get registered class + - **Else use "class" field**: Dynamic import from full module path + Example: {"class": "lerobot.processor.normalize.NormalizeStep"} + - **Result**: (step_class, step_key) where step_key is used for overrides + + 2. **Step Instantiation** (via _instantiate_step): + - **Merge configs**: saved_config + user_overrides + - **Override priority**: User overrides take precedence over saved config + - **Example**: saved={"mean": 0.0}, override={"mean": 1.0} -> final={"mean": 1.0} + - **Result**: Instantiated ProcessorStep object + + 3. **State Loading** (via _load_step_state): + - **If step has "state_file"**: Load tensor state from .safetensors + - **Local first**: Check base_path/state_file.safetensors + - **Hub fallback**: Download state file if not found locally + - **Optional**: Only load if step has load_state_dict method + + 4. **Override Tracking**: + - **Track used overrides**: Remove step_key from remaining set + - **Purpose**: Validate all user overrides were applied (detect typos) + + **Error Handling**: + - Class resolution errors -> ImportError with helpful message + - Instantiation errors -> ValueError with config details + - State loading errors -> Propagated from load_state_dict + + Args: + loaded_config: The loaded processor configuration (must have "steps" field) + overrides: User-provided parameter overrides (keyed by class/registry name) + model_id: The model identifier (needed for Hub state file downloads) + base_path: Local directory path for finding state files + hub_download_kwargs: Parameters for hf_hub_download (tokens, cache, etc.) + + Returns: + Tuple of (instantiated_steps_list, unused_override_keys) + - instantiated_steps_list: List of ready-to-use ProcessorStep instances + - unused_override_keys: Override keys that didn't match any step (for validation) + + Raises: + ImportError: If a step class cannot be imported or found in registry + ValueError: If a step cannot be instantiated with its configuration + """ + steps: list[ProcessorStep] = [] + override_keys = set(overrides.keys()) + + for step_entry in loaded_config["steps"]: + # 1. Get step class and key + step_class, step_key = cls._resolve_step_class(step_entry) + + # 2. Instantiate step with overrides + step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides) + + # 3. Load step state if available + cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs) + + # 4. Track used overrides + if step_key in override_keys: + override_keys.discard(step_key) steps.append(step_instance) - # Check for unused override keys - if override_keys: - available_keys = [] - for step_entry in loaded_config["steps"]: - if "registry_name" in step_entry: - available_keys.append(step_entry["registry_name"]) - else: - full_class_path = step_entry["class"] - class_name = full_class_path.rsplit(".", 1)[1] - available_keys.append(class_name) + return steps, override_keys - raise KeyError( - f"Override keys {list(override_keys)} do not match any step in the saved configuration. " - f"Available step keys: {available_keys}. " - f"Make sure override keys match exact step class names or registry names." + @classmethod + def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]: + """Resolve step class from registry or import path. + + This method implements a two-tier resolution strategy: + + **Tier 1: Registry-based resolution** (preferred): + - **If "registry_name" in step_entry**: Look up in ProcessorStepRegistry + - **Advantage**: Faster, no imports needed, guaranteed compatibility + - **Example**: {"registry_name": "normalize_step"} -> Get pre-registered class + - **Error**: KeyError if registry_name not found -> Convert to ImportError + + **Tier 2: Dynamic import fallback**: + - **Else use "class" field**: Full module.ClassName import path + - **Process**: Split "module.path.ClassName" into module + class parts + - **Import**: Use importlib.import_module() + getattr() + - **Example**: "lerobot.processor.normalize.NormalizeStep" + a. Import module: "lerobot.processor.normalize" + b. Get class: getattr(module, "NormalizeStep") + - **step_key**: Use class_name ("NormalizeStep") for overrides + + **Override Key Strategy**: + - Registry steps: Use registry_name ("normalize_step") + - Import steps: Use class_name ("NormalizeStep") + - This allows users to override with: {"normalize_step": {...}} or {"NormalizeStep": {...}} + + **Error Handling**: + - Registry KeyError -> ImportError with registry context + - Import/Attribute errors -> ImportError with helpful suggestions + - All errors include troubleshooting guidance + + Args: + step_entry: The step configuration dictionary (must have "registry_name" or "class") + + Returns: + Tuple of (step_class, step_key) + - step_class: The resolved ProcessorStep class (ready for instantiation) + - step_key: The key used for user overrides (registry_name or class_name) + + Raises: + ImportError: If step class cannot be loaded from registry or import path + """ + if "registry_name" in step_entry: + try: + step_class = ProcessorStepRegistry.get(step_entry["registry_name"]) + return step_class, step_entry["registry_name"] + except KeyError as e: + raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e + else: + # Fallback to dynamic import using the full class path + full_class_path = step_entry["class"] + module_path, class_name = full_class_path.rsplit(".", 1) + + try: + module = importlib.import_module(module_path) + step_class = getattr(module, class_name) + return step_class, class_name + except (ImportError, AttributeError) as e: + raise ImportError( + f"Failed to load processor step '{full_class_path}'. " + f"Make sure the module '{module_path}' is installed and contains class '{class_name}'. " + f"Consider registering the step using @ProcessorStepRegistry.register() for better portability. " + f"Error: {str(e)}" + ) from e + + @classmethod + def _instantiate_step( + cls, + step_entry: dict[str, Any], + step_class: type[ProcessorStep], + step_key: str, + overrides: dict[str, Any], + ) -> ProcessorStep: + """Instantiate a single processor step with config overrides. + + This method handles the configuration merging and instantiation logic: + + **Configuration Merging Strategy**: + 1. **Extract saved config**: Get step_entry.get("config", {}) from saved pipeline + - Example: {"config": {"mean": 0.0, "std": 1.0}} + 2. **Extract user overrides**: Get overrides.get(step_key, {}) for this step + - Example: overrides = {"NormalizeStep": {"mean": 2.0, "device": "cuda"}} + 3. **Merge with priority**: {**saved_cfg, **step_overrides} + - **Override priority**: User values override saved values + - **Result**: {"mean": 2.0, "std": 1.0, "device": "cuda"} + + **Instantiation Process**: + - **Call constructor**: step_class(**merged_cfg) + - **Example**: NormalizeStep(mean=2.0, std=1.0, device="cuda") + + **Error Handling**: + - **Any exception during instantiation**: Convert to ValueError + - **Include context**: step name, attempted config, original error + - **Purpose**: Help users debug configuration issues + - **Common causes**: + a. Invalid parameter types (str instead of float) + b. Missing required parameters + c. Incompatible parameter combinations + + Args: + step_entry: The step configuration from saved config (contains "config" dict) + step_class: The step class to instantiate (already resolved) + step_key: The key used for overrides ("registry_name" or class name) + overrides: User-provided parameter overrides (keyed by step_key) + + Returns: + The instantiated processor step (ready for use) + + Raises: + ValueError: If step cannot be instantiated, with detailed error context + """ + try: + saved_cfg = step_entry.get("config", {}) + step_overrides = overrides.get(step_key, {}) + merged_cfg = {**saved_cfg, **step_overrides} + return step_class(**merged_cfg) + except Exception as e: + step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown")) + raise ValueError( + f"Failed to instantiate processor step '{step_name}' with config: {step_entry.get('config', {})}. " + f"Error: {str(e)}" + ) from e + + @classmethod + def _load_step_state( + cls, + step_instance: ProcessorStep, + step_entry: dict[str, Any], + model_id: str, + base_path: Path | None, + hub_download_kwargs: dict[str, Any], + ) -> None: + """Load state dictionary for a processor step if available. + + This method implements conditional state loading with local/Hub fallback: + + **Precondition Checks** (early return if not met): + 1. **"state_file" in step_entry**: Step config specifies a state file + - **If missing**: Step has no saved state (e.g., stateless transforms) + 2. **hasattr(step_instance, "load_state_dict")**: Step supports state loading + - **If missing**: Step doesn't implement state loading (rare) + + **State File Resolution Strategy**: + 1. **Local file priority**: Check base_path/state_filename exists + - **Advantage**: Faster, no network calls + - **Example**: "/models/my_model/normalize_step_0.safetensors" + - **Use case**: Loading from local saved model directory + + 2. **Hub download fallback**: Download state file from repository + - **When triggered**: Local file not found or base_path is None + - **Process**: Use hf_hub_download with same parameters as config + - **Example**: Download "normalize_step_0.safetensors" from "user/repo" + - **Result**: Downloaded to local cache, path returned + + **State Loading Process**: + - **Load tensors**: Use safetensors.torch.load_file() + - **Apply to step**: Call step_instance.load_state_dict(tensor_dict) + - **In-place modification**: Updates step's internal tensor state + + **Common state file examples**: + - "normalize_step_0.safetensors" - normalization statistics + - "custom_step_1.safetensors" - learned parameters + - "tokenizer_step_2.safetensors" - vocabulary embeddings + + Args: + step_instance: The step instance to load state into (must have load_state_dict) + step_entry: The step configuration dictionary (may contain "state_file") + model_id: The model identifier (used for Hub downloads if needed) + base_path: Local directory path for finding state files (None for Hub-only) + hub_download_kwargs: Parameters for hf_hub_download (tokens, cache, etc.) + + Note: + This method modifies step_instance in-place and returns None. + If state loading fails, exceptions from load_state_dict propagate. + """ + if "state_file" not in step_entry or not hasattr(step_instance, "load_state_dict"): + return + + state_filename = step_entry["state_file"] + + # Try local file first + if base_path and (base_path / state_filename).exists(): + state_path = str(base_path / state_filename) + else: + # Download from Hub + state_path = hf_hub_download( + repo_id=model_id, + filename=state_filename, + repo_type="model", + **hub_download_kwargs, ) - return cls(steps, loaded_config.get("name", "RobotProcessor")) + step_instance.load_state_dict(load_file(state_path)) + + @classmethod + def _validate_overrides_used( + cls, remaining_override_keys: set[str], loaded_config: dict[str, Any] + ) -> None: + """Validate that all provided overrides were used. + + This method ensures user overrides are valid to catch typos and configuration errors: + + **Validation Logic**: + 1. **If remaining_override_keys is empty**: All overrides were used -> Success + - **Early return**: No validation needed + - **Normal case**: User provided correct override keys + + 2. **If remaining_override_keys has entries**: Some overrides unused -> Error + - **Root cause**: User provided keys that don't match any step + - **Common issues**: + a. Typos in step names ("NormalizStep" vs "NormalizeStep") + b. Using wrong key type (class name vs registry name) + c. Step doesn't exist in saved pipeline + + **Helpful Error Generation**: + - **Extract available keys**: Build list of valid override keys from config + a. **Registry steps**: Use "registry_name" directly + b. **Import steps**: Extract class name from "class" field + - Example: "lerobot.processor.normalize.NormalizeStep" -> "NormalizeStep" + - **Error message includes**: + a. Invalid keys provided by user + b. List of valid keys they can use + c. Guidance about registry vs class names + + **Override Key Resolution Rules**: + - Steps with "registry_name": Use registry_name for overrides + - Steps with "class": Use final class name for overrides + - Users must match these exact keys in their overrides dict + + Args: + remaining_override_keys: Override keys that weren't matched to any step + loaded_config: The loaded processor configuration (contains "steps" list) + + Raises: + KeyError: If any override keys were not used, with helpful error message + """ + if not remaining_override_keys: + return + + available_keys = [ + step.get("registry_name") or step["class"].rsplit(".", 1)[1] for step in loaded_config["steps"] + ] + + raise KeyError( + f"Override keys {list(remaining_override_keys)} do not match any step in the saved configuration. " + f"Available step keys: {available_keys}. " + f"Make sure override keys match exact step class names or registry names." + ) + + @classmethod + def _should_suggest_migration(cls, model_path: Path) -> bool: + """Check if directory has JSON files but no processor configs. + + This method implements smart migration detection to avoid false positives: + + **Decision Logic**: + 1. **No JSON files found**: Return False + - **Reason**: Empty directory or only non-config files + - **Example**: Directory with only .safetensors, .md files + - **Action**: No migration needed + + 2. **JSON files exist**: Analyze each file + - **Goal**: Determine if ANY file is a valid processor config + - **Process**: + a. Try to parse each .json file + b. Skip files with JSON parse errors (malformed) + c. Check if parsed config passes _is_processor_config() + - **If ANY valid processor found**: Return False (no migration) + - **If NO valid processors found**: Return True (migration needed) + + **Examples**: + - **No migration**: ["processor.json", "config.json"] where processor.json is valid + - **Migration needed**: ["config.json", "train.json"] where both are model configs + - **No migration**: [] (empty directory) + - **Migration needed**: ["old_model_config.json"] with old LeRobot format + + **Why this works**: + - **Precise detection**: Only suggests migration for actual old LeRobot models + - **Avoids false positives**: Won't trigger on other HuggingFace model types + - **Graceful handling**: Ignores malformed JSON files + + Args: + model_path: Path to local directory to analyze + + Returns: + True if directory has JSON configs but none are processor configs (migration needed) + False if no JSON files or at least one valid processor config exists + """ + json_files = list(model_path.glob("*.json")) + if len(json_files) == 0: + return False + + # Check if any JSON file is a processor config + for json_file in json_files: + try: + with open(json_file) as f: + config = json.load(f) + + if cls._is_processor_config(config): + return False # Found at least one processor config, no migration needed + + except (json.JSONDecodeError, OSError): + # Skip files that can't be parsed as JSON + continue + + # Have JSON files but no processor configs - suggest migration + return True + + @classmethod + def _is_processor_config(cls, config: dict) -> bool: + """Check if config follows DataProcessorPipeline format. + + This method validates the processor configuration structure: + + **Required Structure Validation**: + 1. **"steps" field existence**: Must have top-level "steps" key + - **If missing**: Not a processor config (e.g., model config, train config) + - **Example invalid**: {"type": "act", "hidden_dim": 256} + + 2. **"steps" field type**: Must be a list, not other types + - **If not list**: Invalid format + - **Example invalid**: {"steps": "some_string"} or {"steps": {"key": "value"}} + + 3. **Empty steps validation**: Empty list is valid + - **If len(steps) == 0**: Return True immediately + - **Use case**: Empty processor pipeline (no-op) + - **Example valid**: {"name": "EmptyProcessor", "steps": []} + + **Individual Step Validation** (for non-empty steps): + For each step in the steps list: + 1. **Step type**: Must be a dictionary + - **If not dict**: Invalid step format + - **Example invalid**: ["string_step", 123, true] + + 2. **Step identifier**: Must have either "class" OR "registry_name" + - **"registry_name"**: Registered step (preferred) + Example: {"registry_name": "normalize_step", "config": {...}} + - **"class"**: Full import path + Example: {"class": "lerobot.processor.normalize.NormalizeStep"} + - **If neither**: Invalid step (can't resolve class) + - **If both**: Also valid (registry_name takes precedence) + + **Valid Processor Config Examples**: + - {"steps": []} - Empty processor + - {"steps": [{"registry_name": "normalize"}]} - Registry step + - {"steps": [{"class": "my.module.Step"}]} - Import step + - {"name": "MyProcessor", "steps": [...]} - With name + + **Invalid Config Examples**: + - {"type": "act"} - Missing "steps" + - {"steps": "normalize"} - Steps not a list + - {"steps": [{}]} - Step missing class/registry_name + - {"steps": ["string"]} - Step not a dict + + Args: + config: The configuration dictionary to validate + + Returns: + True if config follows valid DataProcessorPipeline format, False otherwise + """ + # Must have a "steps" field with a list of step configurations + if not isinstance(config.get("steps"), list): + return False + + steps = config["steps"] + if len(steps) == 0: + return True # Empty processor is valid + + # Each step must be a dict with either "class" or "registry_name" + for step in steps: + if not isinstance(step, dict): + return False + if not ("class" in step or "registry_name" in step): + return False + + return True + + @classmethod + def _suggest_processor_migration(cls, model_path: str | Path, original_error: str) -> None: + """Raise migration error when we detect JSON files but no processor configs. + + This method is called when migration detection determines that a model + directory contains configuration files but none are valid processor configs. + This typically indicates an old LeRobot model that needs migration. + + **When this is called**: + - User tries to load DataProcessorPipeline from local directory + - Directory contains JSON configuration files + - None of the JSON files follow processor config format + - _should_suggest_migration() returned True + + **Migration Command Generation**: + - Constructs exact command user needs to run + - Uses the migration script: migrate_policy_normalization.py + - Includes the model path automatically + - Example: "python src/lerobot/processor/migrate_policy_normalization.py --pretrained-path /models/old_model" + + **Error Structure**: + - **Always raises**: ProcessorMigrationError (never returns) + - **Includes**: model_path, migration_command, original_error + - **Purpose**: Force user attention to migration need + - **User experience**: Clear actionable error with exact command to run + + **Migration Process**: + The suggested command will: + 1. Extract normalization stats from old model + 2. Create new processor configs (preprocessor + postprocessor) + 3. Remove normalization layers from model + 4. Save migrated model with processor pipeline + + Args: + model_path: Path to the model directory needing migration + original_error: The error that triggered migration detection (for context) + + Raises: + ProcessorMigrationError: Always raised (this method never returns normally) + """ + migration_command = ( + f"python src/lerobot/processor/migrate_policy_normalization.py --pretrained-path {model_path}" + ) + + raise ProcessorMigrationError(model_path, migration_command, original_error) def __len__(self) -> int: - """Return the number of steps in the processor.""" + """Returns the number of steps in the pipeline.""" return len(self.steps) - def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor: - """Indexing helper exposing underlying steps. - * ``int`` – returns the idx-th ProcessorStep. - * ``slice`` – returns a new RobotProcessor with the sliced steps. + def __getitem__(self, idx: int | slice) -> ProcessorStep | DataProcessorPipeline[TInput, TOutput]: + """Retrieves a step or a sub-pipeline by index or slice. + + Args: + idx: An integer index or a slice object. + + Returns: + A `ProcessorStep` if `idx` is an integer, or a new `DataProcessorPipeline` + containing the sliced steps. """ if isinstance(idx, slice): - return RobotProcessor(self.steps[idx], self.name) + # Return a new pipeline instance with the sliced steps. + return DataProcessorPipeline( + steps=self.steps[idx], + name=self.name, + to_transition=self.to_transition, + to_output=self.to_output, + before_step_hooks=self.before_step_hooks.copy(), + after_step_hooks=self.after_step_hooks.copy(), + ) return self.steps[idx] def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): - """Attach fn to be executed before every processor step.""" + """Registers a function to be called before each step. + + Args: + fn: A callable that accepts the step index and the current transition. + """ self.before_step_hooks.append(fn) def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): - """Remove a previously registered before_step hook. + """Unregisters a 'before_step' hook. Args: - fn: The exact function reference that was registered. Must be the same object. + fn: The exact function object that was previously registered. Raises: - ValueError: If the hook is not found in the registered hooks. + ValueError: If the hook is not found in the list. """ try: self.before_step_hooks.remove(fn) @@ -759,17 +1262,21 @@ class RobotProcessor(ModelHubMixin): ) from None def register_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): - """Attach fn to be executed after every processor step.""" + """Registers a function to be called after each step. + + Args: + fn: A callable that accepts the step index and the current transition. + """ self.after_step_hooks.append(fn) def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): - """Remove a previously registered after_step hook. + """Unregisters an 'after_step' hook. Args: - fn: The exact function reference that was registered. Must be the same object. + fn: The exact function object that was previously registered. Raises: - ValueError: If the hook is not found in the registered hooks. + ValueError: If the hook is not found in the list. """ try: self.after_step_hooks.remove(fn) @@ -779,13 +1286,13 @@ class RobotProcessor(ModelHubMixin): ) from None def reset(self): - """Clear state in every step that implements ``reset()`` and fire registered hooks.""" + """Resets the state of all stateful steps in the pipeline.""" for step in self.steps: if hasattr(step, "reset"): - step.reset() # type: ignore[attr-defined] + step.reset() def __repr__(self) -> str: - """Return a readable string representation of the processor.""" + """Provides a concise string representation of the pipeline.""" step_names = [step.__class__.__name__ for step in self.steps] if not step_names: @@ -793,472 +1300,417 @@ class RobotProcessor(ModelHubMixin): elif len(step_names) <= 3: steps_repr = f"steps={len(step_names)}: [{', '.join(step_names)}]" else: - # Show first 2 and last 1 with ellipsis for long lists + # For long pipelines, show the first, second, and last steps. displayed = f"{step_names[0]}, {step_names[1]}, ..., {step_names[-1]}" steps_repr = f"steps={len(step_names)}: [{displayed}]" parts = [f"name='{self.name}'", steps_repr] - return f"RobotProcessor({', '.join(parts)})" + return f"DataProcessorPipeline({', '.join(parts)})" def __post_init__(self): + """Validates that all provided steps are instances of `ProcessorStep`.""" for i, step in enumerate(self.steps): - if not callable(step): - raise TypeError( - f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition" - ) + if not isinstance(step, ProcessorStep): + raise TypeError(f"Step {i} ({type(step).__name__}) must inherit from ProcessorStep") - fc = getattr(step, "feature_contract", None) - if not callable(fc): - raise TypeError( - f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]" - ) + def transform_features( + self, initial_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Applies feature transformations from all steps sequentially. - def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + This method propagates a feature description dictionary through each step's + `transform_features` method, allowing the pipeline to statically determine + the output feature specification without processing any real data. + + Args: + initial_features: A dictionary describing the initial features. + + Returns: + The final feature description after all transformations. """ - Apply ALL steps in order. Each step must implement - feature_contract(features) and return a dict (full or incremental schema). - """ - features: dict[str, PolicyFeature] = deepcopy(initial_features) + features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = deepcopy(initial_features) for _, step in enumerate(self.steps): - out = step.feature_contract(features) - if not isinstance(out, dict): - raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]") + out = step.transform_features(features) features = out return features - -class ObservationProcessor: - """Base class for processors that modify only the observation component of a transition. - - Subclasses should override the `observation` method to implement custom observation processing. - This class handles the boilerplate of extracting and reinserting the processed observation - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class MyObservationScaler(ObservationProcessor): - def __init__(self, scale_factor): - self.scale_factor = scale_factor - - def observation(self, observation): - return observation * self.scale_factor - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific observation processing logic. - """ - - def observation(self, observation): - """Process the observation component. + # Convenience methods for processing individual parts of a transition. + def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]: + """Processes only the observation part of a transition through the pipeline. Args: - observation: The observation to process + observation: The observation dictionary. Returns: - The processed observation + The processed observation dictionary. """ - return observation + transition: EnvTransition = create_transition(observation=observation) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.OBSERVATION] + + def process_action( + self, action: PolicyAction | RobotAction | EnvAction + ) -> PolicyAction | RobotAction | EnvAction: + """Processes only the action part of a transition through the pipeline. + + Args: + action: The action data. + + Returns: + The processed action. + """ + transition: EnvTransition = create_transition(action=action) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.ACTION] + + def process_reward(self, reward: float | torch.Tensor) -> float | torch.Tensor: + """Processes only the reward part of a transition through the pipeline. + + Args: + reward: The reward value. + + Returns: + The processed reward. + """ + transition: EnvTransition = create_transition(reward=reward) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.REWARD] + + def process_done(self, done: bool | torch.Tensor) -> bool | torch.Tensor: + """Processes only the done flag of a transition through the pipeline. + + Args: + done: The done flag. + + Returns: + The processed done flag. + """ + transition: EnvTransition = create_transition(done=done) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.DONE] + + def process_truncated(self, truncated: bool | torch.Tensor) -> bool | torch.Tensor: + """Processes only the truncated flag of a transition through the pipeline. + + Args: + truncated: The truncated flag. + + Returns: + The processed truncated flag. + """ + transition: EnvTransition = create_transition(truncated=truncated) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.TRUNCATED] + + def process_info(self, info: dict[str, Any]) -> dict[str, Any]: + """Processes only the info dictionary of a transition through the pipeline. + + Args: + info: The info dictionary. + + Returns: + The processed info dictionary. + """ + transition: EnvTransition = create_transition(info=info) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.INFO] + + def process_complementary_data(self, complementary_data: dict[str, Any]) -> dict[str, Any]: + """Processes only the complementary data part of a transition through the pipeline. + + Args: + complementary_data: The complementary data dictionary. + + Returns: + The processed complementary data dictionary. + """ + transition: EnvTransition = create_transition(complementary_data=complementary_data) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.COMPLEMENTARY_DATA] + + +# Type aliases for semantic clarity. +RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput] +PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput] + + +class ObservationProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that specifically targets the observation in a transition.""" + + @abstractmethod + def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + """Processes an observation dictionary. Subclasses must implement this method. + + Args: + observation: The input observation dictionary from the transition. + + Returns: + The processed observation dictionary. + """ + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition.get(TransitionKey.OBSERVATION) - if observation is None: - return transition + """Applies the `observation` method to the transition's observation.""" + self._current_transition = transition.copy() + new_transition = self._current_transition - processed_observation = self.observation(observation) - # Create a new transition dict with the processed observation - new_transition = transition.copy() + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is None or not isinstance(observation, dict): + raise ValueError("ObservationProcessorStep requires an observation in the transition.") + + processed_observation = self.observation(observation.copy()) new_transition[TransitionKey.OBSERVATION] = processed_observation return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class ActionProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that specifically targets the action in a transition.""" - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class ActionProcessor: - """Base class for processors that modify only the action component of a transition. - - Subclasses should override the `action` method to implement custom action processing. - This class handles the boilerplate of extracting and reinserting the processed action - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class ActionClipping(ActionProcessor): - def __init__(self, min_val, max_val): - self.min_val = min_val - self.max_val = max_val - - def action(self, action): - return np.clip(action, self.min_val, self.max_val) - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific action processing logic. - """ - - def action(self, action): - """Process the action component. + @abstractmethod + def action( + self, action: PolicyAction | RobotAction | EnvAction + ) -> PolicyAction | RobotAction | EnvAction: + """Processes an action. Subclasses must implement this method. Args: - action: The action to process + action: The input action from the transition. Returns: - The processed action + The processed action. """ - return action + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - action = transition.get(TransitionKey.ACTION) + """Applies the `action` method to the transition's action.""" + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) if action is None: - return transition + raise ValueError("ActionProcessorStep requires an action in the transition.") processed_action = self.action(action) - # Create a new transition dict with the processed action - new_transition = transition.copy() new_transition[TransitionKey.ACTION] = processed_action return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class RobotActionProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` for processing a `RobotAction` (a dictionary).""" - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class RewardProcessor: - """Base class for processors that modify only the reward component of a transition. - - Subclasses should override the `reward` method to implement custom reward processing. - This class handles the boilerplate of extracting and reinserting the processed reward - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class RewardScaler(RewardProcessor): - def __init__(self, scale_factor): - self.scale_factor = scale_factor - - def reward(self, reward): - return reward * self.scale_factor - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific reward processing logic. - """ - - def reward(self, reward): - """Process the reward component. + @abstractmethod + def action(self, action: RobotAction) -> RobotAction: + """Processes a `RobotAction`. Subclasses must implement this method. Args: - reward: The reward to process + action: The input `RobotAction` dictionary. Returns: - The processed reward + The processed `RobotAction`. """ - return reward + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - reward = transition.get(TransitionKey.REWARD) + """Applies the `action` method to the transition's action, ensuring it's a `RobotAction`.""" + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) + if action is None or not isinstance(action, dict): + raise ValueError(f"Action should be a RobotAction type (dict), but got {type(action)}") + + processed_action = self.action(action.copy()) + new_transition[TransitionKey.ACTION] = processed_action + return new_transition + + +class PolicyActionProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` for processing a `PolicyAction` (a tensor or dict of tensors).""" + + @abstractmethod + def action(self, action: PolicyAction) -> PolicyAction: + """Processes a `PolicyAction`. Subclasses must implement this method. + + Args: + action: The input `PolicyAction`. + + Returns: + The processed `PolicyAction`. + """ + ... + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Applies the `action` method to the transition's action, ensuring it's a `PolicyAction`.""" + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type (tensor), but got {type(action)}") + + processed_action = self.action(action) + new_transition[TransitionKey.ACTION] = processed_action + return new_transition + + +class RewardProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that specifically targets the reward in a transition.""" + + @abstractmethod + def reward(self, reward) -> float | torch.Tensor: + """Processes a reward. Subclasses must implement this method. + + Args: + reward: The input reward from the transition. + + Returns: + The processed reward. + """ + ... + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Applies the `reward` method to the transition's reward.""" + self._current_transition = transition.copy() + new_transition = self._current_transition + + reward = new_transition.get(TransitionKey.REWARD) if reward is None: - return transition + raise ValueError("RewardProcessorStep requires a reward in the transition.") processed_reward = self.reward(reward) - # Create a new transition dict with the processed reward - new_transition = transition.copy() new_transition[TransitionKey.REWARD] = processed_reward return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class DoneProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that specifically targets the 'done' flag in a transition.""" - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class DoneProcessor: - """Base class for processors that modify only the done flag of a transition. - - Subclasses should override the `done` method to implement custom done flag processing. - This class handles the boilerplate of extracting and reinserting the processed done flag - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class TimeoutDone(DoneProcessor): - def __init__(self, max_steps): - self.steps = 0 - self.max_steps = max_steps - - def done(self, done): - self.steps += 1 - return done or self.steps >= self.max_steps - - def reset(self): - self.steps = 0 - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific done flag processing logic. - """ - - def done(self, done): - """Process the done flag. + @abstractmethod + def done(self, done) -> bool | torch.Tensor: + """Processes a 'done' flag. Subclasses must implement this method. Args: - done: The done flag to process + done: The input 'done' flag from the transition. Returns: - The processed done flag + The processed 'done' flag. """ - return done + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - done = transition.get(TransitionKey.DONE) + """Applies the `done` method to the transition's 'done' flag.""" + self._current_transition = transition.copy() + new_transition = self._current_transition + + done = new_transition.get(TransitionKey.DONE) if done is None: - return transition + raise ValueError("DoneProcessorStep requires a done flag in the transition.") processed_done = self.done(done) - # Create a new transition dict with the processed done flag - new_transition = transition.copy() new_transition[TransitionKey.DONE] = processed_done return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class TruncatedProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that specifically targets the 'truncated' flag in a transition.""" - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class TruncatedProcessor: - """Base class for processors that modify only the truncated flag of a transition. - - Subclasses should override the `truncated` method to implement custom truncated flag processing. - This class handles the boilerplate of extracting and reinserting the processed truncated flag - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class EarlyTruncation(TruncatedProcessor): - def __init__(self, threshold): - self.threshold = threshold - - def truncated(self, truncated): - # Additional truncation condition - return truncated or some_condition > self.threshold - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific truncated flag processing logic. - """ - - def truncated(self, truncated): - """Process the truncated flag. + @abstractmethod + def truncated(self, truncated) -> bool | torch.Tensor: + """Processes a 'truncated' flag. Subclasses must implement this method. Args: - truncated: The truncated flag to process + truncated: The input 'truncated' flag from the transition. Returns: - The processed truncated flag + The processed 'truncated' flag. """ - return truncated + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - truncated = transition.get(TransitionKey.TRUNCATED) + """Applies the `truncated` method to the transition's 'truncated' flag.""" + self._current_transition = transition.copy() + new_transition = self._current_transition + + truncated = new_transition.get(TransitionKey.TRUNCATED) if truncated is None: - return transition + raise ValueError("TruncatedProcessorStep requires a truncated flag in the transition.") processed_truncated = self.truncated(truncated) - # Create a new transition dict with the processed truncated flag - new_transition = transition.copy() new_transition[TransitionKey.TRUNCATED] = processed_truncated return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class InfoProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that specifically targets the 'info' dictionary in a transition.""" - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class InfoProcessor: - """Base class for processors that modify only the info dictionary of a transition. - - Subclasses should override the `info` method to implement custom info processing. - This class handles the boilerplate of extracting and reinserting the processed info - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class InfoAugmenter(InfoProcessor): - def __init__(self): - self.step_count = 0 - - def info(self, info): - info = info.copy() # Create a copy to avoid modifying the original - info["steps"] = self.step_count - self.step_count += 1 - return info - - def reset(self): - self.step_count = 0 - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific info dictionary processing logic. - """ - - def info(self, info): - """Process the info dictionary. + @abstractmethod + def info(self, info) -> dict[str, Any]: + """Processes an 'info' dictionary. Subclasses must implement this method. Args: - info: The info dictionary to process + info: The input 'info' dictionary from the transition. Returns: - The processed info dictionary + The processed 'info' dictionary. """ - return info + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - info = transition.get(TransitionKey.INFO) - if info is None: - return transition + """Applies the `info` method to the transition's 'info' dictionary.""" + self._current_transition = transition.copy() + new_transition = self._current_transition - processed_info = self.info(info) - # Create a new transition dict with the processed info - new_transition = transition.copy() + info = new_transition.get(TransitionKey.INFO) + if info is None or not isinstance(info, dict): + raise ValueError("InfoProcessorStep requires an info dictionary in the transition.") + + processed_info = self.info(info.copy()) new_transition[TransitionKey.INFO] = processed_info return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class ComplementaryDataProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that targets the 'complementary_data' in a transition.""" - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class ComplementaryDataProcessor: - """Base class for processors that modify only the complementary data of a transition. - - Subclasses should override the `complementary_data` method to implement custom complementary data processing. - This class handles the boilerplate of extracting and reinserting the processed complementary data - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - """ - - def complementary_data(self, complementary_data): - """Process the complementary data. + @abstractmethod + def complementary_data(self, complementary_data) -> dict[str, Any]: + """Processes a 'complementary_data' dictionary. Subclasses must implement this method. Args: - complementary_data: The complementary data to process + complementary_data: The input 'complementary_data' from the transition. Returns: - The processed complementary data + The processed 'complementary_data' dictionary. """ - return complementary_data + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) - if complementary_data is None: - return transition + """Applies the `complementary_data` method to the transition's data.""" + self._current_transition = transition.copy() + new_transition = self._current_transition - processed_complementary_data = self.complementary_data(complementary_data) - # Create a new transition dict with the processed complementary data - new_transition = transition.copy() + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None or not isinstance(complementary_data, dict): + raise ValueError("ComplementaryDataProcessorStep requires complementary data in the transition.") + + processed_complementary_data = self.complementary_data(complementary_data.copy()) new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class IdentityProcessorStep(ProcessorStep): + """A no-op processor step that returns the input transition and features unchanged. - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class IdentityProcessor: - """Identity processor that does nothing.""" + This can be useful as a placeholder or for debugging purposes. + """ def __call__(self, transition: EnvTransition) -> EnvTransition: + """Returns the transition without modification.""" return transition - def get_config(self) -> dict[str, Any]: - return {} - - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Returns the features without modification.""" return features diff --git a/src/lerobot/processor/policy_robot_bridge.py b/src/lerobot/processor/policy_robot_bridge.py new file mode 100644 index 00000000..74c53499 --- /dev/null +++ b/src/lerobot/processor/policy_robot_bridge.py @@ -0,0 +1,52 @@ +from dataclasses import asdict, dataclass +from typing import Any + +import torch + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction + + +@dataclass +@ProcessorStepRegistry.register("robot_action_to_policy_action_processor") +class RobotActionToPolicyActionProcessorStep(ActionProcessorStep): + """Processor step to map a dictionary to a tensor action.""" + + motor_names: list[str] + + def action(self, action: RobotAction) -> PolicyAction: + if len(self.motor_names) != len(action): + raise ValueError(f"Action must have {len(self.motor_names)} elements, got {len(action)}") + return torch.tensor([action[f"{name}.pos"] for name in self.motor_names]) + + def get_config(self) -> dict[str, Any]: + return asdict(self) + + def transform_features(self, features): + features[PipelineFeatureType.ACTION]["action"] = PolicyFeature( + type=FeatureType.ACTION, shape=(len(self.motor_names),) + ) + return features + + +@dataclass +@ProcessorStepRegistry.register("policy_action_to_robot_action_processor") +class PolicyActionToRobotActionProcessorStep(ActionProcessorStep): + """Processor step to map a policy action to a robot action.""" + + motor_names: list[str] + + def action(self, action: PolicyAction) -> RobotAction: + if len(self.motor_names) != len(action): + raise ValueError(f"Action must have {len(self.motor_names)} elements, got {len(action)}") + return {f"{name}.pos": action[i] for i, name in enumerate(self.motor_names)} + + def get_config(self) -> dict[str, Any]: + return asdict(self) + + def transform_features(self, features): + for name in self.motor_names: + features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + return features diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py index 4fe4105a..6cae5921 100644 --- a/src/lerobot/processor/rename_processor.py +++ b/src/lerobot/processor/rename_processor.py @@ -13,20 +13,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from dataclasses import dataclass, field from typing import Any -from lerobot.configs.types import PolicyFeature -from lerobot.processor.pipeline import ( - ObservationProcessor, - ProcessorStepRegistry, -) +from lerobot.configs.types import PipelineFeatureType, PolicyFeature + +from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @dataclass -@ProcessorStepRegistry.register(name="rename_processor") -class RenameProcessor(ObservationProcessor): - """Rename processor that renames keys in the observation.""" +@ProcessorStepRegistry.register(name="rename_observations_processor") +class RenameObservationsProcessorStep(ObservationProcessorStep): + """ + A processor step that renames keys in an observation dictionary. + + This step is useful for creating a standardized data interface by mapping keys + from an environment's format to the format expected by a LeRobot policy or + other downstream components. + + Attributes: + rename_map: A dictionary mapping from old key names to new key names. + Keys present in an observation that are not in this map will + be kept with their original names. + """ rename_map: dict[str, str] = field(default_factory=dict) @@ -43,9 +53,41 @@ class RenameProcessor(ObservationProcessor): def get_config(self) -> dict[str, Any]: return {"rename_map": self.rename_map} - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: """Transforms: - Each key in the observation that appears in `rename_map` is renamed to its value. - Keys not in `rename_map` remain unchanged. """ - return {self.rename_map.get(k, k): v for k, v in features.items()} + new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = features.copy() + new_features[PipelineFeatureType.OBSERVATION] = { + self.rename_map.get(k, k): v for k, v in features[PipelineFeatureType.OBSERVATION].items() + } + return new_features + + +def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]: + """ + Renames the top-level keys in a statistics dictionary using a provided mapping. + + This is a helper function typically used to keep normalization statistics + consistent with renamed observation or action features. It performs a defensive + deep copy to avoid modifying the original `stats` dictionary. + + Args: + stats: A nested dictionary of statistics, where top-level keys are + feature names (e.g., `{"observation.state": {"mean": 0.5}}`). + rename_map: A dictionary mapping old feature names to new feature names. + + Returns: + A new statistics dictionary with its top-level keys renamed. Returns an + empty dictionary if the input `stats` is empty. + """ + if not stats: + return {} + renamed: dict[str, dict[str, Any]] = {} + for old_key, sub_stats in stats.items(): + new_key = rename_map.get(old_key, old_key) + renamed[new_key] = deepcopy(sub_stats) if sub_stats is not None else {} + return renamed diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py new file mode 100644 index 00000000..23db7b5e --- /dev/null +++ b/src/lerobot/processor/tokenizer_processor.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script defines a processor for tokenizing natural language instructions from an environment transition. + +It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into +token IDs and attention masks, which are then added to the observation dictionary. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import torch + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS +from lerobot.utils.import_utils import _transformers_available + +from .core import EnvTransition, TransitionKey +from .pipeline import ObservationProcessorStep, ProcessorStepRegistry + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _transformers_available: + from transformers import AutoTokenizer +else: + AutoTokenizer = None + + +@dataclass +@ProcessorStepRegistry.register(name="tokenizer_processor") +class TokenizerProcessorStep(ObservationProcessorStep): + """ + Processor step to tokenize a natural language task description. + + This step extracts a task string from the `complementary_data` of an `EnvTransition`, + tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting + token IDs and attention mask to the `observation` dictionary. + + Requires the `transformers` library to be installed. + + Attributes: + tokenizer_name: The name of a pretrained tokenizer from the Hugging Face Hub (e.g., "bert-base-uncased"). + tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored. + max_length: The maximum length to pad or truncate sequences to. + task_key: The key in `complementary_data` where the task string is stored. + padding_side: The side to pad on ('left' or 'right'). + padding: The padding strategy ('max_length', 'longest', etc.). + truncation: Whether to truncate sequences longer than `max_length`. + input_tokenizer: The internal tokenizer instance, loaded during initialization. + """ + + tokenizer_name: str | None = None + tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency + max_length: int = 512 + task_key: str = "task" + padding_side: str = "right" + padding: str = "max_length" + truncation: bool = True + + # Internal tokenizer instance (not part of the config) + input_tokenizer: Any = field(default=None, init=False, repr=False) + + def __post_init__(self): + """ + Initializes the tokenizer after the dataclass is created. + + It checks for the availability of the `transformers` library and loads the tokenizer + either from a provided object or by name from the Hugging Face Hub. + + Raises: + ImportError: If the `transformers` library is not installed. + ValueError: If neither `tokenizer` nor `tokenizer_name` is provided. + """ + if not _transformers_available: + raise ImportError( + "The 'transformers' library is not installed. " + "Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessorStep." + ) + + if self.tokenizer is not None: + # Use provided tokenizer object directly + self.input_tokenizer = self.tokenizer + elif self.tokenizer_name is not None: + if AutoTokenizer is None: + raise ImportError("AutoTokenizer is not available") + self.input_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + else: + raise ValueError( + "Either 'tokenizer' or 'tokenizer_name' must be provided. " + "Pass a tokenizer object directly or a tokenizer name to auto-load." + ) + + def get_task(self, transition: EnvTransition) -> list[str] | None: + """ + Extracts the task description(s) from the transition's complementary data. + + Args: + transition: The environment transition. + + Returns: + A list of task strings, or None if the task key is not found or the value is None. + """ + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + raise ValueError("Complementary data is None so no task can be extracted from it") + + task = complementary_data[self.task_key] + if task is None: + raise ValueError("Task extracted from Complementary data is None") + + # Standardize to a list of strings for the tokenizer + if isinstance(task, str): + return [task] + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + return task + + return None + + def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + """ + Tokenizes the task description and adds it to the observation dictionary. + + This method retrieves the task, tokenizes it, moves the resulting tensors to the + same device as other data in the transition, and updates the observation. + + Args: + observation: The original observation dictionary. + + Returns: + The updated observation dictionary including token IDs and an attention mask. + """ + task = self.get_task(self.transition) + if task is None: + raise ValueError("Task cannot be None") + + # Tokenize the task (this will create CPU tensors) + tokenized_prompt = self._tokenize_text(task) + + # Detect the device from existing tensors in the transition to ensure consistency + target_device = self._detect_device(self.transition) + + # Move new tokenized tensors to the detected device + if target_device is not None: + tokenized_prompt = { + k: v.to(target_device) if isinstance(v, torch.Tensor) else v + for k, v in tokenized_prompt.items() + } + + # Create a new observation dict to avoid modifying the original in place + new_observation = dict(observation) + + # Add tokenized data to the observation + new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"] + new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) + + return new_observation + + def _detect_device(self, transition: EnvTransition) -> torch.device | None: + """ + Detects the torch.device from existing tensors in the transition. + + It checks tensors in the observation dictionary first, then the action tensor. + + Args: + transition: The environment transition. + + Returns: + The detected `torch.device`, or None if no tensors are found. + """ + # Check observation tensors first (most likely place to find tensors) + observation = transition.get(TransitionKey.OBSERVATION) + if observation: + for value in observation.values(): + if isinstance(value, torch.Tensor): + return value.device + + # Fallback to checking the action tensor + action = transition.get(TransitionKey.ACTION) + if isinstance(action, torch.Tensor): + return action.device + + return None # No tensors found, default will be CPU + + def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]: + """ + A wrapper around the tokenizer call. + + Args: + text: A string or list of strings to tokenize. + + Returns: + A dictionary containing tokenized 'input_ids' and 'attention_mask' as PyTorch tensors. + """ + return self.input_tokenizer( + text, + max_length=self.max_length, + truncation=self.truncation, + padding=self.padding, + padding_side=self.padding_side, + return_tensors="pt", + ) + + def get_config(self) -> dict[str, Any]: + """ + Returns the serializable configuration of the processor. + + Note: The tokenizer object itself is not serialized. If the processor was initialized + with a tokenizer name, that name will be included in the config. + + Returns: + A dictionary with the processor's configuration parameters. + """ + config = { + "max_length": self.max_length, + "task_key": self.task_key, + "padding_side": self.padding_side, + "padding": self.padding, + "truncation": self.truncation, + } + + # Only save tokenizer_name if it was used to create the tokenizer + if self.tokenizer_name is not None and self.tokenizer is None: + config["tokenizer_name"] = self.tokenizer_name + + return config + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Adds feature definitions for the language tokens and attention mask. + + This updates the policy features dictionary to include the new data added to the + observation, ensuring downstream components are aware of their shape and type. + + Args: + features: The dictionary of existing policy features. + + Returns: + The updated dictionary of policy features. + """ + # Add a feature for the token IDs if it doesn't already exist + if OBS_LANGUAGE_TOKENS not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TOKENS] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_length,) + ) + + # Add a feature for the attention mask if it doesn't already exist + if OBS_LANGUAGE_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_length,) + ) + + return features diff --git a/src/lerobot/record.py b/src/lerobot/record.py index f39a05fb..d09b017e 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -21,11 +21,12 @@ Example: lerobot-record \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ - --robot.cameras="{laptop: {type: opencv, camera_index: 0, width: 640, height: 480}}" \ + --robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ --robot.id=black \ - --dataset.repo_id=aliberts/record-test \ + --dataset.repo_id=/ \ --dataset.num_episodes=2 \ --dataset.single_task="Grab the cube" \ + --display_data=true # <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \ # --teleop.type=so100_leader \ # --teleop.port=/dev/tty.usbmodem58760431551 \ @@ -59,9 +60,10 @@ lerobot-record \ import logging import time -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from pathlib import Path from pprint import pformat +from typing import Any from lerobot.cameras import ( # noqa: F401 CameraConfig, # noqa: F401 @@ -72,10 +74,20 @@ from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.image_writer import safe_stop_image_writer from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts from lerobot.datasets.video_utils import VideoEncodingManager -from lerobot.policies.factory import make_policy +from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.processor import ( + PolicyAction, + PolicyProcessorPipeline, + RobotAction, + RobotObservation, + RobotProcessorPipeline, + make_default_processors, +) +from lerobot.processor.rename_processor import rename_stats from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -149,6 +161,8 @@ class DatasetRecordConfig: # Number of episodes to record before batch encoding videos # Set to 1 for immediate encoding (default behavior), or higher for batched encoding video_encoding_batch_size: int = 1 + # Rename map for the observation to override the image and state keys + rename_map: dict[str, str] = field(default_factory=dict) def __post_init__(self): if self.single_task is None: @@ -187,14 +201,55 @@ class RecordConfig: return ["policy"] +""" --------------- record_loop() data flow -------------------------- + [ Robot ] + V + [ robot.get_observation() ] ---> raw_obs + V + [ robot_observation_processor ] ---> processed_obs + V + .-----( ACTION LOGIC )------------------. + V V + [ From Teleoperator ] [ From Policy ] + | | + | [teleop.get_action] -> raw_action | [predict_action] + | | | | + | V | V + | [teleop_action_processor] | | + | | | | + '---> processed_teleop_action '---> processed_policy_action + | | + '-------------------------.-------------' + V + [ robot_action_processor ] --> robot_action_to_send + V + [ robot.send_action() ] -- (Robot Executes) + V + ( Save to Dataset ) + V + ( Rerun Log / Loop Wait ) +""" + + @safe_stop_image_writer def record_loop( robot: Robot, events: dict, fps: int, + teleop_action_processor: RobotProcessorPipeline[ + tuple[RobotAction, RobotObservation], RobotAction + ], # runs after teleop + robot_action_processor: RobotProcessorPipeline[ + tuple[RobotAction, RobotObservation], RobotAction + ], # runs before robot + robot_observation_processor: RobotProcessorPipeline[ + RobotObservation, RobotObservation + ], # runs after robot dataset: LeRobotDataset | None = None, teleop: Teleoperator | list[Teleoperator] | None = None, policy: PreTrainedPolicy | None = None, + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None, + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None, control_time_s: int | None = None, single_task: str | None = None, display_data: bool = False, @@ -226,9 +281,11 @@ def record_loop( "For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot." ) - # if policy is given it needs cleaning up - if policy is not None: + # Reset policy and processor if they are provided + if policy is not None and preprocessor is not None and postprocessor is not None: policy.reset() + preprocessor.reset() + postprocessor.reset() timestamp = 0 start_episode_t = time.perf_counter() @@ -239,32 +296,46 @@ def record_loop( events["exit_early"] = False break - observation = robot.get_observation() + # Get robot observation + obs = robot.get_observation() + + # Applies a pipeline to the raw robot observation, default is IdentityProcessor + obs_processed = robot_observation_processor(obs) if policy is not None or dataset is not None: - observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") + observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix="observation") - if policy is not None: + # Get action from either policy or teleop + if policy is not None and preprocessor is not None and postprocessor is not None: action_values = predict_action( - observation_frame, - policy, - get_safe_torch_device(policy.config.device), - policy.config.use_amp, + observation=observation_frame, + policy=policy, + device=get_safe_torch_device(policy.config.device), + preprocessor=preprocessor, + postprocessor=postprocessor, + use_amp=policy.config.use_amp, task=single_task, robot_type=robot.robot_type, ) - action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)} + + action_names = dataset.features["action"]["names"] + act_processed_policy: RobotAction = { + f"{name}": float(action_values[i]) for i, name in enumerate(action_names) + } + elif policy is None and isinstance(teleop, Teleoperator): - action = teleop.get_action() + act = teleop.get_action() + + # Applies a pipeline to the raw teleop action, default is IdentityProcessor + act_processed_teleop = teleop_action_processor((act, obs)) + elif policy is None and isinstance(teleop, list): - # TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline) arm_action = teleop_arm.get_action() arm_action = {f"arm_{k}": v for k, v in arm_action.items()} - keyboard_action = teleop_keyboard.get_action() base_action = robot._from_keyboard_to_base_action(keyboard_action) - - action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action + act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action + act_processed_teleop = teleop_action_processor((act, obs)) else: logging.info( "No policy or teleoperator provided, skipping action generation." @@ -273,17 +344,28 @@ def record_loop( ) continue - # Action can eventually be clipped using `max_relative_target`, - # so action actually sent is saved in the dataset. - sent_action = robot.send_action(action) + # Applies a pipeline to the action, default is IdentityProcessor + if policy is not None and act_processed_policy is not None: + action_values = act_processed_policy + robot_action_to_send = robot_action_processor((act_processed_policy, obs)) + else: + action_values = act_processed_teleop + robot_action_to_send = robot_action_processor((act_processed_teleop, obs)) + # Send action to robot + # Action can eventually be clipped using `max_relative_target`, + # so action actually sent is saved in the dataset. action = postprocessor.process(action) + # TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot. + _sent_action = robot.send_action(robot_action_to_send) + + # Write to dataset if dataset is not None: - action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action") + action_frame = build_dataset_frame(dataset.features, action_values, prefix="action") frame = {**observation_frame, **action_frame, "task": single_task} dataset.add_frame(frame) if display_data: - log_rerun_data(observation, action) + log_rerun_data(observation=obs_processed, action=action_values) dt_s = time.perf_counter() - start_loop_t busy_wait(1 / fps - dt_s) @@ -301,9 +383,22 @@ def record(cfg: RecordConfig) -> LeRobotDataset: robot = make_robot_from_config(cfg.robot) teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None - action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video) - obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video) - dataset_features = {**action_features, **obs_features} + teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() + + dataset_features = combine_feature_dicts( + aggregate_pipeline_dataset_features( + pipeline=teleop_action_processor, + initial_features=create_initial_features( + action=robot.action_features + ), # TODO(steven, pepijn): in future this should be come from teleop or policy + use_videos=cfg.dataset.video, + ), + aggregate_pipeline_dataset_features( + pipeline=robot_observation_processor, + initial_features=create_initial_features(observation=robot.observation_features), + use_videos=cfg.dataset.video, + ), + ) if cfg.resume: dataset = LeRobotDataset( @@ -335,6 +430,18 @@ def record(cfg: RecordConfig) -> LeRobotDataset: # Load pretrained policy policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + preprocessor = None + postprocessor = None + if cfg.policy is not None: + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), + preprocessor_overrides={ + "device_processor": {"device": cfg.policy.device}, + "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, + }, + ) robot.connect() if teleop is not None: @@ -350,8 +457,13 @@ def record(cfg: RecordConfig) -> LeRobotDataset: robot=robot, events=events, fps=cfg.dataset.fps, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, teleop=teleop, policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, dataset=dataset, control_time_s=cfg.dataset.episode_time_s, single_task=cfg.dataset.single_task, @@ -368,6 +480,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset: robot=robot, events=events, fps=cfg.dataset.fps, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, teleop=teleop, control_time_s=cfg.dataset.reset_time_s, single_task=cfg.dataset.single_task, diff --git a/src/lerobot/replay.py b/src/lerobot/replay.py index cd76d114..6761e3f4 100644 --- a/src/lerobot/replay.py +++ b/src/lerobot/replay.py @@ -23,7 +23,7 @@ lerobot-replay \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=black \ --dataset.repo_id=aliberts/record-test \ - --dataset.episode=2 + --dataset.episode=0 ``` Example replay with bimanual so100: @@ -45,9 +45,11 @@ from dataclasses import asdict, dataclass from pathlib import Path from pprint import pformat -import draccus - +from lerobot.configs import parser from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.processor import ( + make_default_robot_action_processor, +) from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -55,7 +57,6 @@ from lerobot.robots import ( # noqa: F401 hope_jr, koch_follower, make_robot_from_config, - reachy2, so100_follower, so101_follower, ) @@ -86,11 +87,13 @@ class ReplayConfig: play_sounds: bool = True -@draccus.wrap() +@parser.wrap() def replay(cfg: ReplayConfig): init_logging() logging.info(pformat(asdict(cfg))) + robot_action_processor = make_default_robot_action_processor() + robot = make_robot_from_config(cfg.robot) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) @@ -109,7 +112,11 @@ def replay(cfg: ReplayConfig): for i, name in enumerate(dataset.features["action"]["names"]): action[name] = action_array[i] - robot.send_action(action) + robot_obs = robot.get_observation() + + processed_action = robot_action_processor((action, robot_obs)) + + _ = robot.send_action(processed_action) dt_s = time.perf_counter() - start_episode_t busy_wait(1 / dataset.fps - dt_s) diff --git a/src/lerobot/robots/so100_follower/__init__.py b/src/lerobot/robots/so100_follower/__init__.py index b995aab1..5dc43ac3 100644 --- a/src/lerobot/robots/so100_follower/__init__.py +++ b/src/lerobot/robots/so100_follower/__init__.py @@ -14,6 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig +from .config_so100_follower import SO100FollowerConfig from .so100_follower import SO100Follower -from .so100_follower_end_effector import SO100FollowerEndEffector diff --git a/src/lerobot/robots/so100_follower/config_so100_follower.py b/src/lerobot/robots/so100_follower/config_so100_follower.py index 561790e7..272b8c43 100644 --- a/src/lerobot/robots/so100_follower/config_so100_follower.py +++ b/src/lerobot/robots/so100_follower/config_so100_follower.py @@ -39,35 +39,3 @@ class SO100FollowerConfig(RobotConfig): # Set to `True` for backward compatibility with previous policies/dataset use_degrees: bool = False - - -@RobotConfig.register_subclass("so100_follower_end_effector") -@dataclass -class SO100FollowerEndEffectorConfig(SO100FollowerConfig): - """Configuration for the SO100FollowerEndEffector robot.""" - - # Path to URDF file for kinematics - # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: - # https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf - urdf_path: str | None = None - - # End-effector frame name in the URDF - target_frame_name: str = "gripper_frame_link" - - # Default bounds for the end-effector position (in meters) - end_effector_bounds: dict[str, list[float]] = field( - default_factory=lambda: { - "min": [-1.0, -1.0, -1.0], # min x, y, z - "max": [1.0, 1.0, 1.0], # max x, y, z - } - ) - - max_gripper_pos: float = 50 - - end_effector_step_sizes: dict[str, float] = field( - default_factory=lambda: { - "x": 0.02, - "y": 0.02, - "z": 0.02, - } - ) diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py new file mode 100644 index 00000000..56686d44 --- /dev/null +++ b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py @@ -0,0 +1,616 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +import numpy as np + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import ( + EnvTransition, + ObservationProcessorStep, + ProcessorStep, + ProcessorStepRegistry, + RobotAction, + RobotActionProcessorStep, + TransitionKey, +) +from lerobot.utils.rotation import Rotation + + +@ProcessorStepRegistry.register("ee_reference_and_delta") +@dataclass +class EEReferenceAndDelta(RobotActionProcessorStep): + """ + Computes a target end-effector pose from a relative delta command. + + This step takes a desired change in position and orientation (`target_*`) and applies it to a + reference end-effector pose to calculate an absolute target pose. The reference pose is derived + from the current robot joint positions using forward kinematics. + + The processor can operate in two modes: + 1. `use_latched_reference=True`: The reference pose is "latched" or saved at the moment the action + is first enabled. Subsequent commands are relative to this fixed reference. + 2. `use_latched_reference=False`: The reference pose is updated to the robot's current pose at + every step. + + Attributes: + kinematics: The robot's kinematic model for forward kinematics. + end_effector_step_sizes: A dictionary scaling the input delta commands. + motor_names: A list of motor names required for forward kinematics. + use_latched_reference: If True, latch the reference pose on enable; otherwise, always use the + current pose as the reference. + reference_ee_pose: Internal state storing the latched reference pose. + _prev_enabled: Internal state to detect the rising edge of the enable signal. + _command_when_disabled: Internal state to hold the last command while disabled. + """ + + kinematics: RobotKinematics + end_effector_step_sizes: dict + motor_names: list[str] + use_latched_reference: bool = ( + True # If True, latch reference on enable; if False, always use current pose + ) + use_ik_solution: bool = False + + reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False) + _prev_enabled: bool = field(default=False, init=False, repr=False) + _command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False) + + def action(self, action: RobotAction) -> RobotAction: + observation = self.transition.get(TransitionKey.OBSERVATION).copy() + + if observation is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + if self.use_ik_solution and "IK_solution" in self.transition.get(TransitionKey.COMPLEMENTARY_DATA): + q_raw = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)["IK_solution"] + else: + q_raw = np.array( + [ + float(v) + for k, v in observation.items() + if isinstance(k, str) + and k.endswith(".pos") + and k.removesuffix(".pos") in self.motor_names + ], + dtype=float, + ) + + if q_raw is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + # Current pose from FK on measured joints + t_curr = self.kinematics.forward_kinematics(q_raw) + + enabled = bool(action.pop("enabled")) + tx = float(action.pop("target_x")) + ty = float(action.pop("target_y")) + tz = float(action.pop("target_z")) + wx = float(action.pop("target_wx")) + wy = float(action.pop("target_wy")) + wz = float(action.pop("target_wz")) + gripper_vel = float(action.pop("gripper_vel")) + + desired = None + + if enabled: + ref = t_curr + if self.use_latched_reference: + # Latched reference mode: latch reference at the rising edge + if not self._prev_enabled or self.reference_ee_pose is None: + self.reference_ee_pose = t_curr.copy() + ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr + + delta_p = np.array( + [ + tx * self.end_effector_step_sizes["x"], + ty * self.end_effector_step_sizes["y"], + tz * self.end_effector_step_sizes["z"], + ], + dtype=float, + ) + r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix() + desired = np.eye(4, dtype=float) + desired[:3, :3] = ref[:3, :3] @ r_abs + desired[:3, 3] = ref[:3, 3] + delta_p + + self._command_when_disabled = desired.copy() + else: + # While disabled, keep sending the same command to avoid drift. + if self._command_when_disabled is None: + # If we've never had an enabled command yet, freeze current FK pose once. + self._command_when_disabled = t_curr.copy() + desired = self._command_when_disabled.copy() + + # Write action fields + pos = desired[:3, 3] + tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec() + action["ee.x"] = float(pos[0]) + action["ee.y"] = float(pos[1]) + action["ee.z"] = float(pos[2]) + action["ee.wx"] = float(tw[0]) + action["ee.wy"] = float(tw[1]) + action["ee.wz"] = float(tw[2]) + action["ee.gripper_vel"] = gripper_vel + + self._prev_enabled = enabled + return action + + def reset(self): + """Resets the internal state of the processor.""" + self._prev_enabled = False + self.reference_ee_pose = None + self._command_when_disabled = None + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for feat in [ + "enabled", + "target_x", + "target_y", + "target_z", + "target_wx", + "target_wy", + "target_wz", + "gripper_vel", + ]: + features[PipelineFeatureType.ACTION].pop(f"{feat}", None) + + for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_vel"]: + features[PipelineFeatureType.ACTION][f"ee.{feat}"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + return features + + +@ProcessorStepRegistry.register("ee_bounds_and_safety") +@dataclass +class EEBoundsAndSafety(RobotActionProcessorStep): + """ + Clips the end-effector pose to predefined bounds and checks for unsafe jumps. + + This step ensures that the target end-effector pose remains within a safe operational workspace. + It also moderates the command to prevent large, sudden movements between consecutive steps. + + Attributes: + end_effector_bounds: A dictionary with "min" and "max" keys for position clipping. + max_ee_step_m: The maximum allowed change in position (in meters) between steps. + max_ee_twist_step_rad: The maximum allowed change in orientation (in radians) between steps. + _last_pos: Internal state storing the last commanded position. + _last_twist: Internal state storing the last commanded orientation. + """ + + end_effector_bounds: dict + max_ee_step_m: float = 0.05 + max_ee_twist_step_rad: float = 0.20 + _last_pos: np.ndarray | None = field(default=None, init=False, repr=False) + _last_twist: np.ndarray | None = field(default=None, init=False, repr=False) + + def action(self, action: RobotAction) -> RobotAction: + x = action["ee.x"] + y = action["ee.y"] + z = action["ee.z"] + wx = action["ee.wx"] + wy = action["ee.wy"] + wz = action["ee.wz"] + # TODO(Steven): ee.gripper_vel does not need to be bounded + + if None in (x, y, z, wx, wy, wz): + raise ValueError( + "Missing required end-effector pose components: x, y, z, wx, wy, wz must all be present in action" + ) + + pos = np.array([x, y, z], dtype=float) + twist = np.array([wx, wy, wz], dtype=float) + + # Clip position + pos = np.clip(pos, self.end_effector_bounds["min"], self.end_effector_bounds["max"]) + + # Check for jumps in position + if self._last_pos is not None: + dpos = pos - self._last_pos + n = float(np.linalg.norm(dpos)) + if n > self.max_ee_step_m and n > 0: + pos = self._last_pos + dpos * (self.max_ee_step_m / n) + raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m") + + self._last_pos = pos + self._last_twist = twist + + action["ee.x"] = float(pos[0]) + action["ee.y"] = float(pos[1]) + action["ee.z"] = float(pos[2]) + action["ee.wx"] = float(twist[0]) + action["ee.wy"] = float(twist[1]) + action["ee.wz"] = float(twist[2]) + return action + + def reset(self): + """Resets the last known position and orientation.""" + self._last_pos = None + self._last_twist = None + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints") +@dataclass +class InverseKinematicsEEToJoints(RobotActionProcessorStep): + """ + Computes desired joint positions from a target end-effector pose using inverse kinematics (IK). + + This step translates a Cartesian command (position and orientation of the end-effector) into + the corresponding joint-space commands for each motor. + + Attributes: + kinematics: The robot's kinematic model for inverse kinematics. + motor_names: A list of motor names for which to compute joint positions. + q_curr: Internal state storing the last joint positions, used as an initial guess for the IK solver. + initial_guess_current_joints: If True, use the robot's current joint state as the IK guess. + If False, use the solution from the previous step. + """ + + kinematics: RobotKinematics + motor_names: list[str] + q_curr: np.ndarray | None = field(default=None, init=False, repr=False) + initial_guess_current_joints: bool = True + + def action(self, action: RobotAction) -> RobotAction: + x = action.pop("ee.x") + y = action.pop("ee.y") + z = action.pop("ee.z") + wx = action.pop("ee.wx") + wy = action.pop("ee.wy") + wz = action.pop("ee.wz") + gripper_pos = action.pop("ee.gripper_pos") + + if None in (x, y, z, wx, wy, wz, gripper_pos): + raise ValueError( + "Missing required end-effector pose components: ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_pos must all be present in action" + ) + + observation = self.transition.get(TransitionKey.OBSERVATION).copy() + if observation is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + q_raw = np.array( + [float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")], + dtype=float, + ) + if q_raw is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + if self.initial_guess_current_joints: # Use current joints as initial guess + self.q_curr = q_raw + else: # Use previous ik solution as initial guess + if self.q_curr is None: + self.q_curr = q_raw + + # Build desired 4x4 transform from pos + rotvec (twist) + t_des = np.eye(4, dtype=float) + t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix() + t_des[:3, 3] = [x, y, z] + + # Compute inverse kinematics + q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des) + self.q_curr = q_target + + # TODO: This is sentitive to order of motor_names = q_target mapping + for i, name in enumerate(self.motor_names): + if name != "gripper": + action[f"{name}.pos"] = float(q_target[i]) + else: + action["gripper.pos"] = float(gripper_pos) + + return action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]: + features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None) + + for name in self.motor_names: + features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + return features + + def reset(self): + """Resets the initial guess for the IK solver.""" + self.q_curr = None + + +@ProcessorStepRegistry.register("gripper_velocity_to_joint") +@dataclass +class GripperVelocityToJoint(RobotActionProcessorStep): + """ + Converts a gripper velocity command into a target gripper joint position. + + This step integrates a normalized velocity command over time to produce a position command, + taking the current gripper position as a starting point. It also supports a discrete mode + where integer actions map to open, close, or no-op. + + Attributes: + motor_names: A list of motor names, which must include 'gripper'. + speed_factor: A scaling factor to convert the normalized velocity command to a position change. + clip_min: The minimum allowed gripper joint position. + clip_max: The maximum allowed gripper joint position. + discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay). + """ + + speed_factor: float = 20.0 + clip_min: float = 0.0 + clip_max: float = 100.0 + discrete_gripper: bool = False + + def action(self, action: RobotAction) -> RobotAction: + observation = self.transition.get(TransitionKey.OBSERVATION).copy() + + gripper_vel = action.pop("ee.gripper_vel") + + if observation is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + q_raw = np.array( + [float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")], + dtype=float, + ) + if q_raw is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + if self.discrete_gripper: + # Discrete gripper actions are in [0, 1, 2] + # 0: open, 1: close, 2: stay + # We need to shift them to [-1, 0, 1] and then scale them to clip_max + gripper_vel = (gripper_vel - 1) * self.clip_max + + # Compute desired gripper position + delta = gripper_vel * float(self.speed_factor) + # TODO: This assumes gripper is the last specified joint in the robot + gripper_pos = float(np.clip(q_raw[-1] + delta, self.clip_min, self.clip_max)) + action["ee.gripper_pos"] = gripper_pos + + return action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + features[PipelineFeatureType.ACTION].pop("ee.gripper_vel", None) + features[PipelineFeatureType.ACTION]["ee.gripper_pos"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + return features + + +def compute_forward_kinematics_joints_to_ee( + joints: dict[str, Any], kinematics: RobotKinematics, motor_names: list[str] +) -> dict[str, Any]: + motor_joint_values = [joints[f"{n}.pos"] for n in motor_names] + + q = np.array(motor_joint_values, dtype=float) + t = kinematics.forward_kinematics(q) + pos = t[:3, 3] + tw = Rotation.from_matrix(t[:3, :3]).as_rotvec() + gripper_pos = joints["gripper.pos"] + for n in motor_names: + joints.pop(f"{n}.pos") + joints["ee.x"] = float(pos[0]) + joints["ee.y"] = float(pos[1]) + joints["ee.z"] = float(pos[2]) + joints["ee.wx"] = float(tw[0]) + joints["ee.wy"] = float(tw[1]) + joints["ee.wz"] = float(tw[2]) + joints["ee.gripper_pos"] = float(gripper_pos) + return joints + + +@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee_observation") +@dataclass +class ForwardKinematicsJointsToEEObservation(ObservationProcessorStep): + """ + Computes the end-effector pose from joint positions using forward kinematics (FK). + + This step is typically used to add the robot's Cartesian pose to the observation space, + which can be useful for visualization or as an input to a policy. + + Attributes: + kinematics: The robot's kinematic model. + """ + + kinematics: RobotKinematics + motor_names: list[str] + + def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + return compute_forward_kinematics_joints_to_ee(observation, self.kinematics, self.motor_names) + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We only use the ee pose in the dataset, so we don't need the joint positions + for n in self.motor_names: + features[PipelineFeatureType.OBSERVATION].pop(f"{n}.pos", None) + # We specify the dataset features of this step that we want to be stored in the dataset + for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]: + features[PipelineFeatureType.OBSERVATION][f"ee.{k}"] = PolicyFeature( + type=FeatureType.STATE, shape=(1,) + ) + return features + + +@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee_action") +@dataclass +class ForwardKinematicsJointsToEEAction(RobotActionProcessorStep): + """ + Computes the end-effector pose from joint positions using forward kinematics (FK). + + This step is typically used to add the robot's Cartesian pose to the observation space, + which can be useful for visualization or as an input to a policy. + + Attributes: + kinematics: The robot's kinematic model. + """ + + kinematics: RobotKinematics + motor_names: list[str] + + def action(self, action: RobotAction) -> RobotAction: + return compute_forward_kinematics_joints_to_ee(action, self.kinematics, self.motor_names) + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We only use the ee pose in the dataset, so we don't need the joint positions + for n in self.motor_names: + features[PipelineFeatureType.ACTION].pop(f"{n}.pos", None) + # We specify the dataset features of this step that we want to be stored in the dataset + for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]: + features[PipelineFeatureType.ACTION][f"ee.{k}"] = PolicyFeature( + type=FeatureType.STATE, shape=(1,) + ) + return features + + +@ProcessorStepRegistry.register(name="forward_kinematics_joints_to_ee") +@dataclass +class ForwardKinematicsJointsToEE(ProcessorStep): + kinematics: RobotKinematics + motor_names: list[str] + + def __post_init__(self): + self.joints_to_ee_action_processor = ForwardKinematicsJointsToEEAction( + kinematics=self.kinematics, motor_names=self.motor_names + ) + self.joints_to_ee_observation_processor = ForwardKinematicsJointsToEEObservation( + kinematics=self.kinematics, motor_names=self.motor_names + ) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + if transition.get(TransitionKey.ACTION) is not None: + transition = self.joints_to_ee_action_processor(transition) + if transition.get(TransitionKey.OBSERVATION) is not None: + transition = self.joints_to_ee_observation_processor(transition) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + if features[PipelineFeatureType.ACTION] is not None: + features = self.joints_to_ee_action_processor.transform_features(features) + if features[PipelineFeatureType.OBSERVATION] is not None: + features = self.joints_to_ee_observation_processor.transform_features(features) + return features + + +@ProcessorStepRegistry.register("inverse_kinematics_rl_step") +@dataclass +class InverseKinematicsRLStep(ProcessorStep): + """ + Computes desired joint positions from a target end-effector pose using inverse kinematics (IK). + + This is modified from the InverseKinematicsEEToJoints step to be used in the RL pipeline. + """ + + kinematics: RobotKinematics + motor_names: list[str] + q_curr: np.ndarray | None = field(default=None, init=False, repr=False) + initial_guess_current_joints: bool = True + + def __call__(self, transition: EnvTransition) -> EnvTransition: + new_transition = dict(transition) + action = new_transition.get(TransitionKey.ACTION) + if action is None: + raise ValueError("Action is required for InverseKinematicsEEToJoints") + action = dict(action) + + x = action.pop("ee.x") + y = action.pop("ee.y") + z = action.pop("ee.z") + wx = action.pop("ee.wx") + wy = action.pop("ee.wy") + wz = action.pop("ee.wz") + gripper_pos = action.pop("ee.gripper_pos") + + if None in (x, y, z, wx, wy, wz, gripper_pos): + raise ValueError( + "Missing required end-effector pose components: ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_pos must all be present in action" + ) + + observation = new_transition.get(TransitionKey.OBSERVATION).copy() + if observation is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + q_raw = np.array( + [float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")], + dtype=float, + ) + if q_raw is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + if self.initial_guess_current_joints: # Use current joints as initial guess + self.q_curr = q_raw + else: # Use previous ik solution as initial guess + if self.q_curr is None: + self.q_curr = q_raw + + # Build desired 4x4 transform from pos + rotvec (twist) + t_des = np.eye(4, dtype=float) + t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix() + t_des[:3, 3] = [x, y, z] + + # Compute inverse kinematics + q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des) + self.q_curr = q_target + + # TODO: This is sentitive to order of motor_names = q_target mapping + for i, name in enumerate(self.motor_names): + if name != "gripper": + action[f"{name}.pos"] = float(q_target[i]) + else: + action["gripper.pos"] = float(gripper_pos) + + new_transition[TransitionKey.ACTION] = action + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + complementary_data["IK_solution"] = q_target + new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + return new_transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]: + features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None) + + for name in self.motor_names: + features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + return features + + def reset(self): + """Resets the initial guess for the IK solver.""" + self.q_curr = None diff --git a/src/lerobot/robots/so100_follower/so100_follower_end_effector.py b/src/lerobot/robots/so100_follower/so100_follower_end_effector.py deleted file mode 100644 index 5fe2993c..00000000 --- a/src/lerobot/robots/so100_follower/so100_follower_end_effector.py +++ /dev/null @@ -1,200 +0,0 @@ -# !/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time -from typing import Any - -import numpy as np - -from lerobot.cameras import make_cameras_from_configs -from lerobot.errors import DeviceNotConnectedError -from lerobot.model.kinematics import RobotKinematics -from lerobot.motors import Motor, MotorNormMode -from lerobot.motors.feetech import FeetechMotorsBus - -from . import SO100Follower -from .config_so100_follower import SO100FollowerEndEffectorConfig - -logger = logging.getLogger(__name__) - - -class SO100FollowerEndEffector(SO100Follower): - """ - SO100Follower robot with end-effector space control. - - This robot inherits from SO100Follower but transforms actions from - end-effector space to joint space before sending them to the motors. - """ - - config_class = SO100FollowerEndEffectorConfig - name = "so100_follower_end_effector" - - def __init__(self, config: SO100FollowerEndEffectorConfig): - super().__init__(config) - self.bus = FeetechMotorsBus( - port=self.config.port, - motors={ - "shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES), - "shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES), - "elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES), - "wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES), - "wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES), - "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), - }, - calibration=self.calibration, - ) - - self.cameras = make_cameras_from_configs(config.cameras) - - self.config = config - - # Initialize the kinematics module for the so100 robot - if self.config.urdf_path is None: - raise ValueError( - "urdf_path must be provided in the configuration for end-effector control. " - "Please set urdf_path in your SO100FollowerEndEffectorConfig." - ) - - self.kinematics = RobotKinematics( - urdf_path=self.config.urdf_path, - target_frame_name=self.config.target_frame_name, - ) - - # Store the bounds for end-effector position - self.end_effector_bounds = self.config.end_effector_bounds - - self.current_ee_pos = None - self.current_joint_pos = None - - @property - def action_features(self) -> dict[str, Any]: - """ - Define action features for end-effector control. - Returns dictionary with dtype, shape, and names. - """ - return { - "dtype": "float32", - "shape": (4,), - "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3}, - } - - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: - """ - Transform action from end-effector space to joint space and send to motors. - - Args: - action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control - or a numpy array with [delta_x, delta_y, delta_z] - - Returns: - The joint-space action that was sent to the motors - """ - - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - # Convert action to numpy array if not already - if isinstance(action, dict): - if all(k in action for k in ["delta_x", "delta_y", "delta_z"]): - delta_ee = np.array( - [ - action["delta_x"] * self.config.end_effector_step_sizes["x"], - action["delta_y"] * self.config.end_effector_step_sizes["y"], - action["delta_z"] * self.config.end_effector_step_sizes["z"], - ], - dtype=np.float32, - ) - if "gripper" not in action: - action["gripper"] = [1.0] - action = np.append(delta_ee, action["gripper"]) - else: - logger.warning( - f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}" - ) - action = np.zeros(4, dtype=np.float32) - - if self.current_joint_pos is None: - # Read current joint positions - current_joint_pos = self.bus.sync_read("Present_Position") - self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors]) - - # Calculate current end-effector position using forward kinematics - if self.current_ee_pos is None: - self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos) - - # Set desired end-effector position by adding delta - desired_ee_pos = np.eye(4) - desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation - - # Add delta to position and clip to bounds - desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3] - if self.end_effector_bounds is not None: - desired_ee_pos[:3, 3] = np.clip( - desired_ee_pos[:3, 3], - self.end_effector_bounds["min"], - self.end_effector_bounds["max"], - ) - - # Compute inverse kinematics to get joint positions - target_joint_values_in_degrees = self.kinematics.inverse_kinematics( - self.current_joint_pos, desired_ee_pos - ) - - # Create joint space action dictionary - joint_action = { - f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys()) - } - - # Handle gripper separately if included in action - # Gripper delta action is in the range 0 - 2, - # We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos - joint_action["gripper.pos"] = np.clip( - self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos, - 5, - self.config.max_gripper_pos, - ) - - self.current_ee_pos = desired_ee_pos.copy() - self.current_joint_pos = target_joint_values_in_degrees.copy() - self.current_joint_pos[-1] = joint_action["gripper.pos"] - - # Send joint space action to parent class - return super().send_action(joint_action) - - def get_observation(self) -> dict[str, Any]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - # Read arm position - start = time.perf_counter() - obs_dict = self.bus.sync_read("Present_Position") - obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read state: {dt_ms:.1f}ms") - - # Capture images from cameras - for cam_key, cam in self.cameras.items(): - start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") - - return obs_dict - - def reset(self): - self.current_ee_pos = None - self.current_joint_pos = None diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 261e59a3..0455bce3 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -29,10 +29,6 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .so100_follower import SO100Follower return SO100Follower(config) - elif config.type == "so100_follower_end_effector": - from .so100_follower import SO100FollowerEndEffector - - return SO100FollowerEndEffector(config) elif config.type == "so101_follower": from .so101_follower import SO101Follower @@ -73,6 +69,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot: raise ValueError(config.type) +# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset def ensure_safe_goal_position( goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[str, float] ) -> dict[str, float]: diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 13d30c68..bf398a0a 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -56,6 +56,7 @@ from copy import deepcopy from dataclasses import asdict from pathlib import Path from pprint import pformat +from typing import Any import einops import gymnasium as gym @@ -69,9 +70,9 @@ from lerobot.configs import parser from lerobot.configs.eval import EvalPipelineConfig from lerobot.envs.factory import make_env from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation -from lerobot.policies.factory import make_policy +from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import get_device_from_parameters +from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -84,6 +85,8 @@ from lerobot.utils.utils import ( def rollout( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], seeds: list[int] | None = None, return_observations: bool = False, render_callback: Callable[[gym.vector.VectorEnv], None] | None = None, @@ -120,7 +123,6 @@ def rollout( The dictionary described above. """ assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module." - device = get_device_from_parameters(policy) # Reset the policy and environments. policy.reset() @@ -151,23 +153,20 @@ def rollout( if return_observations: all_observations.append(deepcopy(observation)) - observation = { - key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation - } - # Infer "task" from attributes of environments. # TODO: works with SyncVectorEnv but not AsyncVectorEnv observation = add_envs_task(env, observation) - + observation = preprocessor(observation) with torch.inference_mode(): action = policy.select_action(observation) + action = postprocessor(action) # Convert to CPU / numpy. - action = action.to("cpu").numpy() - assert action.ndim == 2, "Action dimensions should be (batch, action_dim)" + action_numpy: np.ndarray = action.to("cpu").numpy() + assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)" # Apply the next action. - observation, reward, terminated, truncated, info = env.step(action) + observation, reward, terminated, truncated, info = env.step(action_numpy) if render_callback is not None: render_callback(env) @@ -181,7 +180,7 @@ def rollout( # Keep track of which environments are done so far. done = terminated | truncated | done - all_actions.append(torch.from_numpy(action)) + all_actions.append(torch.from_numpy(action_numpy)) all_rewards.append(torch.from_numpy(reward)) all_dones.append(torch.from_numpy(done)) all_successes.append(torch.tensor(successes)) @@ -220,6 +219,8 @@ def rollout( def eval_policy( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], n_episodes: int, max_episodes_rendered: int = 0, videos_dir: Path | None = None, @@ -296,8 +297,10 @@ def eval_policy( start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs) ) rollout_data = rollout( - env, - policy, + env=env, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, seeds=list(seeds) if seeds else None, return_observations=return_episode_data, render_callback=render_frame if max_episodes_rendered > 0 else None, @@ -479,13 +482,22 @@ def eval_main(cfg: EvalPipelineConfig): cfg=cfg.policy, env_cfg=cfg.env, ) + policy.eval() + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility. + preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}}, + ) with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): info = eval_policy( - env, - policy, - cfg.eval.n_episodes, + env=env, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + n_episodes=cfg.eval.n_episodes, max_episodes_rendered=10, videos_dir=Path(cfg.output_dir) / "videos", start_seed=cfg.seed, diff --git a/src/lerobot/scripts/rl/actor.py b/src/lerobot/scripts/rl/actor.py index 1c8f9286..baa284c4 100644 --- a/src/lerobot/scripts/rl/actor.py +++ b/src/lerobot/scripts/rl/actor.py @@ -62,9 +62,16 @@ from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy +from lerobot.processor import TransitionKey from lerobot.robots import so100_follower # noqa: F401 -from lerobot.scripts.rl.gym_manipulator import make_robot_env +from lerobot.scripts.rl.gym_manipulator import ( + create_transition, + make_processors, + make_robot_env, + step_env_and_process_transition, +) from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 +from lerobot.teleoperators.utils import TeleopEvents from lerobot.transport import services_pb2, services_pb2_grpc from lerobot.transport.utils import ( bytes_to_state_dict, @@ -91,10 +98,7 @@ from lerobot.utils.utils import ( ACTOR_SHUTDOWN_TIMEOUT = 30 - -################################################# -# Main entry point # -################################################# +# Main entry point @parser.wrap() @@ -201,9 +205,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig): logging.info("[ACTOR] queues closed") -################################################# -# Core algorithm functions # -################################################# +# Core algorithm functions def act_with_policy( @@ -236,7 +238,8 @@ def act_with_policy( logging.info("make_env online") - online_env = make_robot_env(cfg=cfg.env) + online_env, teleop_device = make_robot_env(cfg=cfg.env) + env_processor, action_processor = make_processors(online_env, teleop_device, cfg.env, cfg.policy.device) set_seed(cfg.seed) device = get_safe_torch_device(cfg.policy.device, log=True) @@ -257,6 +260,12 @@ def act_with_policy( assert isinstance(policy, nn.Module) obs, info = online_env.reset() + env_processor.reset() + action_processor.reset() + + # Process initial observation + transition = create_transition(observation=obs, info=info) + transition = env_processor(transition) # NOTE: For the moment we will solely handle the case of a single environment sum_reward_episode = 0 @@ -274,45 +283,71 @@ def act_with_policy( logging.info("[ACTOR] Shutting down act_with_policy") return - if interaction_step >= cfg.policy.online_step_before_learning: - # Time policy inference and check if it meets FPS requirement - with policy_timer: - action = policy.select_action(batch=obs) - policy_fps = policy_timer.fps_last + observation = { + k: v for k, v in transition[TransitionKey.OBSERVATION].items() if k in cfg.policy.input_features + } - log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) + # Time policy inference and check if it meets FPS requirement + with policy_timer: + # Extract observation from transition for policy + action = policy.select_action(batch=observation) + policy_fps = policy_timer.fps_last - else: - action = online_env.action_space.sample() + log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) - next_obs, reward, done, truncated, info = online_env.step(action) + # Use the new step function + new_transition = step_env_and_process_transition( + env=online_env, + transition=transition, + action=action, + env_processor=env_processor, + action_processor=action_processor, + ) + + # Extract values from processed transition + next_observation = { + k: v + for k, v in new_transition[TransitionKey.OBSERVATION].items() + if k in cfg.policy.input_features + } + + # Teleop action is the action that was executed in the environment + # It is either the action from the teleop device or the action from the policy + executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] + + reward = new_transition[TransitionKey.REWARD] + done = new_transition.get(TransitionKey.DONE, False) + truncated = new_transition.get(TransitionKey.TRUNCATED, False) sum_reward_episode += float(reward) - # Increment total steps counter for intervention rate episode_total_steps += 1 - # NOTE: We override the action if the intervention is True, because the action applied is the intervention action - if "is_intervention" in info and info["is_intervention"]: - # NOTE: The action space for demonstration before hand is with the full action space - # but sometimes for example we want to deactivate the gripper - action = info["action_intervention"] + # Check for intervention from transition info + intervention_info = new_transition[TransitionKey.INFO] + if intervention_info.get(TeleopEvents.IS_INTERVENTION, False): episode_intervention = True - # Increment intervention steps counter episode_intervention_steps += 1 + complementary_info = { + "discrete_penalty": torch.tensor( + [new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)] + ), + } + # Create transition for learner (convert to old format) list_transition_to_send_to_learner.append( Transition( - state=obs, - action=action, + state=observation, + action=executed_action, reward=reward, - next_state=next_obs, + next_state=next_observation, done=done, - truncated=truncated, # TODO: (azouitine) Handle truncation properly - complementary_info=info, + truncated=truncated, + complementary_info=complementary_info, ) ) - # assign obs to the next obs and continue the rollout - obs = next_obs + + # Update transition for next iteration + transition = new_transition if done or truncated: logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") @@ -347,21 +382,27 @@ def act_with_policy( ) ) - # Reset intervention counters + # Reset intervention counters and environment sum_reward_episode = 0.0 episode_intervention = False episode_intervention_steps = 0 episode_total_steps = 0 + + # Reset environment and processors obs, info = online_env.reset() + env_processor.reset() + action_processor.reset() + + # Process initial observation + transition = create_transition(observation=obs, info=info) + transition = env_processor(transition) if cfg.env.fps is not None: dt_time = time.perf_counter() - start_time busy_wait(1 / cfg.env.fps - dt_time) -################################################# -# Communication Functions - Group all gRPC/messaging functions # -################################################# +# Communication Functions - Group all gRPC/messaging functions def establish_learner_connection( @@ -606,9 +647,7 @@ def interactions_stream( return services_pb2.Empty() -################################################# -# Policy functions # -################################################# +# Policy functions def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device): @@ -640,9 +679,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device) logging.info("[ACTOR] Loaded discrete critic parameters from Learner.") -################################################# -# Utilities functions # -################################################# +# Utilities functions def push_transitions_to_transport_queue(transitions: list, transitions_queue): diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/scripts/rl/gym_manipulator.py index 046be03e..f91d077f 100644 --- a/src/lerobot/scripts/rl/gym_manipulator.py +++ b/src/lerobot/scripts/rl/gym_manipulator.py @@ -14,65 +14,95 @@ # See the License for the specific language governing permissions and # limitations under the License. - -""" -Robot Environment for LeRobot Manipulation Tasks - -This module provides a comprehensive gym-compatible environment for robot manipulation -with support for: -- Multiple robot types (SO100, SO101, Koch and Moss) -- Human intervention via leader-follower control or gamepad - -- End-effector and joint space control -- Image processing (cropping and resizing) - -The environment is built using a composable wrapper pattern where each wrapper -adds specific functionality to the base RobotEnv. - -Example: - env = make_robot_env(cfg) - obs, info = env.reset() - action = policy.select_action(obs) - obs, reward, terminated, truncated, info = env.step(action) -""" - import logging import time -from collections import deque -from collections.abc import Sequence -from threading import Lock -from typing import Annotated, Any +from dataclasses import dataclass +from typing import Any import gymnasium as gym import numpy as np import torch -import torchvision.transforms.functional as F # noqa: N812 from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser -from lerobot.envs.configs import EnvConfig -from lerobot.envs.utils import preprocess_observation +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.envs.configs import HILSerlRobotEnvConfig from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + AddTeleopActionAsComplimentaryDataStep, + AddTeleopEventsAsInfoStep, + DataProcessorPipeline, + DeviceProcessorStep, + EnvTransition, + GripperPenaltyProcessorStep, + ImageCropResizeProcessorStep, + InterventionActionProcessorStep, + JointVelocityProcessorStep, + MapDeltaActionToRobotActionStep, + MapTensorToDeltaActionDictStep, + MotorCurrentProcessorStep, + Numpy2TorchActionProcessorStep, + RewardClassifierProcessorStep, + RobotActionToPolicyActionProcessorStep, + TimeLimitProcessorStep, + Torch2NumpyActionProcessorStep, + TransitionKey, + VanillaObservationProcessorStep, + create_transition, +) +from lerobot.processor.converters import identity_transition from lerobot.robots import ( # noqa: F401 RobotConfig, make_robot_from_config, so100_follower, ) +from lerobot.robots.robot import Robot +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + EEBoundsAndSafety, + EEReferenceAndDelta, + ForwardKinematicsJointsToEEObservation, + GripperVelocityToJoint, + InverseKinematicsRLStep, +) from lerobot.teleoperators import ( gamepad, # noqa: F401 keyboard, # noqa: F401 make_teleoperator_from_config, so101_leader, # noqa: F401 ) -from lerobot.teleoperators.gamepad.teleop_gamepad import GamepadTeleop -from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardEndEffectorTeleop +from lerobot.teleoperators.teleoperator import Teleoperator +from lerobot.teleoperators.utils import TeleopEvents from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say logging.basicConfig(level=logging.INFO) -def reset_follower_position(robot_arm, target_position): +@dataclass +class DatasetConfig: + """Configuration for dataset creation and management.""" + + repo_id: str + task: str + root: str | None = None + num_episodes_to_record: int = 5 + replay_episode: int | None = None + push_to_hub: bool = False + + +@dataclass +class GymManipulatorConfig: + """Main configuration for gym manipulator environment.""" + + env: HILSerlRobotEnvConfig + dataset: DatasetConfig + mode: str | None = None # Either "record", "replay", None + device: str = "cpu" + + +def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> None: + """Reset robot arm to target position using smooth trajectory.""" current_position_dict = robot_arm.bus.sync_read("Present_Position") current_position = np.array( [current_position_dict[name] for name in current_position_dict], dtype=np.float32 @@ -86,158 +116,25 @@ def reset_follower_position(robot_arm, target_position): busy_wait(0.015) -class TorchBox(gym.spaces.Box): - """ - A version of gym.spaces.Box that handles PyTorch tensors. - - This class extends gym.spaces.Box to work with PyTorch tensors, - providing compatibility between NumPy arrays and PyTorch tensors. - """ - - def __init__( - self, - low: float | Sequence[float] | np.ndarray, - high: float | Sequence[float] | np.ndarray, - shape: Sequence[int] | None = None, - np_dtype: np.dtype | type = np.float32, - torch_dtype: torch.dtype = torch.float32, - device: str = "cpu", - seed: int | np.random.Generator | None = None, - ) -> None: - """ - Initialize the PyTorch-compatible Box space. - - Args: - low: Lower bounds of the space. - high: Upper bounds of the space. - shape: Shape of the space. If None, inferred from low and high. - np_dtype: NumPy data type for internal storage. - torch_dtype: PyTorch data type for tensor conversion. - device: PyTorch device for returned tensors. - seed: Random seed for sampling. - """ - super().__init__(low, high, shape=shape, dtype=np_dtype, seed=seed) - self.torch_dtype = torch_dtype - self.device = device - - def sample(self) -> torch.Tensor: - """ - Sample a random point from the space. - - Returns: - A PyTorch tensor within the space bounds. - """ - arr = super().sample() - return torch.as_tensor(arr, dtype=self.torch_dtype, device=self.device) - - def contains(self, x: torch.Tensor) -> bool: - """ - Check if a tensor is within the space bounds. - - Args: - x: The PyTorch tensor to check. - - Returns: - Boolean indicating whether the tensor is within bounds. - """ - # Move to CPU/numpy and cast to the internal dtype - arr = x.detach().cpu().numpy().astype(self.dtype, copy=False) - return super().contains(arr) - - def seed(self, seed: int | np.random.Generator | None = None): - """ - Set the random seed for sampling. - - Args: - seed: The random seed to use. - - Returns: - List containing the seed. - """ - super().seed(seed) - return [seed] - - def __repr__(self) -> str: - """ - Return a string representation of the space. - - Returns: - Formatted string with space details. - """ - return ( - f"TorchBox({self.low_repr}, {self.high_repr}, {self.shape}, " - f"np={self.dtype.name}, torch={self.torch_dtype}, device={self.device})" - ) - - -class TorchActionWrapper(gym.Wrapper): - """ - Wrapper that changes the action space to use PyTorch tensors. - - This wrapper modifies the action space to return PyTorch tensors when sampled - and handles converting PyTorch actions to NumPy when stepping the environment. - """ - - def __init__(self, env: gym.Env, device: str): - """ - Initialize the PyTorch action space wrapper. - - Args: - env: The environment to wrap. - device: The PyTorch device to use for tensor operations. - """ - super().__init__(env) - self.action_space = TorchBox( - low=env.action_space.low, - high=env.action_space.high, - shape=env.action_space.shape, - torch_dtype=torch.float32, - device=torch.device("cpu"), - ) - - def step(self, action: torch.Tensor): - """ - Step the environment with a PyTorch tensor action. - - This method handles conversion from PyTorch tensors to NumPy arrays - for compatibility with the underlying environment. - - Args: - action: PyTorch tensor action to take. - - Returns: - Tuple of (observation, reward, terminated, truncated, info). - """ - if action.dim() == 2: - action = action.squeeze(0) - action = action.detach().cpu().numpy() - return self.env.step(action) - - class RobotEnv(gym.Env): - """ - Gym-compatible environment for evaluating robotic control policies with integrated human intervention. - - This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta) - and absolute joint position commands and automatically configures its observation and action spaces based on the robot's - sensors and configuration. - """ + """Gym environment for robotic control with human intervention support.""" def __init__( self, robot, use_gripper: bool = False, display_cameras: bool = False, - ): - """ - Initialize the RobotEnv environment. - - The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup - supports both relative (delta) adjustments and absolute joint positions for controlling the robot. + reset_pose: list[float] | None = None, + reset_time_s: float = 5.0, + ) -> None: + """Initialize robot environment with configuration options. Args: - robot: The robot interface object used to connect and interact with the physical robot. - display_cameras: If True, the robot's camera feeds will be displayed during execution. + robot: Robot interface for hardware communication. + use_gripper: Whether to include gripper in action space. + display_cameras: Whether to show camera feeds during execution. + reset_pose: Joint positions for environment reset. + reset_time_s: Time to wait during reset. """ super().__init__() @@ -255,52 +152,50 @@ class RobotEnv(gym.Env): self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors] self._image_keys = self.robot.cameras.keys() - self.current_observation = None + self.reset_pose = reset_pose + self.reset_time_s = reset_time_s self.use_gripper = use_gripper + self._joint_names = list(self.robot.bus.motors.keys()) + self._raw_joint_positions = None + self._setup_spaces() - def _get_observation(self) -> dict[str, np.ndarray]: - """Helper to convert a dictionary from bus.sync_read to an ordered numpy array.""" + def _get_observation(self) -> dict[str, Any]: + """Get current robot observation including joint positions and camera images.""" obs_dict = self.robot.get_observation() - joint_positions = np.array([obs_dict[name] for name in self._joint_names]) + raw_joint_joint_position = {f"{name}.pos": obs_dict[f"{name}.pos"] for name in self._joint_names} + joint_positions = np.array([raw_joint_joint_position[f"{name}.pos"] for name in self._joint_names]) images = {key: obs_dict[key] for key in self._image_keys} - self.current_observation = {"agent_pos": joint_positions, "pixels": images} - def _setup_spaces(self): - """ - Dynamically configure the observation and action spaces based on the robot's capabilities. + return {"agent_pos": joint_positions, "pixels": images, **raw_joint_joint_position} - Observation Space: - - For keys with "image": A Box space with pixel values ranging from 0 to 255. - - For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range. - - Action Space: - - The action space is defined as a Box space representing joint position commands. It is defined as relative (delta) - or absolute, based on the configuration. - """ - self._get_observation() + def _setup_spaces(self) -> None: + """Configure observation and action spaces based on robot capabilities.""" + current_observation = self._get_observation() observation_spaces = {} # Define observation spaces for images and other states. - if "pixels" in self.current_observation: + if current_observation is not None and "pixels" in current_observation: prefix = "observation.images" observation_spaces = { f"{prefix}.{key}": gym.spaces.Box( - low=0, high=255, shape=self.current_observation["pixels"][key].shape, dtype=np.uint8 + low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8 ) - for key in self.current_observation["pixels"] + for key in current_observation["pixels"] } - observation_spaces["observation.state"] = gym.spaces.Box( - low=0, - high=10, - shape=self.current_observation["agent_pos"].shape, - dtype=np.float32, - ) + if current_observation is not None: + agent_pos = current_observation["agent_pos"] + observation_spaces["observation.state"] = gym.spaces.Box( + low=0, + high=10, + shape=agent_pos.shape, + dtype=np.float32, + ) self.observation_space = gym.spaces.Dict(observation_spaces) @@ -322,57 +217,46 @@ class RobotEnv(gym.Env): dtype=np.float32, ) - def reset(self, seed=None, options=None) -> tuple[dict[str, np.ndarray], dict[str, Any]]: - """ - Reset the environment to its initial state. - This method resets the step counter and clears any episodic data. + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Reset environment to initial state. Args: - seed: A seed for random number generation to ensure reproducibility. - options: Additional options to influence the reset behavior. + seed: Random seed for reproducibility. + options: Additional reset options. Returns: - A tuple containing: - - observation (dict): The initial sensor observation. - - info (dict): A dictionary with supplementary information, including the key "is_intervention". + Tuple of (observation, info) dictionaries. """ - super().reset(seed=seed, options=options) + # Reset the robot + # self.robot.reset() + start_time = time.perf_counter() + if self.reset_pose is not None: + log_say("Reset the environment.", play_sounds=True) + reset_follower_position(self.robot, np.array(self.reset_pose)) + log_say("Reset the environment done.", play_sounds=True) - self.robot.reset() + busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) + + super().reset(seed=seed, options=options) # Reset episode tracking variables. self.current_step = 0 self.episode_data = None - self.current_observation = None - self._get_observation() - return self.current_observation, {"is_intervention": False} + obs = self._get_observation() + self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names} + return obs, {TeleopEvents.IS_INTERVENTION: False} def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]: - """ - Execute a single step within the environment using the specified action. + """Execute one environment step with given action.""" + joint_targets_dict = {f"{key}.pos": action[i] for i, key in enumerate(self.robot.bus.motors.keys())} - The provided action is processed and sent to the robot as joint position commands - that may be either absolute values or deltas based on the environment configuration. + self.robot.send_action(joint_targets_dict) - Args: - action: The commanded joint positions as a numpy array or torch tensor. + obs = self._get_observation() - Returns: - A tuple containing: - - observation (dict): The new sensor observation after taking the step. - - reward (float): The step reward (default is 0.0 within this wrapper). - - terminated (bool): True if the episode has reached a terminal state. - - truncated (bool): True if the episode was truncated (e.g., time constraints). - - info (dict): Additional debugging information including intervention status. - """ - action_dict = {"delta_x": action[0], "delta_y": action[1], "delta_z": action[2]} - - # 1.0 action corresponds to no-op action - action_dict["gripper"] = action[3] if self.use_gripper else 1.0 - - self.robot.send_action(action_dict) - - self._get_observation() + self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names} if self.display_cameras: self.render() @@ -384,1880 +268,501 @@ class RobotEnv(gym.Env): truncated = False return ( - self.current_observation, + obs, reward, terminated, truncated, - {"is_intervention": False}, + {TeleopEvents.IS_INTERVENTION: False}, ) - def render(self): - """ - Render the current state of the environment by displaying the robot's camera feeds. - """ + def render(self) -> None: + """Display robot camera feeds.""" import cv2 - image_keys = [key for key in self.current_observation if "image" in key] + current_observation = self._get_observation() + if current_observation is not None: + image_keys = [key for key in current_observation if "image" in key] - for key in image_keys: - cv2.imshow(key, cv2.cvtColor(self.current_observation[key].numpy(), cv2.COLOR_RGB2BGR)) - cv2.waitKey(1) + for key in image_keys: + cv2.imshow(key, cv2.cvtColor(current_observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) - def close(self): - """ - Close the environment and clean up resources by disconnecting the robot. - - If the robot is currently connected, this method properly terminates the connection to ensure that all - associated resources are released. - """ + def close(self) -> None: + """Close environment and disconnect robot.""" if self.robot.is_connected: self.robot.disconnect() + def get_raw_joint_positions(self) -> dict[str, float]: + """Get raw joint positions.""" + return self._raw_joint_positions -class AddJointVelocityToObservation(gym.ObservationWrapper): - """ - Wrapper that adds joint velocity information to the observation. - This wrapper computes joint velocities by tracking changes in joint positions over time, - and extends the observation space to include these velocities. - """ - - def __init__(self, env, joint_velocity_limits=100.0, fps=30, num_dof=6): - """ - Initialize the joint velocity wrapper. - - Args: - env: The environment to wrap. - joint_velocity_limits: Maximum expected joint velocity for space bounds. - fps: Frames per second used to calculate velocity (position delta / time). - num_dof: Number of degrees of freedom (joints) in the robot. - """ - super().__init__(env) - - # Extend observation space to include joint velocities - old_low = self.observation_space["observation.state"].low - old_high = self.observation_space["observation.state"].high - old_shape = self.observation_space["observation.state"].shape - - self.last_joint_positions = np.zeros(num_dof) - - new_low = np.concatenate([old_low, np.ones(num_dof) * -joint_velocity_limits]) - new_high = np.concatenate([old_high, np.ones(num_dof) * joint_velocity_limits]) - - new_shape = (old_shape[0] + num_dof,) - - self.observation_space["observation.state"] = gym.spaces.Box( - low=new_low, - high=new_high, - shape=new_shape, - dtype=np.float32, - ) - - self.dt = 1.0 / fps - - def observation(self, observation): - """ - Add joint velocity information to the observation. - - Args: - observation: The original observation from the environment. - - Returns: - The modified observation with joint velocities. - """ - joint_velocities = (observation["agent_pos"] - self.last_joint_positions) / self.dt - self.last_joint_positions = observation["agent_pos"] - observation["agent_pos"] = np.concatenate([observation["agent_pos"], joint_velocities], axis=-1) - return observation - - -class AddCurrentToObservation(gym.ObservationWrapper): - """ - Wrapper that adds motor current information to the observation. - - This wrapper extends the observation space to include the current values - from each motor, providing information about the forces being applied. - """ - - def __init__(self, env, max_current=500, num_dof=6): - """ - Initialize the current observation wrapper. - - Args: - env: The environment to wrap. - max_current: Maximum expected current for space bounds. - num_dof: Number of degrees of freedom (joints) in the robot. - """ - super().__init__(env) - - # Extend observation space to include joint velocities - old_low = self.observation_space["observation.state"].low - old_high = self.observation_space["observation.state"].high - old_shape = self.observation_space["observation.state"].shape - - new_low = np.concatenate([old_low, np.zeros(num_dof)]) - new_high = np.concatenate([old_high, np.ones(num_dof) * max_current]) - - new_shape = (old_shape[0] + num_dof,) - - self.observation_space["observation.state"] = gym.spaces.Box( - low=new_low, - high=new_high, - shape=new_shape, - dtype=np.float32, - ) - - def observation(self, observation): - """ - Add current information to the observation. - - Args: - observation: The original observation from the environment. - - Returns: - The modified observation with current values. - """ - present_current_dict = self.env.unwrapped.robot.bus.sync_read("Present_Current") - present_current_observation = np.array( - [present_current_dict[name] for name in self.env.unwrapped.robot.bus.motors] - ) - observation["agent_pos"] = np.concatenate( - [observation["agent_pos"], present_current_observation], axis=-1 - ) - return observation - - -class RewardWrapper(gym.Wrapper): - def __init__(self, env, reward_classifier, device="cuda"): - """ - Wrapper to add reward prediction to the environment using a trained classifier. - - Args: - env: The environment to wrap. - reward_classifier: The reward classifier model. - device: The device to run the model on. - """ - self.env = env - - self.device = device - - self.reward_classifier = torch.compile(reward_classifier) - self.reward_classifier.to(self.device) - - def step(self, action): - """ - Execute a step and compute the reward using the classifier. - - Args: - action: The action to take in the environment. - - Returns: - Tuple of (observation, reward, terminated, truncated, info). - """ - observation, _, terminated, truncated, info = self.env.step(action) - - images = {} - for key in observation: - if "image" in key: - images[key] = observation[key].to(self.device, non_blocking=(self.device == "cuda")) - if images[key].dim() == 3: - images[key] = images[key].unsqueeze(0) - - start_time = time.perf_counter() - with torch.inference_mode(): - success = ( - self.reward_classifier.predict_reward(images, threshold=0.7) - if self.reward_classifier is not None - else 0.0 - ) - info["Reward classifier frequency"] = 1 / (time.perf_counter() - start_time) - - reward = 0.0 - if success == 1.0: - terminated = True - reward = 1.0 - - return observation, reward, terminated, truncated, info - - def reset(self, seed=None, options=None): - """ - Reset the environment. - - Args: - seed: Random seed for reproducibility. - options: Additional reset options. - - Returns: - The initial observation and info from the wrapped environment. - """ - return self.env.reset(seed=seed, options=options) - - -class TimeLimitWrapper(gym.Wrapper): - """ - Wrapper that adds a time limit to episodes and tracks execution time. - - This wrapper terminates episodes after a specified time has elapsed, providing - better control over episode length. - """ - - def __init__(self, env, control_time_s, fps): - """ - Initialize the time limit wrapper. - - Args: - env: The environment to wrap. - control_time_s: Maximum episode duration in seconds. - fps: Frames per second for calculating the maximum number of steps. - """ - self.env = env - self.control_time_s = control_time_s - self.fps = fps - - self.last_timestamp = 0.0 - self.episode_time_in_s = 0.0 - - self.max_episode_steps = int(self.control_time_s * self.fps) - - self.current_step = 0 - - def step(self, action): - """ - Step the environment and track time elapsed. - - Args: - action: The action to take in the environment. - - Returns: - Tuple of (observation, reward, terminated, truncated, info). - """ - obs, reward, terminated, truncated, info = self.env.step(action) - time_since_last_step = time.perf_counter() - self.last_timestamp - self.episode_time_in_s += time_since_last_step - self.last_timestamp = time.perf_counter() - self.current_step += 1 - # check if last timestep took more time than the expected fps - if 1.0 / time_since_last_step < self.fps: - logging.debug(f"Current timestep exceeded expected fps {self.fps}") - - if self.current_step >= self.max_episode_steps: - terminated = True - return obs, reward, terminated, truncated, info - - def reset(self, seed=None, options=None): - """ - Reset the environment and time tracking. - - Args: - seed: Random seed for reproducibility. - options: Additional reset options. - - Returns: - The initial observation and info from the wrapped environment. - """ - self.episode_time_in_s = 0.0 - self.last_timestamp = time.perf_counter() - self.current_step = 0 - return self.env.reset(seed=seed, options=options) - - -class ImageCropResizeWrapper(gym.Wrapper): - """ - Wrapper that crops and resizes image observations. - - This wrapper processes image observations to focus on relevant regions by - cropping and then resizing to a standard size. - """ - - def __init__( - self, - env, - crop_params_dict: dict[str, Annotated[tuple[int], 4]], - resize_size=None, - ): - """ - Initialize the image crop and resize wrapper. - - Args: - env: The environment to wrap. - crop_params_dict: Dictionary mapping image observation keys to crop parameters - (top, left, height, width). - resize_size: Target size for resized images (height, width). Defaults to (128, 128). - """ - super().__init__(env) - self.env = env - self.crop_params_dict = crop_params_dict - print(f"obs_keys , {self.env.observation_space}") - print(f"crop params dict {crop_params_dict.keys()}") - for key_crop in crop_params_dict: - if key_crop not in self.env.observation_space.keys(): # noqa: SIM118 - raise ValueError(f"Key {key_crop} not in observation space") - for key in crop_params_dict: - new_shape = (3, resize_size[0], resize_size[1]) - self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape) - - self.resize_size = resize_size - if self.resize_size is None: - self.resize_size = (128, 128) - - def step(self, action): - """ - Step the environment and process image observations. - - Args: - action: The action to take in the environment. - - Returns: - Tuple of (observation, reward, terminated, truncated, info) with processed images. - """ - obs, reward, terminated, truncated, info = self.env.step(action) - for k in self.crop_params_dict: - device = obs[k].device - if obs[k].dim() >= 3: - # Reshape to combine height and width dimensions for easier calculation - batch_size = obs[k].size(0) - channels = obs[k].size(1) - flattened_spatial_dims = obs[k].view(batch_size, channels, -1) - - # Calculate standard deviation across spatial dimensions (H, W) - # If any channel has std=0, all pixels in that channel have the same value - # This is helpful if one camera mistakenly covered or the image is black - std_per_channel = torch.std(flattened_spatial_dims, dim=2) - if (std_per_channel <= 0.02).any(): - logging.warning( - f"Potential hardware issue detected: All pixels have the same value in observation {k}" - ) - - if device == torch.device("mps:0"): - obs[k] = obs[k].cpu() - - obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) - obs[k] = F.resize(obs[k], self.resize_size) - # TODO (michel-aractingi): Bug in resize, it returns values outside [0, 1] - obs[k] = obs[k].clamp(0.0, 1.0) - obs[k] = obs[k].to(device) - - return obs, reward, terminated, truncated, info - - def reset(self, seed=None, options=None): - """ - Reset the environment and process image observations. - - Args: - seed: Random seed for reproducibility. - options: Additional reset options. - - Returns: - Tuple of (observation, info) with processed images. - """ - obs, info = self.env.reset(seed=seed, options=options) - for k in self.crop_params_dict: - device = obs[k].device - if device == torch.device("mps:0"): - obs[k] = obs[k].cpu() - obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) - obs[k] = F.resize(obs[k], self.resize_size) - obs[k] = obs[k].clamp(0.0, 1.0) - obs[k] = obs[k].to(device) - return obs, info - - -class ConvertToLeRobotObservation(gym.ObservationWrapper): - """ - Wrapper that converts standard observations to LeRobot format. - - This wrapper processes observations to match the expected format for LeRobot, - including normalizing image values and moving tensors to the specified device. - """ - - def __init__(self, env, device: str = "cpu"): - """ - Initialize the LeRobot observation converter. - - Args: - env: The environment to wrap. - device: Target device for the observation tensors. - """ - super().__init__(env) - - self.device = torch.device(device) - - def observation(self, observation): - """ - Convert observations to LeRobot format. - - Args: - observation: The original observation from the environment. - - Returns: - The processed observation with normalized images and proper tensor formats. - """ - observation = preprocess_observation(observation) - observation = { - key: observation[key].to(self.device, non_blocking=self.device.type == "cuda") - for key in observation - } - return observation - - -class ResetWrapper(gym.Wrapper): - """ - Wrapper that handles environment reset procedures. - - This wrapper provides additional functionality during environment reset, - including the option to reset to a fixed pose or allow manual reset. - """ - - def __init__( - self, - env: RobotEnv, - reset_pose: np.ndarray | None = None, - reset_time_s: float = 5, - ): - """ - Initialize the reset wrapper. - - Args: - env: The environment to wrap. - reset_pose: Fixed joint positions to reset to. If None, manual reset is used. - reset_time_s: Time in seconds to wait after reset or allowed for manual reset. - """ - super().__init__(env) - self.reset_time_s = reset_time_s - self.reset_pose = reset_pose - self.robot = self.unwrapped.robot - - def reset(self, *, seed=None, options=None): - """ - Reset the environment with either fixed or manual reset procedure. - - If reset_pose is provided, the robot will move to that position. - Otherwise, manual teleoperation control is allowed for reset_time_s seconds. - - Args: - seed: Random seed for reproducibility. - options: Additional reset options. - - Returns: - The initial observation and info from the wrapped environment. - """ - start_time = time.perf_counter() - if self.reset_pose is not None: - log_say("Reset the environment.", play_sounds=True) - reset_follower_position(self.unwrapped.robot, self.reset_pose) - log_say("Reset the environment done.", play_sounds=True) - - if hasattr(self.env, "robot_leader"): - self.env.robot_leader.bus.sync_write("Torque_Enable", 1) - log_say("Reset the leader robot.", play_sounds=True) - reset_follower_position(self.env.robot_leader, self.reset_pose) - log_say("Reset the leader robot done.", play_sounds=True) - else: - log_say( - f"Manually reset the environment for {self.reset_time_s} seconds.", - play_sounds=True, - ) - start_time = time.perf_counter() - while time.perf_counter() - start_time < self.reset_time_s: - action = self.env.robot_leader.get_action() - self.unwrapped.robot.send_action(action) - - log_say("Manual reset of the environment done.", play_sounds=True) - - busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) - - return super().reset(seed=seed, options=options) - - -class BatchCompatibleWrapper(gym.ObservationWrapper): - """ - Wrapper that ensures observations are compatible with batch processing. - - This wrapper adds a batch dimension to observations that don't already have one, - making them compatible with models that expect batched inputs. - """ - - def __init__(self, env): - """ - Initialize the batch compatibility wrapper. - - Args: - env: The environment to wrap. - """ - super().__init__(env) - - def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """ - Add batch dimensions to observations if needed. - - Args: - observation: Dictionary of observation tensors. - - Returns: - Dictionary of observation tensors with batch dimensions. - """ - for key in observation: - if "image" in key and observation[key].dim() == 3: - observation[key] = observation[key].unsqueeze(0) - if "state" in key and observation[key].dim() == 1: - observation[key] = observation[key].unsqueeze(0) - if "velocity" in key and observation[key].dim() == 1: - observation[key] = observation[key].unsqueeze(0) - return observation - - -class GripperPenaltyWrapper(gym.RewardWrapper): - """ - Wrapper that adds penalties for inefficient gripper commands. - - This wrapper modifies rewards to discourage excessive gripper movement - or commands that attempt to move the gripper beyond its physical limits. - """ - - def __init__(self, env, penalty: float = -0.1): - """ - Initialize the gripper penalty wrapper. - - Args: - env: The environment to wrap. - penalty: Negative reward value to apply for inefficient gripper actions. - """ - super().__init__(env) - self.penalty = penalty - self.last_gripper_state = None - - def reward(self, reward, action): - """ - Apply penalties to reward based on gripper actions. - - Args: - reward: The original reward from the environment. - action: The action that was taken. - - Returns: - Modified reward with penalty applied if necessary. - """ - gripper_state_normalized = self.last_gripper_state / self.unwrapped.robot.config.max_gripper_pos - - action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND - - gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or ( - gripper_state_normalized > 0.75 and action_normalized < -0.5 - ) - - return reward + self.penalty * int(gripper_penalty_bool) - - def step(self, action): - """ - Step the environment and apply gripper penalties. - - Args: - action: The action to take in the environment. - - Returns: - Tuple of (observation, reward, terminated, truncated, info) with penalty applied. - """ - self.last_gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"] - - gripper_action = action[-1] - obs, reward, terminated, truncated, info = self.env.step(action) - gripper_penalty = self.reward(reward, gripper_action) - - info["discrete_penalty"] = gripper_penalty - - return obs, reward, terminated, truncated, info - - def reset(self, **kwargs): - """ - Reset the environment and penalty tracking. - - Args: - **kwargs: Keyword arguments passed to the wrapped environment's reset. - - Returns: - The initial observation and info with gripper penalty initialized. - """ - self.last_gripper_state = None - obs, info = super().reset(**kwargs) - info["gripper_penalty"] = 0.0 - return obs, info - - -class GripperActionWrapper(gym.ActionWrapper): - """ - Wrapper that processes gripper control commands. - - This wrapper quantizes and processes gripper commands, adding a sleep time between - consecutive gripper actions to prevent rapid toggling. - """ - - def __init__(self, env, quantization_threshold: float = 0.2, gripper_sleep: float = 0.0): - """ - Initialize the gripper action wrapper. - - Args: - env: The environment to wrap. - quantization_threshold: Threshold below which gripper commands are quantized to zero. - gripper_sleep: Minimum time in seconds between consecutive gripper commands. - """ - super().__init__(env) - self.quantization_threshold = quantization_threshold - self.gripper_sleep = gripper_sleep - self.last_gripper_action_time = 0.0 - self.last_gripper_action = None - - def action(self, action): - """ - Process gripper commands in the action. - - Args: - action: The original action from the agent. - - Returns: - Modified action with processed gripper command. - """ - if self.gripper_sleep > 0.0: - if ( - self.last_gripper_action is not None - and time.perf_counter() - self.last_gripper_action_time < self.gripper_sleep - ): - action[-1] = self.last_gripper_action - else: - self.last_gripper_action_time = time.perf_counter() - self.last_gripper_action = action[-1] - - gripper_command = action[-1] - # Gripper actions are between 0, 2 - # we want to quantize them to -1, 0 or 1 - gripper_command = gripper_command - 1.0 - - if self.quantization_threshold is not None: - # Quantize gripper command to -1, 0 or 1 - gripper_command = ( - np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0 - ) - gripper_command = gripper_command * self.unwrapped.robot.config.max_gripper_pos - - gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"] - - gripper_action_value = np.clip( - gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos - ) - action[-1] = gripper_action_value.item() - return action - - def reset(self, **kwargs): - """ - Reset the gripper action tracking. - - Args: - **kwargs: Keyword arguments passed to the wrapped environment's reset. - - Returns: - The initial observation and info. - """ - obs, info = super().reset(**kwargs) - self.last_gripper_action_time = 0.0 - self.last_gripper_action = None - return obs, info - - -class EEObservationWrapper(gym.ObservationWrapper): - """ - Wrapper that adds end-effector pose information to observations. - - This wrapper computes the end-effector pose using forward kinematics - and adds it to the observation space. - """ - - def __init__(self, env, ee_pose_limits): - """ - Initialize the end-effector observation wrapper. - - Args: - env: The environment to wrap. - ee_pose_limits: Dictionary with 'min' and 'max' keys containing limits for EE pose. - """ - super().__init__(env) - - # Extend observation space to include end effector pose - prev_space = self.observation_space["observation.state"] - - self.observation_space["observation.state"] = gym.spaces.Box( - low=np.concatenate([prev_space.low, ee_pose_limits["min"]]), - high=np.concatenate([prev_space.high, ee_pose_limits["max"]]), - shape=(prev_space.shape[0] + 3,), - dtype=np.float32, - ) - - self.kinematics = RobotKinematics( - urdf_path=env.unwrapped.robot.config.urdf_path, - target_frame_name=env.unwrapped.robot.config.target_frame_name, - ) - - def observation(self, observation): - """ - Add end-effector pose to the observation. - - Args: - observation: Original observation from the environment. - - Returns: - Enhanced observation with end-effector pose information. - """ - current_joint_pos = self.unwrapped.current_observation["agent_pos"] - - current_ee_pos = self.kinematics.forward_kinematics(current_joint_pos)[:3, 3] - observation["agent_pos"] = np.concatenate([observation["agent_pos"], current_ee_pos], -1) - return observation - - -########################################################### -# Wrappers related to human intervention and input devices -########################################################### - - -class BaseLeaderControlWrapper(gym.Wrapper): - """ - Base class for leader-follower robot control wrappers. - - This wrapper enables human intervention through a leader-follower robot setup, - where the human can control a leader robot to guide the follower robot's movements. - """ - - def __init__( - self, - env, - teleop_device, - end_effector_step_sizes, - use_geared_leader_arm: bool = False, - use_gripper=False, - ): - """ - Initialize the base leader control wrapper. - - Args: - env: The environment to wrap. - teleop_device: The teleoperation device. - use_geared_leader_arm: Whether to use a geared leader arm setup. - use_gripper: Whether to include gripper control. - """ - super().__init__(env) - self.robot_leader = teleop_device - self.robot_follower = env.unwrapped.robot - self.use_geared_leader_arm = use_geared_leader_arm - self.use_gripper: bool = use_gripper - self.end_effector_step_sizes = np.array(list(end_effector_step_sizes.values())) - - # Set up keyboard event tracking - self._init_keyboard_events() - self.event_lock = Lock() # Thread-safe access to events - - # Initialize robot control - self.kinematics = RobotKinematics( - urdf_path=env.unwrapped.robot.config.urdf_path, - target_frame_name=env.unwrapped.robot.config.target_frame_name, - ) - self.leader_torque_enabled = True - self.prev_leader_gripper = None - - # Configure leader arm - # NOTE: Lower the gains of leader arm for automatic take-over - # With lower gains we can manually move the leader arm without risk of injury to ourselves or the robot - # With higher gains, it would be dangerous and difficult to modify the leader's pose while torque is enabled - # Default value for P_coeff is 32 - self.robot_leader.bus.sync_write("Torque_Enable", 1) - for motor in self.robot_leader.bus.motors: - self.robot_leader.bus.write("P_Coefficient", motor, 16) - self.robot_leader.bus.write("I_Coefficient", motor, 0) - self.robot_leader.bus.write("D_Coefficient", motor, 16) - - self.leader_tracking_error_queue = deque(maxlen=4) - self._init_keyboard_listener() - - def _init_keyboard_events(self): - """ - Initialize the keyboard events dictionary. - - This method sets up tracking for keyboard events used for intervention control. - It should be overridden in subclasses to add additional events. - """ - self.keyboard_events = { - "episode_success": False, - "episode_end": False, - "rerecord_episode": False, - } - - def _handle_key_press(self, key, keyboard_device): - """ - Handle key press events. - - Args: - key: The key that was pressed. - keyboard: The keyboard module with key definitions. - - This method should be overridden in subclasses for additional key handling. - """ - try: - if key == keyboard_device.Key.esc: - self.keyboard_events["episode_end"] = True - return - if key == keyboard_device.Key.left: - self.keyboard_events["rerecord_episode"] = True - return - if hasattr(key, "char") and key.char == "s": - logging.info("Key 's' pressed. Episode success triggered.") - self.keyboard_events["episode_success"] = True - return - except Exception as e: - logging.error(f"Error handling key press: {e}") - - def _init_keyboard_listener(self): - """ - Initialize the keyboard listener for intervention control. - - This method sets up keyboard event handling if not in headless mode. - """ - from pynput import keyboard as keyboard_device - - def on_press(key): - with self.event_lock: - self._handle_key_press(key, keyboard_device) - - self.listener = keyboard_device.Listener(on_press=on_press) - self.listener.start() - - def _check_intervention(self): - """ - Check if human intervention is needed. - - Returns: - Boolean indicating whether intervention is needed. - - This method should be overridden in subclasses with specific intervention logic. - """ - return False - - def _handle_intervention(self, action): - """ - Process actions during intervention mode. - - Args: - action: The original action from the agent. - - Returns: - Tuple of (modified_action, intervention_action). - """ - if self.leader_torque_enabled: - self.robot_leader.bus.sync_write("Torque_Enable", 0) - self.leader_torque_enabled = False - - leader_pos_dict = self.robot_leader.bus.sync_read("Present_Position") - follower_pos_dict = self.robot_follower.bus.sync_read("Present_Position") - - leader_pos = np.array([leader_pos_dict[name] for name in leader_pos_dict]) - follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict]) - - self.leader_tracking_error_queue.append(np.linalg.norm(follower_pos[:-1] - leader_pos[:-1])) - - # [:3, 3] Last column of the transformation matrix corresponds to the xyz translation - leader_ee = self.kinematics.forward_kinematics(leader_pos)[:3, 3] - follower_ee = self.kinematics.forward_kinematics(follower_pos)[:3, 3] - - action = np.clip(leader_ee - follower_ee, -self.end_effector_step_sizes, self.end_effector_step_sizes) - # Normalize the action to the range [-1, 1] - action = action / self.end_effector_step_sizes - - if self.use_gripper: - if self.prev_leader_gripper is None: - self.prev_leader_gripper = np.clip( - leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos - ) - - # Get gripper action delta based on leader pose - leader_gripper = leader_pos[-1] - gripper_delta = leader_gripper - self.prev_leader_gripper - - # Normalize by max angle and quantize to {0,1,2} - normalized_delta = gripper_delta / self.robot_follower.config.max_gripper_pos - if normalized_delta >= 0.3: - gripper_action = 2 - elif normalized_delta <= 0.1: - gripper_action = 0 - else: - gripper_action = 1 - - action = np.append(action, gripper_action) - - return action - - def _handle_leader_teleoperation(self): - """ - Handle leader teleoperation in non-intervention mode. - - This method synchronizes the leader robot position with the follower. - """ - - prev_leader_pos_dict = self.robot_leader.bus.sync_read("Present_Position") - prev_leader_pos = np.array( - [prev_leader_pos_dict[name] for name in prev_leader_pos_dict], dtype=np.float32 - ) - - if not self.leader_torque_enabled: - self.robot_leader.bus.sync_write("Torque_Enable", 1) - self.leader_torque_enabled = True - - follower_pos_dict = self.robot_follower.bus.sync_read("Present_Position") - follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict], dtype=np.float32) - - goal_pos = {f"{motor}": follower_pos[i] for i, motor in enumerate(self.robot_leader.bus.motors)} - self.robot_leader.bus.sync_write("Goal_Position", goal_pos) - - self.leader_tracking_error_queue.append(np.linalg.norm(follower_pos[:-1] - prev_leader_pos[:-1])) - - def step(self, action): - """ - Execute a step with possible human intervention. - - Args: - action: The action to take in the environment. - - Returns: - Tuple of (observation, reward, terminated, truncated, info). - """ - is_intervention = self._check_intervention() - - # NOTE: - if is_intervention: - action = self._handle_intervention(action) - else: - self._handle_leader_teleoperation() - - # NOTE: - obs, reward, terminated, truncated, info = self.env.step(action) - - if isinstance(action, np.ndarray): - action = torch.from_numpy(action) - - # Add intervention info - info["is_intervention"] = is_intervention - info["action_intervention"] = action - - self.prev_leader_gripper = np.clip( - self.robot_leader.bus.sync_read("Present_Position")["gripper"], - 0, - self.robot_follower.config.max_gripper_pos, - ) - - # Check for success or manual termination - success = self.keyboard_events["episode_success"] - terminated = terminated or self.keyboard_events["episode_end"] or success - - if success: - reward = 1.0 - logging.info("Episode ended successfully with reward 1.0") - - return obs, reward, terminated, truncated, info - - def reset(self, **kwargs): - """ - Reset the environment and intervention state. - - Args: - **kwargs: Keyword arguments passed to the wrapped environment's reset. - - Returns: - The initial observation and info. - """ - self.keyboard_events = dict.fromkeys(self.keyboard_events, False) - self.leader_tracking_error_queue.clear() - return super().reset(**kwargs) - - def close(self): - """ - Clean up resources, including stopping keyboard listener. - - Returns: - Result of closing the wrapped environment. - """ - if hasattr(self, "listener") and self.listener is not None: - self.listener.stop() - return self.env.close() - - -class GearedLeaderControlWrapper(BaseLeaderControlWrapper): - """ - Wrapper that enables manual intervention via keyboard. - - This wrapper extends the BaseLeaderControlWrapper to allow explicit toggling - of human intervention mode with keyboard controls. - """ - - def _init_keyboard_events(self): - """ - Initialize keyboard events including human intervention flag. - - Extends the base class dictionary with an additional flag for tracking - intervention state toggled by keyboard. - """ - super()._init_keyboard_events() - self.keyboard_events["human_intervention_step"] = False - - def _handle_key_press(self, key, keyboard_device): - """ - Handle key presses including space for intervention toggle. - - Args: - key: The key that was pressed. - keyboard: The keyboard module with key definitions. - - Extends the base handler to respond to space key for toggling intervention. - """ - super()._handle_key_press(key, keyboard_device) - if key == keyboard_device.Key.space: - if not self.keyboard_events["human_intervention_step"]: - logging.info( - "Space key pressed. Human intervention required.\n" - "Place the leader in similar pose to the follower and press space again." - ) - self.keyboard_events["human_intervention_step"] = True - log_say("Human intervention step.", play_sounds=True) - else: - self.keyboard_events["human_intervention_step"] = False - logging.info("Space key pressed for a second time.\nContinuing with policy actions.") - log_say("Continuing with policy actions.", play_sounds=True) - - def _check_intervention(self): - """ - Check if human intervention is active based on keyboard toggle. - - Returns: - Boolean indicating whether intervention mode is active. - """ - return self.keyboard_events["human_intervention_step"] - - -class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper): - """ - Wrapper with automatic intervention based on error thresholds. - - This wrapper monitors the error between leader and follower positions - and automatically triggers intervention when error exceeds thresholds. - """ - - def __init__( - self, - env, - teleop_device, - end_effector_step_sizes, - use_gripper=False, - intervention_threshold=10.0, - release_threshold=1e-2, - ): - """ - Initialize the automatic intervention wrapper. - - Args: - env: The environment to wrap. - teleop_device: The teleoperation device. - use_gripper: Whether to include gripper control. - intervention_threshold: Error threshold to trigger intervention. - release_threshold: Error threshold to release intervention. - queue_size: Number of error measurements to track for smoothing. - """ - super().__init__(env, teleop_device, end_effector_step_sizes, use_gripper=use_gripper) - - # Error tracking parameters - self.intervention_threshold = intervention_threshold # Threshold to trigger intervention - self.release_threshold = release_threshold # Threshold to release intervention - self.is_intervention_active = False - self.start_time = time.perf_counter() - - def _check_intervention(self): - """ - Determine if intervention should occur based on the rate of change of leader-follower error in end_effector space. - - This method monitors the rate of change of leader-follower error in end_effector space - and automatically triggers intervention when the rate of change exceeds - the intervention threshold, releasing when it falls below the release threshold. - - Returns: - Boolean indicating whether intervention should be active. - """ - - # Condition for starting the intervention - # If the error in teleoperation is too high, that means the a user has grasped the leader robot and he wants to take over - if ( - not self.is_intervention_active - and len(self.leader_tracking_error_queue) == self.leader_tracking_error_queue.maxlen - and np.var(list(self.leader_tracking_error_queue)[-2:]) > self.intervention_threshold - ): - self.is_intervention_active = True - self.leader_tracking_error_queue.clear() - log_say("Intervention started", play_sounds=True) - return True - - # Track the error over time in leader_tracking_error_queue - # If the variance of the tracking error is too low, that means the user has let go of the leader robot and the intervention is over - if ( - self.is_intervention_active - and len(self.leader_tracking_error_queue) == self.leader_tracking_error_queue.maxlen - and np.var(self.leader_tracking_error_queue) < self.release_threshold - ): - self.is_intervention_active = False - self.leader_tracking_error_queue.clear() - log_say("Intervention ended", play_sounds=True) - return False - - # If not change has happened that merits a change in the intervention state, return the current state - return self.is_intervention_active - - def reset(self, **kwargs): - """ - Reset error tracking on environment reset. - - Args: - **kwargs: Keyword arguments passed to the wrapped environment's reset. - - Returns: - The initial observation and info. - """ - self.is_intervention_active = False - return super().reset(**kwargs) - - -class GamepadControlWrapper(gym.Wrapper): - """ - Wrapper that allows controlling a gym environment with a gamepad. - - This wrapper intercepts the step method and allows human input via gamepad - to override the agent's actions when desired. - """ - - def __init__( - self, - env, - teleop_device, # Accepts an instantiated teleoperator - use_gripper=False, # This should align with teleop_device's config - auto_reset=False, - ): - """ - Initialize the gamepad controller wrapper. - - Args: - env: The environment to wrap. - teleop_device: The instantiated teleoperation device (e.g., GamepadTeleop). - use_gripper: Whether to include gripper control (should match teleop_device.config.use_gripper). - auto_reset: Whether to auto reset the environment when episode ends. - """ - super().__init__(env) - - self.teleop_device = teleop_device - # Ensure the teleop_device is connected if it has a connect method - if hasattr(self.teleop_device, "connect") and not self.teleop_device.is_connected: - self.teleop_device.connect() - - # self.controller attribute is removed - - self.auto_reset = auto_reset - # use_gripper from args should ideally match teleop_device.config.use_gripper - # For now, we use the one passed, but it can lead to inconsistency if not set correctly from config - self.use_gripper = use_gripper - - logging.info("Gamepad control wrapper initialized with provided teleop_device.") - print( - "Gamepad controls (managed by the provided teleop_device - specific button mappings might vary):" - ) - print(" Left analog stick: Move in X-Y plane") - print(" Right analog stick: Move in Z axis (up/down)") - print(" X/Square button: End episode (FAILURE)") - print(" Y/Triangle button: End episode (SUCCESS)") - print(" B/Circle button: Exit program") - - def get_teleop_commands( - self, - ) -> tuple[bool, np.ndarray, bool, bool, bool]: - """ - Get the current action from the gamepad if any input is active. - - Returns: - Tuple containing: - - is_active: Whether gamepad input is active (from teleop_device.gamepad.should_intervene()) - - action: The action derived from gamepad input (from teleop_device.get_action()) - - terminate_episode: Whether episode termination was requested - - success: Whether episode success was signaled - - rerecord_episode: Whether episode rerecording was requested - """ - if not hasattr(self.teleop_device, "gamepad") or self.teleop_device.gamepad is None: - raise AttributeError( - "teleop_device does not have a 'gamepad' attribute or it is None. Expected for GamepadControlWrapper." - ) - - # Get status flags from the underlying gamepad controller within the teleop_device - self.teleop_device.gamepad.update() # Ensure gamepad state is fresh - intervention_is_active = self.teleop_device.gamepad.should_intervene() - episode_end_status = self.teleop_device.gamepad.get_episode_end_status() - - terminate_episode = episode_end_status is not None - success = episode_end_status == "success" - rerecord_episode = episode_end_status == "rerecord_episode" - - # Get the action dictionary from the teleop_device - action_dict = self.teleop_device.get_action() - - # Convert action_dict to numpy array based on expected structure - # Order: delta_x, delta_y, delta_z, gripper (if use_gripper) - action_list = [action_dict["delta_x"], action_dict["delta_y"], action_dict["delta_z"]] - if self.use_gripper: - # GamepadTeleop returns gripper action as 0 (close), 1 (stay), 2 (open) - # This needs to be consistent with what EEActionWrapper expects if it's used downstream - # EEActionWrapper for gripper typically expects 0.0 (closed) to 2.0 (open) - # For now, we pass the direct value from GamepadTeleop, ensure downstream compatibility. - gripper_val = action_dict.get("gripper", 1.0) # Default to 1.0 (stay) if not present - action_list.append(float(gripper_val)) - - gamepad_action_np = np.array(action_list, dtype=np.float32) - - return ( - intervention_is_active, - gamepad_action_np, - terminate_episode, - success, - rerecord_episode, - ) - - def step(self, action): - """ - Step the environment, using gamepad input to override actions when active. - - Args: - action: Original action from agent. - - Returns: - Tuple of (observation, reward, terminated, truncated, info). - """ - # Get gamepad state and action - ( - is_intervention, - gamepad_action, - terminate_episode, - success, - rerecord_episode, - ) = self.get_teleop_commands() - - # Update episode ending state if requested - if terminate_episode: - logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}") - - # Only override the action if gamepad is active - action = gamepad_action if is_intervention else action - - # Step the environment - obs, reward, terminated, truncated, info = self.env.step(action) - - # Add episode ending if requested via gamepad - terminated = terminated or truncated or terminate_episode - - if success: - reward = 1.0 - logging.info("Episode ended successfully with reward 1.0") - - if isinstance(action, np.ndarray): - action = torch.from_numpy(action) - - info["is_intervention"] = is_intervention - # The original `BaseLeaderControlWrapper` puts `action_intervention` in info. - # For Gamepad, if intervention, `gamepad_action` is the intervention. - # If not intervention, policy's action is `action`. - # For consistency, let's store the *human's* action if intervention occurred. - info["action_intervention"] = action - - info["rerecord_episode"] = rerecord_episode - - # If episode ended, reset the state - if terminated or truncated: - # Add success/failure information to info dict - info["next.success"] = success - - # Auto reset if configured - if self.auto_reset: - obs, reset_info = self.reset() - info.update(reset_info) - - return obs, reward, terminated, truncated, info - - def close(self): - """ - Clean up resources when environment closes. - - Returns: - Result of closing the wrapped environment. - """ - if hasattr(self.teleop_device, "disconnect"): - self.teleop_device.disconnect() - - # Call the parent close method - return self.env.close() - - -class KeyboardControlWrapper(GamepadControlWrapper): - """ - Wrapper that allows controlling a gym environment with a keyboard. - - This wrapper intercepts the step method and allows human input via keyboard - to override the agent's actions when desired. - - Inherits from GamepadControlWrapper to avoid code duplication. - """ - - def __init__( - self, - env, - teleop_device, # Accepts an instantiated teleoperator - use_gripper=False, # This should align with teleop_device's config - auto_reset=False, - ): - """ - Initialize the gamepad controller wrapper. - - Args: - env: The environment to wrap. - teleop_device: The instantiated teleoperation device (e.g., GamepadTeleop). - use_gripper: Whether to include gripper control (should match teleop_device.config.use_gripper). - auto_reset: Whether to auto reset the environment when episode ends. - """ - super().__init__(env, teleop_device, use_gripper, auto_reset) - - self.is_intervention_active = False - - logging.info("Keyboard control wrapper initialized with provided teleop_device.") - print("Keyboard controls:") - print(" Arrow keys: Move in X-Y plane") - print(" Shift and Shift_R: Move in Z axis") - print(" Right Ctrl and Left Ctrl: Open and close gripper") - print(" f: End episode with FAILURE") - print(" s: End episode with SUCCESS") - print(" r: End episode with RERECORD") - print(" i: Start/Stop Intervention") - - def get_teleop_commands( - self, - ) -> tuple[bool, np.ndarray, bool, bool, bool]: - action_dict = self.teleop_device.get_action() - episode_end_status = None - - # Unroll the misc_keys_queue to check for events related to intervention, episode success, etc. - while not self.teleop_device.misc_keys_queue.empty(): - key = self.teleop_device.misc_keys_queue.get() - if key == "i": - self.is_intervention_active = not self.is_intervention_active - elif key == "f": - episode_end_status = "failure" - elif key == "s": - episode_end_status = "success" - elif key == "r": - episode_end_status = "rerecord_episode" - - terminate_episode = episode_end_status is not None - success = episode_end_status == "success" - rerecord_episode = episode_end_status == "rerecord_episode" - - # Convert action_dict to numpy array based on expected structure - # Order: delta_x, delta_y, delta_z, gripper (if use_gripper) - action_list = [action_dict["delta_x"], action_dict["delta_y"], action_dict["delta_z"]] - if self.use_gripper: - # GamepadTeleop returns gripper action as 0 (close), 1 (stay), 2 (open) - # This needs to be consistent with what EEActionWrapper expects if it's used downstream - # EEActionWrapper for gripper typically expects 0.0 (closed) to 2.0 (open) - # For now, we pass the direct value from GamepadTeleop, ensure downstream compatibility. - gripper_val = action_dict.get("gripper", 1.0) # Default to 1.0 (stay) if not present - action_list.append(float(gripper_val)) - - gamepad_action_np = np.array(action_list, dtype=np.float32) - - return ( - self.is_intervention_active, - gamepad_action_np, - terminate_episode, - success, - rerecord_episode, - ) - - -class GymHilDeviceWrapper(gym.Wrapper): - def __init__(self, env, device="cpu"): - super().__init__(env) - self.device = device - - def step(self, action): - obs, reward, terminated, truncated, info = self.env.step(action) - for k in obs: - obs[k] = obs[k].to(self.device) - if "action_intervention" in info: - # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device - info["action_intervention"] = info["action_intervention"].astype(np.float32) - info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) - return obs, reward, terminated, truncated, info - - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): - obs, info = self.env.reset(seed=seed, options=options) - for k in obs: - obs[k] = obs[k].to(self.device) - if "action_intervention" in info: - # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device - info["action_intervention"] = info["action_intervention"].astype(np.float32) - info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) - return obs, info - - -class GymHilObservationProcessorWrapper(gym.ObservationWrapper): - def __init__(self, env: gym.Env): - super().__init__(env) - prev_space = self.observation_space - new_space = {} - - for key in prev_space: - if "pixels" in key: - for k in prev_space["pixels"]: - new_space[f"observation.images.{k}"] = gym.spaces.Box( - 0.0, 255.0, shape=(3, 128, 128), dtype=np.uint8 - ) - - if key == "agent_pos": - new_space["observation.state"] = prev_space["agent_pos"] - - self.observation_space = gym.spaces.Dict(new_space) - - def observation(self, observation: dict[str, Any]) -> dict[str, Any]: - return preprocess_observation(observation) - - -########################################################### -# Factory functions -########################################################### - - -def make_robot_env(cfg: EnvConfig) -> gym.Env: - """ - Factory function to create a robot environment. - - This function builds a robot environment with all necessary wrappers - based on the provided configuration. +def make_robot_env(cfg: HILSerlRobotEnvConfig) -> tuple[gym.Env, Any]: + """Create robot environment from configuration. Args: - cfg: Configuration object containing environment parameters. + cfg: Environment configuration. Returns: - A gym environment with all necessary wrappers applied. + Tuple of (gym environment, teleoperator device). """ - if cfg.type == "hil": + # Check if this is a GymHIL simulation environment + if cfg.name == "gym_hil": + assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop" import gym_hil # noqa: F401 - # TODO (azouitine) + # Extract gripper settings with defaults + use_gripper = cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else True + gripper_penalty = cfg.processor.gripper.gripper_penalty if cfg.processor.gripper is not None else 0.0 + env = gym.make( f"gym_hil/{cfg.task}", image_obs=True, render_mode="human", - use_gripper=cfg.wrapper.use_gripper, - gripper_penalty=cfg.wrapper.gripper_penalty, - ) - env = GymHilObservationProcessorWrapper(env=env) - env = GymHilDeviceWrapper(env=env, device=cfg.device) - env = BatchCompatibleWrapper(env=env) - env = TorchActionWrapper(env=env, device=cfg.device) - return env - - if not hasattr(cfg, "robot") or not hasattr(cfg, "teleop"): - raise ValueError( - "Configuration for 'gym_manipulator' must be HILSerlRobotEnvConfig with robot and teleop." + use_gripper=use_gripper, + gripper_penalty=gripper_penalty, ) - if cfg.robot is None: - raise ValueError("RobotConfig (cfg.robot) must be provided for gym_manipulator environment.") + return env, None + + # Real robot environment + assert cfg.robot is not None, "Robot config must be provided for real robot environment" + assert cfg.teleop is not None, "Teleop config must be provided for real robot environment" + robot = make_robot_from_config(cfg.robot) teleop_device = make_teleoperator_from_config(cfg.teleop) teleop_device.connect() - # Create base environment + # Create base environment with safe defaults + use_gripper = cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else True + display_cameras = ( + cfg.processor.observation.display_cameras if cfg.processor.observation is not None else False + ) + reset_pose = cfg.processor.reset.fixed_reset_joint_positions if cfg.processor.reset is not None else None + env = RobotEnv( robot=robot, - use_gripper=cfg.wrapper.use_gripper, - display_cameras=cfg.wrapper.display_cameras if cfg.wrapper else False, + use_gripper=use_gripper, + display_cameras=display_cameras, + reset_pose=reset_pose, ) - # Add observation and image processing - if cfg.wrapper: - if cfg.wrapper.add_joint_velocity_to_observation: - env = AddJointVelocityToObservation(env=env, fps=cfg.fps) - if cfg.wrapper.add_current_to_observation: - env = AddCurrentToObservation(env=env) - if cfg.wrapper.add_ee_pose_to_observation: - env = EEObservationWrapper(env=env, ee_pose_limits=robot.end_effector_bounds) - - env = ConvertToLeRobotObservation(env=env, device=cfg.device) - - if cfg.wrapper and cfg.wrapper.crop_params_dict is not None: - env = ImageCropResizeWrapper( - env=env, - crop_params_dict=cfg.wrapper.crop_params_dict, - resize_size=cfg.wrapper.resize_size, - ) - - # Add reward computation and control wrappers - reward_classifier = init_reward_classifier(cfg) - if reward_classifier is not None: - env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) - - env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) - if cfg.wrapper.use_gripper and cfg.wrapper.gripper_penalty is not None: - env = GripperPenaltyWrapper( - env=env, - penalty=cfg.wrapper.gripper_penalty, - ) - - # Control mode specific wrappers - control_mode = cfg.wrapper.control_mode - if control_mode == "gamepad": - assert isinstance(teleop_device, GamepadTeleop), ( - "teleop_device must be an instance of GamepadTeleop for gamepad control mode" - ) - env = GamepadControlWrapper( - env=env, - teleop_device=teleop_device, - use_gripper=cfg.wrapper.use_gripper, - ) - elif control_mode == "keyboard_ee": - assert isinstance(teleop_device, KeyboardEndEffectorTeleop), ( - "teleop_device must be an instance of KeyboardEndEffectorTeleop for keyboard control mode" - ) - env = KeyboardControlWrapper( - env=env, - teleop_device=teleop_device, - use_gripper=cfg.wrapper.use_gripper, - ) - elif control_mode == "leader": - env = GearedLeaderControlWrapper( - env=env, - teleop_device=teleop_device, - end_effector_step_sizes=cfg.robot.end_effector_step_sizes, - use_gripper=cfg.wrapper.use_gripper, - ) - elif control_mode == "leader_automatic": - env = GearedLeaderAutomaticControlWrapper( - env=env, - teleop_device=teleop_device, - end_effector_step_sizes=cfg.robot.end_effector_step_sizes, - use_gripper=cfg.wrapper.use_gripper, - ) - else: - raise ValueError(f"Invalid control mode: {control_mode}") - - env = ResetWrapper( - env=env, - reset_pose=cfg.wrapper.fixed_reset_joint_positions, - reset_time_s=cfg.wrapper.reset_time_s, - ) - - env = BatchCompatibleWrapper(env=env) - env = TorchActionWrapper(env=env, device=cfg.device) - - return env + return env, teleop_device -def init_reward_classifier(cfg): - """ - Load a reward classifier policy from a pretrained path if configured. +def make_processors( + env: gym.Env, teleop_device: Teleoperator | None, cfg: HILSerlRobotEnvConfig, device: str = "cpu" +) -> tuple[ + DataProcessorPipeline[EnvTransition, EnvTransition], DataProcessorPipeline[EnvTransition, EnvTransition] +]: + """Create environment and action processors. Args: - cfg: The environment configuration containing classifier paths. + env: Robot environment instance. + teleop_device: Teleoperator device for intervention. + cfg: Processor configuration. + device: Target device for computations. Returns: - The loaded classifier model or None if not configured. + Tuple of (environment processor, action processor). """ - if cfg.reward_classifier_pretrained_path is None: - return None - - from lerobot.policies.sac.reward_model.modeling_classifier import Classifier - - # Get device from config or default to CUDA - device = getattr(cfg, "device", "cpu") - - # Load the classifier directly using from_pretrained - classifier = Classifier.from_pretrained( - pretrained_name_or_path=cfg.reward_classifier_pretrained_path, + terminate_on_success = ( + cfg.processor.reset.terminate_on_success if cfg.processor.reset is not None else True ) - # Ensure model is on the correct device - classifier.to(device) - classifier.eval() # Set to evaluation mode + if cfg.name == "gym_hil": + action_pipeline_steps = [ + InterventionActionProcessorStep(terminate_on_success=terminate_on_success), + Torch2NumpyActionProcessorStep(), + ] - return classifier + env_pipeline_steps = [ + Numpy2TorchActionProcessorStep(), + VanillaObservationProcessorStep(), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=device), + ] + + return DataProcessorPipeline( + steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition + ), DataProcessorPipeline( + steps=action_pipeline_steps, to_transition=identity_transition, to_output=identity_transition + ) + + # Full processor pipeline for real robot environment + # Get robot and motor information for kinematics + motor_names = list(env.robot.bus.motors.keys()) + + # Set up kinematics solver if inverse kinematics is configured + kinematics_solver = None + if cfg.processor.inverse_kinematics is not None: + kinematics_solver = RobotKinematics( + urdf_path=cfg.processor.inverse_kinematics.urdf_path, + target_frame_name=cfg.processor.inverse_kinematics.target_frame_name, + joint_names=motor_names, + ) + + env_pipeline_steps = [VanillaObservationProcessorStep()] + + if cfg.processor.observation is not None: + if cfg.processor.observation.add_joint_velocity_to_observation: + env_pipeline_steps.append(JointVelocityProcessorStep(dt=1.0 / cfg.fps)) + if cfg.processor.observation.add_current_to_observation: + env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot)) + + if kinematics_solver is not None: + env_pipeline_steps.append( + ForwardKinematicsJointsToEEObservation( + kinematics=kinematics_solver, + motor_names=motor_names, + ) + ) + + if cfg.processor.image_preprocessing is not None: + env_pipeline_steps.append( + ImageCropResizeProcessorStep( + crop_params_dict=cfg.processor.image_preprocessing.crop_params_dict, + resize_size=cfg.processor.image_preprocessing.resize_size, + ) + ) + + # Add time limit processor if reset config exists + if cfg.processor.reset is not None: + env_pipeline_steps.append( + TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps)) + ) + + # Add gripper penalty processor if gripper config exists and enabled + if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper: + env_pipeline_steps.append( + GripperPenaltyProcessorStep( + penalty=cfg.processor.gripper.gripper_penalty, + max_gripper_pos=cfg.processor.max_gripper_pos, + ) + ) + + if ( + cfg.processor.reward_classifier is not None + and cfg.processor.reward_classifier.pretrained_path is not None + ): + env_pipeline_steps.append( + RewardClassifierProcessorStep( + pretrained_path=cfg.processor.reward_classifier.pretrained_path, + device=device, + success_threshold=cfg.processor.reward_classifier.success_threshold, + success_reward=cfg.processor.reward_classifier.success_reward, + terminate_on_success=terminate_on_success, + ) + ) + + env_pipeline_steps.append(AddBatchDimensionProcessorStep()) + env_pipeline_steps.append(DeviceProcessorStep(device=device)) + + action_pipeline_steps = [ + AddTeleopActionAsComplimentaryDataStep(teleop_device=teleop_device), + AddTeleopEventsAsInfoStep(teleop_device=teleop_device), + InterventionActionProcessorStep( + use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False, + terminate_on_success=terminate_on_success, + ), + ] + + # Replace InverseKinematicsProcessor with new kinematic processors + if cfg.processor.inverse_kinematics is not None and kinematics_solver is not None: + # Add EE bounds and safety processor + inverse_kinematics_steps = [ + MapTensorToDeltaActionDictStep( + use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False + ), + MapDeltaActionToRobotActionStep(), + EEReferenceAndDelta( + kinematics=kinematics_solver, + end_effector_step_sizes=cfg.processor.inverse_kinematics.end_effector_step_sizes, + motor_names=motor_names, + use_latched_reference=False, + use_ik_solution=True, + ), + EEBoundsAndSafety( + end_effector_bounds=cfg.processor.inverse_kinematics.end_effector_bounds, + ), + GripperVelocityToJoint( + clip_max=cfg.processor.max_gripper_pos, + speed_factor=1.0, + discrete_gripper=True, + ), + InverseKinematicsRLStep( + kinematics=kinematics_solver, motor_names=motor_names, initial_guess_current_joints=False + ), + ] + action_pipeline_steps.extend(inverse_kinematics_steps) + action_pipeline_steps.append(RobotActionToPolicyActionProcessorStep(motor_names=motor_names)) + + return DataProcessorPipeline( + steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition + ), DataProcessorPipeline( + steps=action_pipeline_steps, to_transition=identity_transition, to_output=identity_transition + ) -########################################################### -# Record and replay functions -########################################################### - - -def record_dataset(env, policy, cfg): +def step_env_and_process_transition( + env: gym.Env, + transition: EnvTransition, + action: torch.Tensor, + env_processor: DataProcessorPipeline[EnvTransition, EnvTransition], + action_processor: DataProcessorPipeline[EnvTransition, EnvTransition], +) -> EnvTransition: """ - Record a dataset of robot interactions using either a policy or teleop. - - This function runs episodes in the environment and records the observations, - actions, and results for dataset creation. + Execute one step with processor pipeline. Args: - env: The environment to record from. - policy: Optional policy to generate actions (if None, uses teleop). - cfg: Configuration object containing recording parameters like: - - repo_id: Repository ID for dataset storage - - dataset_root: Local root directory for dataset - - num_episodes: Number of episodes to record - - fps: Frames per second for recording - - push_to_hub: Whether to push dataset to Hugging Face Hub - - task: Name/description of the task being recorded - - number_of_steps_after_success: Number of additional steps to continue recording after - a success (reward=1) is detected. This helps collect - more positive examples for reward classifier training. + env: The robot environment + transition: Current transition state + action: Action to execute + env_processor: Environment processor + action_processor: Action processor + + Returns: + Processed transition with updated state. """ - from lerobot.datasets.lerobot_dataset import LeRobotDataset - # Setup initial action (zero action if using teleop) - action = env.action_space.sample() * 0.0 - - action_names = ["delta_x_ee", "delta_y_ee", "delta_z_ee"] - if cfg.wrapper.use_gripper: - action_names.append("gripper_delta") - - # Configure dataset features based on environment spaces - features = { - "observation.state": { - "dtype": "float32", - "shape": env.observation_space["observation.state"].shape, - "names": None, - }, - "action": { - "dtype": "float32", - "shape": (len(action_names),), - "names": action_names, - }, - "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, - "next.done": {"dtype": "bool", "shape": (1,), "names": None}, - "complementary_info.discrete_penalty": { - "dtype": "float32", - "shape": (1,), - "names": ["discrete_penalty"], - }, - } - - # Add image features - for key in env.observation_space: - if "image" in key: - features[key] = { - "dtype": "video", - "shape": env.observation_space[key].shape, - "names": ["channels", "height", "width"], - } - - # Create dataset - dataset = LeRobotDataset.create( - cfg.repo_id, - cfg.fps, - root=cfg.dataset_root, - use_videos=True, - image_writer_threads=4, - image_writer_processes=0, - features=features, + # Create action transition + transition[TransitionKey.ACTION] = action + transition[TransitionKey.OBSERVATION] = ( + env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {} ) + processed_action_transition = action_processor(transition) + processed_action = processed_action_transition[TransitionKey.ACTION] - # Record episodes - episode_index = 0 - recorded_action = None - while episode_index < cfg.num_episodes: - obs, _ = env.reset() - start_episode_t = time.perf_counter() - log_say(f"Recording episode {episode_index}", play_sounds=True) + obs, reward, terminated, truncated, info = env.step(processed_action) - # Track success state collection - success_detected = False - success_steps_collected = 0 + reward = reward + processed_action_transition[TransitionKey.REWARD] + terminated = terminated or processed_action_transition[TransitionKey.DONE] + truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED] + complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy() + new_info = processed_action_transition[TransitionKey.INFO].copy() + new_info.update(info) - # Run episode steps - while time.perf_counter() - start_episode_t < cfg.wrapper.control_time_s: - start_loop_t = time.perf_counter() + new_transition = create_transition( + observation=obs, + action=processed_action, + reward=reward, + done=terminated, + truncated=truncated, + info=new_info, + complementary_data=complementary_data, + ) + new_transition = env_processor(new_transition) - # Get action from policy if available - if cfg.pretrained_policy_name_or_path is not None: - action = policy.select_action(obs) + return new_transition - # Step environment - obs, reward, terminated, truncated, info = env.step(action) - # Check if episode needs to be rerecorded - if info.get("rerecord_episode", False): - break +def control_loop( + env: gym.Env, + env_processor: DataProcessorPipeline[EnvTransition, EnvTransition], + action_processor: DataProcessorPipeline[EnvTransition, EnvTransition], + teleop_device: Teleoperator, + cfg: GymManipulatorConfig, +) -> None: + """Main control loop for robot environment interaction. + if cfg.mode == "record": then a dataset will be created and recorded - # For teleop, get action from intervention - recorded_action = { - "action": info["action_intervention"].cpu().squeeze(0).float() if policy is None else action + Args: + env: The robot environment + env_processor: Environment processor + action_processor: Action processor + teleop_device: Teleoperator device + cfg: gym_manipulator configuration + """ + dt = 1.0 / cfg.env.fps + + print(f"Starting control loop at {cfg.env.fps} FPS") + print("Controls:") + print("- Use gamepad/teleop device for intervention") + print("- When not intervening, robot will stay still") + print("- Press Ctrl+C to exit") + + # Reset environment and processors + obs, info = env.reset() + complementary_data = ( + {"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {} + ) + env_processor.reset() + action_processor.reset() + + # Process initial observation + transition = create_transition(observation=obs, info=info, complementary_data=complementary_data) + transition = env_processor(data=transition) + + # Determine if gripper is used + use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True + + dataset = None + if cfg.mode == "record": + action_features = teleop_device.action_features + features = { + "action": action_features, + "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, + "next.done": {"dtype": "bool", "shape": (1,), "names": None}, + } + if use_gripper: + features["complementary_info.discrete_penalty"] = { + "dtype": "float32", + "shape": (1,), + "names": ["discrete_penalty"], } - # Process observation for dataset - obs_processed = {k: v.cpu().squeeze(0).float() for k, v in obs.items()} + for key, value in transition[TransitionKey.OBSERVATION].items(): + if key == "observation.state": + features[key] = { + "dtype": "float32", + "shape": value.squeeze(0).shape, + "names": None, + } + if "image" in key: + features[key] = { + "dtype": "video", + "shape": value.squeeze(0).shape, + "names": ["channels", "height", "width"], + } - # Check if we've just detected success - if reward == 1.0 and not success_detected: - success_detected = True - logging.info("Success detected! Collecting additional success states.") + # Create dataset + dataset = LeRobotDataset.create( + cfg.dataset.repo_id, + cfg.env.fps, + root=cfg.dataset.root, + use_videos=True, + image_writer_threads=4, + image_writer_processes=0, + features=features, + ) - # Add frame to dataset - continue marking as success even during extra collection steps - frame = {**obs_processed, **recorded_action} + episode_idx = 0 + episode_step = 0 + episode_start_time = time.perf_counter() - # If we're in the success collection phase, keep marking rewards as 1.0 - if success_detected: - frame["next.reward"] = np.array([1.0], dtype=np.float32) - else: - frame["next.reward"] = np.array([reward], dtype=np.float32) + while episode_idx < cfg.dataset.num_episodes_to_record: + step_start_time = time.perf_counter() - # Only mark as done if we're truly done (reached end or collected enough success states) - really_done = terminated or truncated - if success_detected: - success_steps_collected += 1 - really_done = success_steps_collected >= cfg.number_of_steps_after_success + # Create a neutral action (no movement) + neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) + if use_gripper: + neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay - frame["next.done"] = np.array([really_done], dtype=bool) - frame["complementary_info.discrete_penalty"] = torch.tensor( - [info.get("discrete_penalty", 0.0)], dtype=torch.float32 + # Use the new step function + transition = step_env_and_process_transition( + env=env, + transition=transition, + action=neutral_action, + env_processor=env_processor, + action_processor=action_processor, + ) + terminated = transition.get(TransitionKey.DONE, False) + truncated = transition.get(TransitionKey.TRUNCATED, False) + + if cfg.mode == "record": + observations = { + k: v.squeeze(0).cpu() + for k, v in transition[TransitionKey.OBSERVATION].items() + if isinstance(v, torch.Tensor) + } + # Use teleop_action if available, otherwise use the action from the transition + action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get( + "teleop_action", transition[TransitionKey.ACTION] ) - frame["task"] = cfg.task - dataset.add_frame(frame) + frame = { + **observations, + "action": action_to_record.cpu(), + "next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32), + "next.done": np.array([terminated or truncated], dtype=bool), + } + if use_gripper: + discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0) + frame["complementary_info.discrete_penalty"] = np.array([discrete_penalty], dtype=np.float32) - # Maintain consistent timing - if cfg.fps: - dt_s = time.perf_counter() - start_loop_t - busy_wait(1 / cfg.fps - dt_s) + if dataset is not None: + frame["task"] = cfg.dataset.task + dataset.add_frame(frame) - # Check if we should end the episode - if (terminated or truncated) and not success_detected: - # Regular termination without success - break - elif success_detected and success_steps_collected >= cfg.number_of_steps_after_success: - # We've collected enough success states - logging.info(f"Collected {success_steps_collected} additional success states") - break + episode_step += 1 - # Handle episode recording - if info.get("rerecord_episode", False): - dataset.clear_episode_buffer() - logging.info(f"Re-recording episode {episode_index}") - continue + # Handle episode termination + if terminated or truncated: + episode_time = time.perf_counter() - episode_start_time + logging.info( + f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}" + ) + episode_step = 0 + episode_idx += 1 - dataset.save_episode() - episode_index += 1 + if dataset is not None: + if transition[TransitionKey.INFO].get("rerecord_episode", False): + logging.info(f"Re-recording episode {episode_idx}") + dataset.clear_episode_buffer() + episode_idx -= 1 + else: + logging.info(f"Saving episode {episode_idx}") + dataset.save_episode() - # Finalize dataset - # dataset.consolidate(run_compute_stats=True) - if cfg.push_to_hub: + # Reset for new episode + obs, info = env.reset() + env_processor.reset() + action_processor.reset() + + transition = create_transition(observation=obs, info=info) + transition = env_processor(transition) + + # Maintain fps timing + busy_wait(dt - (time.perf_counter() - step_start_time)) + + if dataset is not None and cfg.dataset.push_to_hub: + logging.info("Pushing dataset to hub") dataset.push_to_hub() -def replay_episode(env, cfg): - """ - Replay a recorded episode in the environment. +def replay_trajectory( + env: gym.Env, action_processor: DataProcessorPipeline, cfg: GymManipulatorConfig +) -> None: + """Replay recorded trajectory on robot environment.""" + assert cfg.dataset.replay_episode is not None, "Replay episode must be provided for replay" - This function loads actions from a previously recorded episode - and executes them in the environment. + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + episodes=[cfg.dataset.replay_episode], + download_videos=False, + ) + episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode) + actions = episode_frames.select_columns("action") - Args: - env: The environment to replay in. - cfg: Configuration object containing replay parameters: - - repo_id: Repository ID for dataset - - dataset_root: Local root directory for dataset - - episode: Episode ID to replay - """ - from lerobot.datasets.lerobot_dataset import LeRobotDataset + _, info = env.reset() - dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode]) - env.reset() - - actions = dataset.hf_dataset.select_columns("action") - - for idx in range(dataset.num_frames): - start_episode_t = time.perf_counter() - - action = actions[idx]["action"] - env.step(action) - - dt_s = time.perf_counter() - start_episode_t - busy_wait(1 / 10 - dt_s) + for action_data in actions: + start_time = time.perf_counter() + transition = create_transition( + observation=env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {}, + action=action_data["action"], + ) + transition = action_processor(transition) + env.step(transition[TransitionKey.ACTION]) + busy_wait(1 / cfg.env.fps - (time.perf_counter() - start_time)) @parser.wrap() -def main(cfg: EnvConfig): - """Main entry point for the robot environment script. +def main(cfg: GymManipulatorConfig) -> None: + """Main entry point for gym manipulator script.""" + env, teleop_device = make_robot_env(cfg.env) + env_processor, action_processor = make_processors(env, teleop_device, cfg.env, cfg.device) - This function runs the robot environment in one of several modes - based on the provided configuration. - - Args: - cfg: Configuration object defining the run parameters, - including mode (record, replay, random) and other settings. - """ - env = make_robot_env(cfg) - - if cfg.mode == "record": - policy = None - if cfg.pretrained_policy_name_or_path is not None: - from lerobot.policies.sac.modeling_sac import SACPolicy - - policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path) - policy.to(cfg.device) - policy.eval() - - record_dataset( - env, - policy=policy, - cfg=cfg, - ) - exit() + print("Environment observation space:", env.observation_space) + print("Environment action space:", env.action_space) + print("Environment processor:", env_processor) + print("Action processor:", action_processor) if cfg.mode == "replay": - replay_episode( - env, - cfg=cfg, - ) + replay_trajectory(env, action_processor, cfg) exit() - env.reset() - - # Initialize the smoothed action as a random sample. - smoothed_action = env.action_space.sample() * 0.0 - - # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. - # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. - alpha = 1.0 - - num_episode = 0 - successes = [] - while num_episode < 10: - start_loop_s = time.perf_counter() - # Sample a new random action from the robot's action space. - new_random_action = env.action_space.sample() - # Update the smoothed action using an exponential moving average. - smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action - - # Execute the step: wrap the NumPy action in a torch tensor. - obs, reward, terminated, truncated, info = env.step(smoothed_action) - if terminated or truncated: - successes.append(reward) - env.reset() - num_episode += 1 - - dt_s = time.perf_counter() - start_loop_s - busy_wait(1 / cfg.fps - dt_s) - - logging.info(f"Success after 20 steps {successes}") - logging.info(f"success rate {sum(successes) / len(successes)}") + control_loop(env, env_processor, action_processor, teleop_device, cfg) if __name__ == "__main__": diff --git a/src/lerobot/scripts/rl/learner.py b/src/lerobot/scripts/rl/learner.py index f9f3901c..5d995382 100644 --- a/src/lerobot/scripts/rl/learner.py +++ b/src/lerobot/scripts/rl/learner.py @@ -75,6 +75,7 @@ from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.robots import so100_follower # noqa: F401 from lerobot.scripts.rl import learner_service from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 +from lerobot.teleoperators.utils import TeleopEvents from lerobot.transport import services_pb2_grpc from lerobot.transport.utils import ( MAX_MESSAGE_SIZE, @@ -102,11 +103,6 @@ from lerobot.utils.wandb_utils import WandBLogger LOG_PREFIX = "[LEARNER]" -################################################# -# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS # -################################################# - - @parser.wrap() def train_cli(cfg: TrainRLServerPipelineConfig): if not use_threads(cfg): @@ -249,9 +245,7 @@ def start_learner_threads( logging.info("[LEARNER] queues closed") -################################################# -# Core algorithm functions # -################################################# +# Core algorithm functions def add_actor_information_and_train( @@ -819,9 +813,7 @@ def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.M return optimizers, lr_scheduler -################################################# -# Training setup functions # -################################################# +# Training setup functions def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipelineConfig: @@ -1022,9 +1014,7 @@ def initialize_offline_replay_buffer( return offline_replay_buffer -################################################# -# Utilities/Helpers functions # -################################################# +# Utilities/Helpers functions def get_observation_features( @@ -1048,10 +1038,8 @@ def get_observation_features( return None, None with torch.no_grad(): - observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True) - next_observation_features = policy.actor.encoder.get_cached_image_features( - next_observations, normalize=True - ) + observation_features = policy.actor.encoder.get_cached_image_features(observations) + next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations) return observation_features, next_observation_features @@ -1176,7 +1164,7 @@ def process_transitions( # Add to offline buffer if it's an intervention if dataset_repo_id is not None and transition.get("complementary_info", {}).get( - "is_intervention" + TeleopEvents.IS_INTERVENTION ): offline_replay_buffer.add(**transition) diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 398bea90..485fc927 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -31,7 +31,7 @@ from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import cycle from lerobot.envs.factory import make_env from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.factory import make_policy +from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters from lerobot.scripts.eval import eval_policy @@ -64,6 +64,28 @@ def update_policy( use_amp: bool = False, lock=None, ) -> tuple[MetricsTracker, dict]: + """ + Performs a single training step to update the policy's weights. + + This function executes the forward and backward passes, clips gradients, and steps the optimizer and + learning rate scheduler. It also handles mixed-precision training via a GradScaler. + + Args: + train_metrics: A MetricsTracker instance to record training statistics. + policy: The policy model to be trained. + batch: A batch of training data. + optimizer: The optimizer used to update the policy's parameters. + grad_clip_norm: The maximum norm for gradient clipping. + grad_scaler: The GradScaler for automatic mixed-precision training. + lr_scheduler: An optional learning rate scheduler. + use_amp: A boolean indicating whether to use automatic mixed precision. + lock: An optional lock for thread-safe optimizer updates. + + Returns: + A tuple containing: + - The updated MetricsTracker with new statistics for this step. + - A dictionary of outputs from the policy's forward pass, for logging purposes. + """ start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() @@ -107,6 +129,20 @@ def update_policy( @parser.wrap() def train(cfg: TrainPipelineConfig): + """ + Main function to train a policy. + + This function orchestrates the entire training pipeline, including: + - Setting up logging, seeding, and device configuration. + - Creating the dataset, evaluation environment (if applicable), policy, and optimizer. + - Handling resumption from a checkpoint. + - Running the main training loop, which involves fetching data batches and calling `update_policy`. + - Periodically logging metrics, saving model checkpoints, and evaluating the policy. + - Pushing the final trained model to the Hugging Face Hub if configured. + + Args: + cfg: A `TrainPipelineConfig` object containing all training configurations. + """ cfg.validate() logging.info(pformat(cfg.to_dict())) @@ -141,6 +177,16 @@ def train(cfg: TrainPipelineConfig): ds_meta=dataset.meta, ) + # Create processors - only provide dataset_stats if not resuming from saved processors + processor_kwargs = {} + if not (cfg.resume and cfg.policy.pretrained_path): + # Only provide dataset_stats when not resuming from saved processor state + processor_kwargs["dataset_stats"] = dataset.meta.stats + + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs + ) + logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) @@ -205,15 +251,9 @@ def train(cfg: TrainPipelineConfig): for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) + batch = preprocessor(batch) train_tracker.dataloading_s = time.perf_counter() - start_time - for key in batch: - if isinstance(batch[key], torch.Tensor): - if batch[key].dtype != torch.bool: - batch[key] = batch[key].type(torch.float32) if device.type == "mps" else batch[key] - - batch[key] = batch[key].to(device, non_blocking=device.type == "cuda") - train_tracker, output_dict = update_policy( train_tracker, policy, @@ -245,7 +285,9 @@ def train(cfg: TrainPipelineConfig): if cfg.save_checkpoint and is_saving_step: logging.info(f"Checkpoint policy after step {step}") checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) - save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler) + save_checkpoint( + checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor + ) update_last_checkpoint(checkpoint_dir) if wandb_logger: wandb_logger.log_policy(checkpoint_dir) @@ -258,9 +300,11 @@ def train(cfg: TrainPipelineConfig): torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), ): eval_info = eval_policy( - eval_env, - policy, - cfg.eval.n_episodes, + env=eval_env, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + n_episodes=cfg.eval.n_episodes, videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", max_episodes_rendered=4, start_seed=cfg.seed, @@ -289,6 +333,8 @@ def train(cfg: TrainPipelineConfig): if cfg.policy.push_to_hub: policy.push_model_to_hub(cfg) + preprocessor.push_to_hub(cfg.policy.repo_id) + postprocessor.push_to_hub(cfg.policy.repo_id) def main(): diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index e7be6967..62c243e9 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -56,11 +56,17 @@ import time from dataclasses import asdict, dataclass from pprint import pformat -import draccus import rerun as rr from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.configs import parser +from lerobot.processor import ( + RobotAction, + RobotObservation, + RobotProcessorPipeline, + make_default_processors, +) from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -100,36 +106,81 @@ class TeleoperateConfig: def teleop_loop( - teleop: Teleoperator, robot: Robot, fps: int, display_data: bool = False, duration: float | None = None + teleop: Teleoperator, + robot: Robot, + fps: int, + teleop_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction], + robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction], + robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation], + display_data: bool = False, + duration: float | None = None, ): + """ + This function continuously reads actions from a teleoperation device, processes them through optional + pipelines, sends them to a robot, and optionally displays the robot's state. The loop runs at a + specified frequency until a set duration is reached or it is manually interrupted. + + Args: + teleop: The teleoperator device instance providing control actions. + robot: The robot instance being controlled. + fps: The target frequency for the control loop in frames per second. + display_data: If True, fetches robot observations and displays them in the console and Rerun. + duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely. + teleop_action_processor: An optional pipeline to process raw actions from the teleoperator. + robot_action_processor: An optional pipeline to process actions before they are sent to the robot. + robot_observation_processor: An optional pipeline to process raw observations from the robot. + """ + display_len = max(len(key) for key in robot.action_features) start = time.perf_counter() + while True: loop_start = time.perf_counter() - action = teleop.get_action() - if display_data: - observation = robot.get_observation() - log_rerun_data(observation, action) - robot.send_action(action) + # Get robot observation + # Not really needed for now other than for visualization + # teleop_action_processor can take None as an observation + # given that it is the identity processor as default + obs = robot.get_observation() + + # Get teleop action + raw_action = teleop.get_action() + + # Process teleop action through pipeline + teleop_action = teleop_action_processor((raw_action, obs)) + + # Process action for robot through pipeline + robot_action_to_send = robot_action_processor((teleop_action, obs)) + + # Send processed action to robot (robot_action_processor.to_output should return dict[str, Any]) + _ = robot.send_action(robot_action_to_send) + + if display_data: + # Process robot observation through pipeline + obs_transition = robot_observation_processor(obs) + + log_rerun_data( + observation=obs_transition, + action=teleop_action, + ) + + print("\n" + "-" * (display_len + 10)) + print(f"{'NAME':<{display_len}} | {'NORM':>7}") + # Display the final robot action that was sent + for motor, value in robot_action_to_send.items(): + print(f"{motor:<{display_len}} | {value:>7.2f}") + move_cursor_up(len(robot_action_to_send) + 5) + dt_s = time.perf_counter() - loop_start busy_wait(1 / fps - dt_s) - loop_s = time.perf_counter() - loop_start - - print("\n" + "-" * (display_len + 10)) - print(f"{'NAME':<{display_len}} | {'NORM':>7}") - for motor, value in action.items(): - print(f"{motor:<{display_len}} | {value:>7.2f}") print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)") if duration is not None and time.perf_counter() - start >= duration: return - move_cursor_up(len(action) + 5) - -@draccus.wrap() +@parser.wrap() def teleoperate(cfg: TeleoperateConfig): init_logging() logging.info(pformat(asdict(cfg))) @@ -138,12 +189,22 @@ def teleoperate(cfg: TeleoperateConfig): teleop = make_teleoperator_from_config(cfg.teleop) robot = make_robot_from_config(cfg.robot) + teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() teleop.connect() robot.connect() try: - teleop_loop(teleop, robot, cfg.fps, display_data=cfg.display_data, duration=cfg.teleop_time_s) + teleop_loop( + teleop=teleop, + robot=robot, + fps=cfg.fps, + display_data=cfg.display_data, + duration=cfg.teleop_time_s, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, + ) except KeyboardInterrupt: pass finally: diff --git a/src/lerobot/teleoperators/__init__.py b/src/lerobot/teleoperators/__init__.py index 56f48af7..ee508ddd 100644 --- a/src/lerobot/teleoperators/__init__.py +++ b/src/lerobot/teleoperators/__init__.py @@ -16,4 +16,4 @@ from .config import TeleoperatorConfig from .teleoperator import Teleoperator -from .utils import make_teleoperator_from_config +from .utils import TeleopEvents, make_teleoperator_from_config diff --git a/src/lerobot/teleoperators/gamepad/gamepad_utils.py b/src/lerobot/teleoperators/gamepad/gamepad_utils.py index 7ebed6b3..d994dadd 100644 --- a/src/lerobot/teleoperators/gamepad/gamepad_utils.py +++ b/src/lerobot/teleoperators/gamepad/gamepad_utils.py @@ -16,6 +16,8 @@ import logging +from ..utils import TeleopEvents + class InputController: """Base class for input controllers that generate motion deltas.""" @@ -134,10 +136,10 @@ class KeyboardController(InputController): return False elif key == keyboard.Key.enter: self.key_states["success"] = True - self.episode_end_status = "success" + self.episode_end_status = TeleopEvents.SUCCESS elif key == keyboard.Key.backspace: self.key_states["failure"] = True - self.episode_end_status = "failure" + self.episode_end_status = TeleopEvents.FAILURE except AttributeError: pass @@ -255,13 +257,13 @@ class GamepadController(InputController): for event in pygame.event.get(): if event.type == pygame.JOYBUTTONDOWN: if event.button == 3: - self.episode_end_status = "success" + self.episode_end_status = TeleopEvents.SUCCESS # A button (1) for failure elif event.button == 1: - self.episode_end_status = "failure" + self.episode_end_status = TeleopEvents.FAILURE # X button (0) for rerecord elif event.button == 0: - self.episode_end_status = "rerecord_episode" + self.episode_end_status = TeleopEvents.RERECORD_EPISODE # RB button (6) for closing gripper elif event.button == 6: @@ -295,8 +297,8 @@ class GamepadController(InputController): try: # Read joystick axes # Left stick X and Y (typically axes 0 and 1) - x_input = self.joystick.get_axis(0) # Left/Right - y_input = self.joystick.get_axis(1) # Up/Down (often inverted) + y_input = self.joystick.get_axis(0) # Up/Down (often inverted) + x_input = self.joystick.get_axis(1) # Left/Right # Right stick Y (typically axis 3 or 4) z_input = self.joystick.get_axis(3) # Up/Down for Z @@ -308,7 +310,7 @@ class GamepadController(InputController): # Calculate deltas (note: may need to invert axes depending on controller) delta_x = -x_input * self.x_step_size # Forward/backward - delta_y = y_input * self.y_step_size # Left/right + delta_y = -y_input * self.y_step_size # Left/right delta_z = -z_input * self.z_step_size # Up/down return delta_x, delta_y, delta_z @@ -451,11 +453,11 @@ class GamepadControllerHID(InputController): # Check if X/Square button (bit 5) is pressed for failure # Check if A/Cross button (bit 4) is pressed for rerecording if buttons & 1 << 7: - self.episode_end_status = "success" + self.episode_end_status = TeleopEvents.SUCCESS elif buttons & 1 << 5: - self.episode_end_status = "failure" + self.episode_end_status = TeleopEvents.FAILURE elif buttons & 1 << 4: - self.episode_end_status = "rerecord_episode" + self.episode_end_status = TeleopEvents.RERECORD_EPISODE else: self.episode_end_status = None diff --git a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py index 98a0647e..c7072f4a 100644 --- a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py +++ b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py @@ -21,6 +21,7 @@ from typing import Any import numpy as np from ..teleoperator import Teleoperator +from ..utils import TeleopEvents from .configuration_gamepad import GamepadTeleopConfig @@ -107,6 +108,48 @@ class GamepadTeleop(Teleoperator): return action_dict + def get_teleop_events(self) -> dict[str, Any]: + """ + Get extra control events from the gamepad such as intervention status, + episode termination, success indicators, etc. + + Returns: + Dictionary containing: + - is_intervention: bool - Whether human is currently intervening + - terminate_episode: bool - Whether to terminate the current episode + - success: bool - Whether the episode was successful + - rerecord_episode: bool - Whether to rerecord the episode + """ + if self.gamepad is None: + return { + TeleopEvents.IS_INTERVENTION: False, + TeleopEvents.TERMINATE_EPISODE: False, + TeleopEvents.SUCCESS: False, + TeleopEvents.RERECORD_EPISODE: False, + } + + # Update gamepad state to get fresh inputs + self.gamepad.update() + + # Check if intervention is active + is_intervention = self.gamepad.should_intervene() + + # Get episode end status + episode_end_status = self.gamepad.get_episode_end_status() + terminate_episode = episode_end_status in [ + TeleopEvents.RERECORD_EPISODE, + TeleopEvents.FAILURE, + ] + success = episode_end_status == TeleopEvents.SUCCESS + rerecord_episode = episode_end_status == TeleopEvents.RERECORD_EPISODE + + return { + TeleopEvents.IS_INTERVENTION: is_intervention, + TeleopEvents.TERMINATE_EPISODE: terminate_episode, + TeleopEvents.SUCCESS: success, + TeleopEvents.RERECORD_EPISODE: rerecord_episode, + } + def disconnect(self) -> None: """Disconnect from the gamepad.""" if self.gamepad is not None: diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index d034982f..7f489b25 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -24,6 +24,7 @@ from typing import Any from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..teleoperator import Teleoperator +from ..utils import TeleopEvents from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig PYNPUT_AVAILABLE = True @@ -176,16 +177,6 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop): "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2}, } - def _on_press(self, key): - if hasattr(key, "char"): - key = key.char - self.event_queue.put((key, True)) - - def _on_release(self, key): - if hasattr(key, "char"): - key = key.char - self.event_queue.put((key, False)) - def get_action(self) -> dict[str, Any]: if not self.is_connected: raise DeviceNotConnectedError( @@ -235,3 +226,66 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop): action_dict["gripper"] = gripper_action return action_dict + + def get_teleop_events(self) -> dict[str, Any]: + """ + Get extra control events from the keyboard such as intervention status, + episode termination, success indicators, etc. + + Keyboard mappings: + - Any movement keys pressed = intervention active + - 's' key = success (terminate episode successfully) + - 'r' key = rerecord episode (terminate and rerecord) + - 'q' key = quit episode (terminate without success) + + Returns: + Dictionary containing: + - is_intervention: bool - Whether human is currently intervening + - terminate_episode: bool - Whether to terminate the current episode + - success: bool - Whether the episode was successful + - rerecord_episode: bool - Whether to rerecord the episode + """ + if not self.is_connected: + return { + TeleopEvents.IS_INTERVENTION: False, + TeleopEvents.TERMINATE_EPISODE: False, + TeleopEvents.SUCCESS: False, + TeleopEvents.RERECORD_EPISODE: False, + } + + # Check if any movement keys are currently pressed (indicates intervention) + movement_keys = [ + keyboard.Key.up, + keyboard.Key.down, + keyboard.Key.left, + keyboard.Key.right, + keyboard.Key.shift, + keyboard.Key.shift_r, + keyboard.Key.ctrl_r, + keyboard.Key.ctrl_l, + ] + is_intervention = any(self.current_pressed.get(key, False) for key in movement_keys) + + # Check for episode control commands from misc_keys_queue + terminate_episode = False + success = False + rerecord_episode = False + + # Process any pending misc keys + while not self.misc_keys_queue.empty(): + key = self.misc_keys_queue.get_nowait() + if key == "s": + success = True + elif key == "r": + terminate_episode = True + rerecord_episode = True + elif key == "q": + terminate_episode = True + success = False + + return { + TeleopEvents.IS_INTERVENTION: is_intervention, + TeleopEvents.TERMINATE_EPISODE: terminate_episode, + TeleopEvents.SUCCESS: success, + TeleopEvents.RERECORD_EPISODE: rerecord_episode, + } diff --git a/src/lerobot/teleoperators/phone/__init__.py b/src/lerobot/teleoperators/phone/__init__.py new file mode 100644 index 00000000..2b28c1f9 --- /dev/null +++ b/src/lerobot/teleoperators/phone/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_phone import PhoneConfig +from .teleop_phone import Phone diff --git a/src/lerobot/teleoperators/phone/config_phone.py b/src/lerobot/teleoperators/phone/config_phone.py new file mode 100644 index 00000000..380d5f5f --- /dev/null +++ b/src/lerobot/teleoperators/phone/config_phone.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from enum import Enum + +import numpy as np + +from ..config import TeleoperatorConfig + + +class PhoneOS(Enum): + ANDROID = "android" + IOS = "ios" + + +@TeleoperatorConfig.register_subclass("phone") +@dataclass +class PhoneConfig(TeleoperatorConfig): + phone_os: PhoneOS = PhoneOS.IOS + camera_offset = np.array( + [0.0, -0.02, 0.04] + ) # iPhone 14 Pro camera is 2cm off center and 4cm above center diff --git a/src/lerobot/teleoperators/phone/phone_processor.py b/src/lerobot/teleoperators/phone/phone_processor.py new file mode 100644 index 00000000..67e64c7d --- /dev/null +++ b/src/lerobot/teleoperators/phone/phone_processor.py @@ -0,0 +1,110 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.processor import ProcessorStepRegistry, RobotAction, RobotActionProcessorStep +from lerobot.teleoperators.phone.config_phone import PhoneOS + + +@ProcessorStepRegistry.register("map_phone_action_to_robot_action") +@dataclass +class MapPhoneActionToRobotAction(RobotActionProcessorStep): + """ + Maps calibrated phone pose actions to standardized robot action inputs. + + This processor step acts as a bridge between the phone teleoperator's output + and the robot's expected action format. It remaps the phone's 6-DoF pose + (position and rotation) to the robot's target end-effector pose, applying + necessary axis inversions and swaps. It also interprets platform-specific + button presses to generate a gripper command. + + Attributes: + platform: The operating system of the phone (iOS or Android), used + to determine the correct button mappings for the gripper. + """ + + # TODO(Steven): Gripper vel could be output of phone_teleop directly + platform: PhoneOS + _enabled_prev: bool = field(default=False, init=False, repr=False) + + def action(self, action: RobotAction) -> RobotAction: + """ + Processes the phone action dictionary to create a robot action dictionary. + + Args: + act: The input action dictionary from the phone teleoperator. + + Returns: + A new action dictionary formatted for the robot controller. + + Raises: + ValueError: If 'pos' or 'rot' keys are missing from the input action. + """ + # Pop them from the action + enabled = bool(action.pop("phone.enabled")) + pos = action.pop("phone.pos") + rot = action.pop("phone.rot") + inputs = action.pop("phone.raw_inputs") + + if pos is None or rot is None: + raise ValueError("pos and rot must be present in action") + + rotvec = rot.as_rotvec() # Absolute orientation as rotvec + + # Map certain inputs to certain actions + if self.platform == PhoneOS.IOS: + gripper_vel = float(inputs.get("a3", 0.0)) + else: + a = float(inputs.get("reservedButtonA", 0.0)) + b = float(inputs.get("reservedButtonB", 0.0)) + gripper_vel = ( + a - b + ) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed + + # For some actions we need to invert the axis + action["enabled"] = enabled + action["target_x"] = -pos[1] if enabled else 0.0 + action["target_y"] = pos[0] if enabled else 0.0 + action["target_z"] = pos[2] if enabled else 0.0 + action["target_wx"] = rotvec[1] if enabled else 0.0 + action["target_wy"] = rotvec[0] if enabled else 0.0 + action["target_wz"] = -rotvec[2] if enabled else 0.0 + action["gripper_vel"] = gripper_vel # Still send gripper action when disabled + return action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for feat in ["enabled", "pos", "rot", "raw_inputs"]: + features[PipelineFeatureType.ACTION].pop(f"phone.{feat}", None) + + for feat in [ + "enabled", + "target_x", + "target_y", + "target_z", + "target_wx", + "target_wy", + "target_wz", + "gripper_vel", + ]: + features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + return features diff --git a/src/lerobot/teleoperators/phone/teleop_phone.py b/src/lerobot/teleoperators/phone/teleop_phone.py new file mode 100644 index 00000000..c90729ef --- /dev/null +++ b/src/lerobot/teleoperators/phone/teleop_phone.py @@ -0,0 +1,421 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Docs: +# hebi: https://docs.hebi.us/tools.html#mobile-io +# teleop: https://github.com/SpesRobotics/teleop + +import logging +import threading +import time + +import hebi +import numpy as np +from teleop import Teleop + +from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS +from lerobot.teleoperators.teleoperator import Teleoperator +from lerobot.utils.rotation import Rotation + +logger = logging.getLogger(__name__) + + +class BasePhone: + _enabled: bool = False + _calib_pos: np.ndarray | None = None + _calib_rot_inv: Rotation | None = None + + def _reapply_position_calibration(self, pos: np.ndarray) -> None: + self._calib_pos = pos.copy() + + @property + def is_calibrated(self) -> bool: + return (self._calib_pos is not None) and (self._calib_rot_inv is not None) + + @property + def action_features(self) -> dict[str, type]: + return { + "phone.pos": np.ndarray, # shape (3,) + "phone.rot": Rotation, # scipy.spatial.transform.Rotation + "phone.raw_inputs": dict, # analogs/buttons or webXR meta + "phone.enabled": bool, + } + + @property + def feedback_features(self) -> dict[str, type]: + # No haptic or other feedback implemented yet + pass + + def configure(self) -> None: + # No additional configuration required for phone teleop + pass + + def send_feedback(self, feedback: dict[str, float]) -> None: + # We could add haptic feedback (vibrations) here, but it's not implemented yet + raise NotImplementedError + + +class IOSPhone(BasePhone, Teleoperator): + name = "ios_phone" + + def __init__(self, config: PhoneConfig): + super().__init__(config) + self.config = config + self._group = None + + @property + def is_connected(self) -> bool: + return self._group is not None + + def connect(self) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.") + lookup = hebi.Lookup() + time.sleep(2.0) + group = lookup.get_group_from_names(["HEBI"], ["mobileIO"]) + if group is None: + raise RuntimeError("Mobile I/O not found — check name/family settings in the app.") + self._group = group + logger.info(f"{self} connected to HEBI group with {group.size} module(s).") + + self.calibrate() + + def calibrate(self) -> None: + print( + "Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)" + ) + print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n") + position, rotation = self._wait_for_capture_trigger() + self._calib_pos = position.copy() + self._calib_rot_inv = rotation.inv() + self._enabled = False + print("Calibration done\n") + + def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]: + """ + Blocks execution until the calibration trigger is detected from the iOS device. + + This method enters a loop, continuously reading the phone's state. It waits for the user to press + and hold the 'B1' button in the HEBI Mobile I/O app. Once B1 is pressed, the loop breaks and + returns the phone's pose at that exact moment. + + Returns: + A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the + moment the trigger was activated. + """ + while True: + has_pose, position, rotation, fb_pose = self._read_current_pose() + if not has_pose: + time.sleep(0.01) + continue + + io = getattr(fb_pose, "io", None) + button_b = getattr(io, "b", None) if io is not None else None + button_b1_pressed = False + if button_b is not None: + button_b1_pressed = bool(button_b.get_int(1)) + if button_b1_pressed: + return position, rotation + + time.sleep(0.01) + + def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]: + """ + Reads the instantaneous 6-DoF pose from the connected iOS device via the HEBI SDK. + + This method fetches the latest feedback packet from the HEBI group, extracts the ARKit + position and orientation, and converts them into a standard format. It also applies a + configured camera offset to adjust the pose from the camera's frame to the phone's + physical frame. + + Returns: + A tuple containing: + - A boolean indicating if a valid pose was successfully read. + - The 3D position as a NumPy array, or None if not available. + - The orientation as a `Rotation` object, or None if not available. + - The raw HEBI feedback object for accessing other data like button presses. + """ + fbk = self._group.get_next_feedback() + pose = fbk[0] + ar_pos = getattr(pose, "ar_position", None) + ar_quat = getattr(pose, "ar_orientation", None) + if ar_pos is None or ar_quat is None: + return False, None, None, None + # HEBI provides orientation in w, x, y, z format. + # Scipy's Rotation expects x, y, z, w. + quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw + rot = Rotation.from_quat(quat_xyzw) + pos = ar_pos - rot.apply(self.config.camera_offset) + return True, pos, rot, pose + + def get_action(self) -> dict: + has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose() + if not has_pose or not self.is_calibrated: + return {} + + # Collect raw inputs (B1 / analogs on iOS, move/scale on Android) + raw_inputs: dict[str, float | int | bool] = {} + io = getattr(fb_pose, "io", None) + if io is not None: + bank_a, bank_b = io.a, io.b + if bank_a: + for ch in range(1, 9): + if bank_a.has_float(ch): + raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch)) + if bank_b: + for ch in range(1, 9): + if bank_b.has_int(ch): + raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch)) + elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch): + raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch)) + + enable = bool(raw_inputs.get("b1", 0)) + + # Rising edge then re-capture calibration immediately from current raw pose + if enable and not self._enabled: + self._reapply_position_calibration(raw_position) + + # Apply calibration + pos_cal = self._calib_rot_inv.apply(raw_position - self._calib_pos) + rot_cal = self._calib_rot_inv * raw_rotation + + self._enabled = enable + + return { + "phone.pos": pos_cal, + "phone.rot": rot_cal, + "phone.raw_inputs": raw_inputs, + "phone.enabled": self._enabled, + } + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self._group = None + + +class AndroidPhone(BasePhone, Teleoperator): + name = "android_phone" + + def __init__(self, config: PhoneConfig): + super().__init__(config) + self.config = config + self._teleop = None + self._teleop_thread = None + self._latest_pose = None + self._latest_message = None + self._android_lock = threading.Lock() + + @property + def is_connected(self) -> bool: + return self._teleop is not None + + def connect(self) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + logger.info("Starting teleop stream for Android...") + self._teleop = Teleop() + self._teleop.subscribe(self._android_callback) + self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True) + self._teleop_thread.start() + logger.info(f"{self} connected, teleop stream started.") + + self.calibrate() + + def calibrate(self) -> None: + print( + "Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)" + ) + print("Touch and move on the WebXR page to capture this pose...\n") + + pos, rot = self._wait_for_capture_trigger() + self._calib_pos = pos.copy() + self._calib_rot_inv = rot.inv() + self._enabled = False + print("Calibration done\n") + + def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]: + """ + Blocks execution until the calibration trigger is detected from the Android device. + + This method enters a loop, continuously checking the latest message received from the WebXR + session. It waits for the user to touch and move their finger on the screen, which generates + a `move` event. Once this event is detected, the loop breaks and returns the phone's current + pose. + + Returns: + A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the + moment the trigger was activated. + """ + while True: + with self._android_lock: + msg = self._latest_message or {} + + if bool(msg.get("move", False)): + ok, pos, rot, _pose = self._read_current_pose() + if ok: + return pos, rot + + time.sleep(0.01) + + def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]: + """ + Reads the latest 6-DoF pose received from the Android device's WebXR session. + + This method accesses the most recent pose data stored by the `_android_callback`. It uses a + thread lock to safely read the shared `_latest_pose` variable. The pose, a 4x4 matrix, is + then decomposed into position and rotation, and the configured camera offset is applied. + + Returns: + A tuple containing: + - A boolean indicating if a valid pose was available. + - The 3D position as a NumPy array, or None if no pose has been received yet. + - The orientation as a `Rotation` object, or None if no pose has been received. + - The raw 4x4 pose matrix as received from the teleop stream. + """ + with self._android_lock: + if self._latest_pose is None: + return False, None, None, None + p = self._latest_pose.copy() + pose = self._latest_pose + rot = Rotation.from_matrix(p[:3, :3]) + pos = p[:3, 3] - rot.apply(self.config.camera_offset) + return True, pos, rot, pose + + def _android_callback(self, pose: np.ndarray, message: dict) -> None: + """ + Callback function to handle incoming data from the Android teleop stream. + + This method is executed by the `teleop` package's subscriber thread whenever a new + pose and message are received from the WebXR session on the Android phone. It updates + the internal state (`_latest_pose` and `_latest_message`) with the new data. + A thread lock is used to ensure that these shared variables are updated atomically, + preventing race conditions with the main thread that reads them. + + Args: + pose: A 4x4 NumPy array representing the phone's transformation matrix. + message: A dictionary containing additional data, such as button presses or touch events. + """ + with self._android_lock: + self._latest_pose = pose + self._latest_message = message + + def get_action(self) -> dict: + ok, raw_pos, raw_rot, pose = self._read_current_pose() + if not ok or not self.is_calibrated: + return {} + + # Collect raw inputs (B1 / analogs on iOS, move/scale on Android) + raw_inputs: dict[str, float | int | bool] = {} + msg = self._latest_message or {} + raw_inputs["move"] = bool(msg.get("move", False)) + raw_inputs["scale"] = float(msg.get("scale", 1.0)) + raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False)) + raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False)) + + enable = bool(raw_inputs.get("move", False)) + + # Rising edge then re-capture calibration immediately from current raw pose + if enable and not self._enabled: + self._reapply_position_calibration(raw_pos) + + # Apply calibration + pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos) + rot_cal = self._calib_rot_inv * raw_rot + + self._enabled = enable + + return { + "phone.pos": pos_cal, + "phone.rot": rot_cal, + "phone.raw_inputs": raw_inputs, + "phone.enabled": self._enabled, + } + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self._teleop = None + if self._teleop_thread and self._teleop_thread.is_alive(): + self._teleop_thread.join(timeout=1.0) + self._teleop_thread = None + self._latest_pose = None + + +class Phone(Teleoperator): + """ + Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API). + For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs. + + Press and hold **B1** to enable teleoperation. While enabled, the first B1 press + captures a reference pose and rotation, when disabled and pressed again the position is reapplied. + """ + + config_class = PhoneConfig + name = "phone" + + def __init__(self, config: PhoneConfig): + super().__init__(config) + self.config = config + + self._phone_impl: Teleoperator + + if self.config.phone_os == PhoneOS.IOS: + self._phone_impl = IOSPhone(config) + elif self.config.phone_os == PhoneOS.ANDROID: + self._phone_impl = AndroidPhone(config) + else: + raise ValueError(f"Invalid config phone_os: {self.config.phone_os}") + + @property + def is_connected(self) -> bool: + return self._phone_impl.is_connected + + def connect(self) -> None: + return self._phone_impl.connect() + + def calibrate(self) -> None: + return self._phone_impl.calibrate() + + @property + def is_calibrated(self) -> bool: + return self._phone_impl.is_calibrated + + @property + def action_features(self) -> dict[str, type]: + return self._phone_impl.action_features + + @property + def feedback_features(self) -> dict[str, type]: + return self._phone_impl.feedback_features + + def configure(self) -> None: + return self._phone_impl.configure() + + def get_action(self) -> dict: + return self._phone_impl.get_action() + + def send_feedback(self, feedback: dict[str, float]) -> None: + return self._phone_impl.send_feedback(feedback) + + def disconnect(self) -> None: + return self._phone_impl.disconnect() diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 02e6fd22..bad7d9c3 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -12,10 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum + from .config import TeleoperatorConfig from .teleoperator import Teleoperator +class TeleopEvents(Enum): + """Shared constants for teleoperator events across teleoperators.""" + + SUCCESS = "success" + FAILURE = "failure" + RERECORD_EPISODE = "rerecord_episode" + IS_INTERVENTION = "is_intervention" + TERMINATE_EPISODE = "terminate_episode" + + def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: if config.type == "keyboard": from .keyboard import KeyboardTeleop diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index 4bcc241d..47beb574 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -22,6 +22,7 @@ import traceback from contextlib import nullcontext from copy import copy from functools import cache +from typing import Any import numpy as np import torch @@ -31,10 +32,25 @@ from termcolor import colored from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import DEFAULT_FEATURES from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.robots import Robot def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): + """ + Logs performance metrics for a single step of the robot control loop. + + This function formats and prints a single line of log information, including episode/frame counters, + total loop time (dt), and detailed timings for various robot and camera operations. It can also + highlight performance drops in yellow if the actual FPS is lower than the target FPS. + + Args: + robot: The `Robot` instance, used to access its internal logs for detailed timings. + dt_s: The total duration of the control loop step in seconds. + episode_index: The index of the current episode. + frame_index: The index of the current frame within the episode. + fps: The target frames per second, used to check for performance degradation. + """ log_items = [] if episode_index is not None: log_items.append(f"ep:{episode_index}") @@ -80,7 +96,16 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f @cache def is_headless(): - """Detects if python is running without a monitor.""" + """ + Detects if the Python script is running in a headless environment (e.g., without a display). + + This function attempts to import `pynput`, a library that requires a graphical environment. + If the import fails, it assumes the environment is headless. The result is cached to avoid + re-running the check. + + Returns: + True if the environment is determined to be headless, False otherwise. + """ try: import pynput # noqa @@ -101,10 +126,35 @@ def predict_action( observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], use_amp: bool, task: str | None = None, robot_type: str | None = None, ): + """ + Performs a single-step inference to predict a robot action from an observation. + + This function encapsulates the full inference pipeline: + 1. Prepares the observation by converting it to PyTorch tensors and adding a batch dimension. + 2. Runs the preprocessor pipeline on the observation. + 3. Feeds the processed observation to the policy to get a raw action. + 4. Runs the postprocessor pipeline on the raw action. + 5. Formats the final action by removing the batch dimension and moving it to the CPU. + + Args: + observation: A dictionary of NumPy arrays representing the robot's current observation. + policy: The `PreTrainedPolicy` model to use for action prediction. + device: The `torch.device` (e.g., 'cuda' or 'cpu') to run inference on. + preprocessor: The `PolicyProcessorPipeline` for preprocessing observations. + postprocessor: The `PolicyProcessorPipeline` for postprocessing actions. + use_amp: A boolean to enable/disable Automatic Mixed Precision for CUDA inference. + task: An optional string identifier for the task. + robot_type: An optional string identifier for the robot type. + + Returns: + A `torch.Tensor` containing the predicted action, ready for the robot. + """ observation = copy(observation) with ( torch.inference_mode(), @@ -122,10 +172,14 @@ def predict_action( observation["task"] = task if task else "" observation["robot_type"] = robot_type if robot_type else "" + observation = preprocessor(observation) + # Compute the next action with the policy # based on the current observation action = policy.select_action(observation) + action = postprocessor(action) + # Remove batch dimension action = action.squeeze(0) @@ -136,6 +190,18 @@ def predict_action( def init_keyboard_listener(): + """ + Initializes a non-blocking keyboard listener for real-time user interaction. + + This function sets up a listener for specific keys (right arrow, left arrow, escape) to control + the program flow during execution, such as stopping recording or exiting loops. It gracefully + handles headless environments where keyboard listening is not possible. + + Returns: + A tuple containing: + - The `pynput.keyboard.Listener` instance, or `None` if in a headless environment. + - A dictionary of event flags (e.g., `exit_early`) that are set by key presses. + """ # Allow to exit early while recording an episode or resetting the environment, # by tapping the right arrow key '->'. This might require a sudo permission # to allow your terminal to monitor keyboard events. @@ -177,6 +243,19 @@ def init_keyboard_listener(): def sanity_check_dataset_name(repo_id, policy_cfg): + """ + Validates the dataset repository name against the presence of a policy configuration. + + This function enforces a naming convention: a dataset repository ID should start with "eval_" + if and only if a policy configuration is provided for evaluation purposes. + + Args: + repo_id: The Hugging Face Hub repository ID of the dataset. + policy_cfg: The configuration object for the policy, or `None`. + + Raises: + ValueError: If the naming convention is violated. + """ _, dataset_name = repo_id.split("/") # either repo_id doesnt start with "eval_" and there is no policy # or repo_id starts with "eval_" and there is a policy @@ -197,6 +276,21 @@ def sanity_check_dataset_name(repo_id, policy_cfg): def sanity_check_dataset_robot_compatibility( dataset: LeRobotDataset, robot: Robot, fps: int, features: dict ) -> None: + """ + Checks if a dataset's metadata is compatible with the current robot and recording setup. + + This function compares key metadata fields (`robot_type`, `fps`, and `features`) from the + dataset against the current configuration to ensure that appended data will be consistent. + + Args: + dataset: The `LeRobotDataset` instance to check. + robot: The `Robot` instance representing the current hardware setup. + fps: The current recording frequency (frames per second). + features: The dictionary of features for the current recording session. + + Raises: + ValueError: If any of the checked metadata fields do not match. + """ fields = [ ("robot_type", dataset.meta.robot_type, robot.robot_type), ("fps", dataset.fps, fps), diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 5c29b5a8..09e64937 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -58,6 +58,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b _torch_available, _torch_version = is_package_available("torch", return_version=True) +_transformers_available = is_package_available("transformers") _gym_xarm_available = is_package_available("gym_xarm") _gym_aloha_available = is_package_available("gym_aloha") _gym_pusht_available = is_package_available("gym_pusht") diff --git a/src/lerobot/utils/rotation.py b/src/lerobot/utils/rotation.py new file mode 100644 index 00000000..41b65294 --- /dev/null +++ b/src/lerobot/utils/rotation.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom rotation utilities to replace scipy.spatial.transform.Rotation.""" + +import numpy as np + + +class Rotation: + """ + Custom rotation class that provides a subset of scipy.spatial.transform.Rotation functionality. + + Supports conversions between rotation vectors, rotation matrices, and quaternions. + """ + + def __init__(self, quat: np.ndarray) -> None: + """Initialize rotation from quaternion [x, y, z, w].""" + self._quat = np.asarray(quat, dtype=float) + # Normalize quaternion + norm = np.linalg.norm(self._quat) + if norm > 0: + self._quat = self._quat / norm + + @classmethod + def from_rotvec(cls, rotvec: np.ndarray) -> "Rotation": + """ + Create rotation from rotation vector using Rodrigues' formula. + + Args: + rotvec: Rotation vector [x, y, z] where magnitude is angle in radians + + Returns: + Rotation instance + """ + rotvec = np.asarray(rotvec, dtype=float) + angle = np.linalg.norm(rotvec) + + if angle < 1e-8: + # For very small angles, use identity quaternion + quat = np.array([0.0, 0.0, 0.0, 1.0]) + else: + axis = rotvec / angle + half_angle = angle / 2.0 + sin_half = np.sin(half_angle) + cos_half = np.cos(half_angle) + + # Quaternion [x, y, z, w] + quat = np.array([axis[0] * sin_half, axis[1] * sin_half, axis[2] * sin_half, cos_half]) + + return cls(quat) + + @classmethod + def from_matrix(cls, matrix: np.ndarray) -> "Rotation": + """ + Create rotation from 3x3 rotation matrix. + + Args: + matrix: 3x3 rotation matrix + + Returns: + Rotation instance + """ + matrix = np.asarray(matrix, dtype=float) + + # Shepherd's method for converting rotation matrix to quaternion + trace = np.trace(matrix) + + if trace > 0: + s = np.sqrt(trace + 1.0) * 2 # s = 4 * qw + qw = 0.25 * s + qx = (matrix[2, 1] - matrix[1, 2]) / s + qy = (matrix[0, 2] - matrix[2, 0]) / s + qz = (matrix[1, 0] - matrix[0, 1]) / s + elif matrix[0, 0] > matrix[1, 1] and matrix[0, 0] > matrix[2, 2]: + s = np.sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2]) * 2 # s = 4 * qx + qw = (matrix[2, 1] - matrix[1, 2]) / s + qx = 0.25 * s + qy = (matrix[0, 1] + matrix[1, 0]) / s + qz = (matrix[0, 2] + matrix[2, 0]) / s + elif matrix[1, 1] > matrix[2, 2]: + s = np.sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2]) * 2 # s = 4 * qy + qw = (matrix[0, 2] - matrix[2, 0]) / s + qx = (matrix[0, 1] + matrix[1, 0]) / s + qy = 0.25 * s + qz = (matrix[1, 2] + matrix[2, 1]) / s + else: + s = np.sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1]) * 2 # s = 4 * qz + qw = (matrix[1, 0] - matrix[0, 1]) / s + qx = (matrix[0, 2] + matrix[2, 0]) / s + qy = (matrix[1, 2] + matrix[2, 1]) / s + qz = 0.25 * s + + quat = np.array([qx, qy, qz, qw]) + return cls(quat) + + @classmethod + def from_quat(cls, quat: np.ndarray) -> "Rotation": + """ + Create rotation from quaternion. + + Args: + quat: Quaternion [x, y, z, w] or [w, x, y, z] (specify convention in docstring) + This implementation expects [x, y, z, w] format + + Returns: + Rotation instance + """ + return cls(quat) + + def as_matrix(self) -> np.ndarray: + """ + Convert rotation to 3x3 rotation matrix. + + Returns: + 3x3 rotation matrix + """ + qx, qy, qz, qw = self._quat + + # Compute rotation matrix from quaternion + return np.array( + [ + [1 - 2 * (qy * qy + qz * qz), 2 * (qx * qy - qz * qw), 2 * (qx * qz + qy * qw)], + [2 * (qx * qy + qz * qw), 1 - 2 * (qx * qx + qz * qz), 2 * (qy * qz - qx * qw)], + [2 * (qx * qz - qy * qw), 2 * (qy * qz + qx * qw), 1 - 2 * (qx * qx + qy * qy)], + ], + dtype=float, + ) + + def as_rotvec(self) -> np.ndarray: + """ + Convert rotation to rotation vector. + + Returns: + Rotation vector [x, y, z] where magnitude is angle in radians + """ + qx, qy, qz, qw = self._quat + + # Ensure qw is positive for unique representation + if qw < 0: + qx, qy, qz, qw = -qx, -qy, -qz, -qw + + # Compute angle and axis + angle = 2.0 * np.arccos(np.clip(abs(qw), 0.0, 1.0)) + sin_half_angle = np.sqrt(1.0 - qw * qw) + + if sin_half_angle < 1e-8: + # For very small angles, use linearization: rotvec ≈ 2 * [qx, qy, qz] + return 2.0 * np.array([qx, qy, qz]) + + # Extract axis and scale by angle + axis = np.array([qx, qy, qz]) / sin_half_angle + return angle * axis + + def as_quat(self) -> np.ndarray: + """ + Get quaternion representation. + + Returns: + Quaternion [x, y, z, w] + """ + return self._quat.copy() + + def apply(self, vectors: np.ndarray, inverse: bool = False) -> np.ndarray: + """ + Apply this rotation to a set of vectors. + + This is equivalent to applying the rotation matrix to the vectors: + self.as_matrix() @ vectors (or self.as_matrix().T @ vectors if inverse=True). + + Args: + vectors: Array of shape (3,) or (N, 3) representing vectors in 3D space + inverse: If True, apply the inverse of the rotation. Default is False. + + Returns: + Rotated vectors with shape: + - (3,) if input was single vector with shape (3,) + - (N, 3) in all other cases + """ + vectors = np.asarray(vectors, dtype=float) + original_shape = vectors.shape + + # Handle single vector case - ensure it's 2D for matrix multiplication + if vectors.ndim == 1: + if len(vectors) != 3: + raise ValueError("Single vector must have length 3") + vectors = vectors.reshape(1, 3) + single_vector = True + elif vectors.ndim == 2: + if vectors.shape[1] != 3: + raise ValueError("Vectors must have shape (N, 3)") + single_vector = False + else: + raise ValueError("Vectors must be 1D or 2D array") + + # Get rotation matrix + rotation_matrix = self.as_matrix() + + # Apply inverse if requested (transpose for orthogonal rotation matrices) + if inverse: + rotation_matrix = rotation_matrix.T + + # Apply rotation: (N, 3) @ (3, 3).T -> (N, 3) + rotated_vectors = vectors @ rotation_matrix.T + + # Return original shape for single vector case + if single_vector and original_shape == (3,): + return rotated_vectors.flatten() + + return rotated_vectors + + def inv(self) -> "Rotation": + """ + Invert this rotation. + + Composition of a rotation with its inverse results in an identity transformation. + + Returns: + Rotation instance containing the inverse of this rotation + """ + qx, qy, qz, qw = self._quat + + # For a unit quaternion, the inverse is the conjugate: [-x, -y, -z, w] + inverse_quat = np.array([-qx, -qy, -qz, qw]) + + return Rotation(inverse_quat) + + def __mul__(self, other: "Rotation") -> "Rotation": + """ + Compose this rotation with another rotation using the * operator. + + The composition `r2 * r1` means "apply r1 first, then r2". + This is equivalent to applying rotation matrices: r2.as_matrix() @ r1.as_matrix() + + Args: + other: Another Rotation instance to compose with + + Returns: + Rotation instance representing the composition of rotations + """ + if not isinstance(other, Rotation): + return NotImplemented + + # Get quaternions [x, y, z, w] + x1, y1, z1, w1 = other._quat # Apply first + x2, y2, z2, w2 = self._quat # Apply second + + # Quaternion multiplication: q2 * q1 (apply q1 first, then q2) + composed_quat = np.array( + [ + w2 * x1 + x2 * w1 + y2 * z1 - z2 * y1, # x component + w2 * y1 - x2 * z1 + y2 * w1 + z2 * x1, # y component + w2 * z1 + x2 * y1 - y2 * x1 + z2 * w1, # z component + w2 * w1 - x2 * x1 - y2 * y1 - z2 * z1, # w component + ] + ) + + return Rotation(composed_quat) diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/utils/train_utils.py index 2859fe05..be2eb814 100644 --- a/src/lerobot/utils/train_utils.py +++ b/src/lerobot/utils/train_utils.py @@ -32,6 +32,7 @@ from lerobot.datasets.utils import load_json, write_json from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.processor import PolicyProcessorPipeline from lerobot.utils.random_utils import load_rng_state, save_rng_state @@ -74,6 +75,8 @@ def save_checkpoint( policy: PreTrainedPolicy, optimizer: Optimizer, scheduler: LRScheduler | None = None, + preprocessor: PolicyProcessorPipeline | None = None, + postprocessor: PolicyProcessorPipeline | None = None, ) -> None: """This function creates the following directory structure: @@ -81,7 +84,9 @@ def save_checkpoint( ├── pretrained_model/ │ ├── config.json # policy config │ ├── model.safetensors # policy weights - │ └── train_config.json # train config + │ ├── train_config.json # train config + │ ├── processor.json # processor config (if preprocessor provided) + │ └── step_*.safetensors # processor state files (if any) └── training_state/ ├── optimizer_param_groups.json # optimizer param groups ├── optimizer_state.safetensors # optimizer state @@ -95,10 +100,15 @@ def save_checkpoint( policy (PreTrainedPolicy): The policy to save. optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None. scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. + preprocessor: The preprocessor/pipeline to save. Defaults to None. """ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) cfg.save_pretrained(pretrained_dir) + if preprocessor is not None: + preprocessor.save_pretrained(pretrained_dir) + if postprocessor is not None: + postprocessor.save_pretrained(pretrained_dir) save_training_state(checkpoint_dir, step, optimizer, scheduler) diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index f0f9aebb..e6acc87d 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numbers import os from typing import Any @@ -28,19 +29,69 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None: rr.spawn(memory_limit=memory_limit) -def log_rerun_data(observation: dict[str | Any], action: dict[str | Any]): - for obs, val in observation.items(): - if isinstance(val, float): - rr.log(f"observation.{obs}", rr.Scalar(val)) - elif isinstance(val, np.ndarray): - if val.ndim == 1: - for i, v in enumerate(val): - rr.log(f"observation.{obs}_{i}", rr.Scalar(float(v))) - else: - rr.log(f"observation.{obs}", rr.Image(val), static=True) - for act, val in action.items(): - if isinstance(val, float): - rr.log(f"action.{act}", rr.Scalar(val)) - elif isinstance(val, np.ndarray): - for i, v in enumerate(val): - rr.log(f"action.{act}_{i}", rr.Scalar(float(v))) +def _is_scalar(x): + return ( + isinstance(x, float) + or isinstance(x, numbers.Real) + or isinstance(x, (np.integer, np.floating)) + or (isinstance(x, np.ndarray) and x.ndim == 0) + ) + + +def log_rerun_data( + observation: dict[str, Any] | None = None, + action: dict[str, Any] | None = None, +) -> None: + """ + Logs observation and action data to Rerun for real-time visualization. + + This function iterates through the provided observation and action dictionaries and sends their contents + to the Rerun viewer. It handles different data types appropriately: + - Scalar values (floats, ints) are logged as `rr.Scalar`. + - 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed + from CHW to HWC format and logged as `rr.Image`. + - 1D NumPy arrays are logged as a series of individual scalars, with each element indexed. + - Other multi-dimensional arrays are flattened and logged as individual scalars. + + Keys are automatically namespaced with "observation." or "action." if not already present. + + Args: + observation: An optional dictionary containing observation data to log. + action: An optional dictionary containing action data to log. + """ + if observation: + for k, v in observation.items(): + if v is None: + continue + key = k if str(k).startswith("observation.") else f"observation.{k}" + + if _is_scalar(v): + rr.log(key, rr.Scalar(float(v))) + elif isinstance(v, np.ndarray): + arr = v + # Convert CHW -> HWC when needed + if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): + arr = np.transpose(arr, (1, 2, 0)) + if arr.ndim == 1: + for i, vi in enumerate(arr): + rr.log(f"{key}_{i}", rr.Scalar(float(vi))) + else: + rr.log(key, rr.Image(arr), static=True) + + if action: + for k, v in action.items(): + if v is None: + continue + key = k if str(k).startswith("action.") else f"action.{k}" + + if _is_scalar(v): + rr.log(key, rr.Scalar(float(v))) + elif isinstance(v, np.ndarray): + if v.ndim == 1: + for i, vi in enumerate(v): + rr.log(f"{key}_{i}", rr.Scalar(float(vi))) + else: + # Fall back to flattening higher-dimensional arrays + flat = v.flatten() + for i, vi in enumerate(flat): + rr.log(f"{key}_{i}", rr.Scalar(float(vi))) diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors index 8bd63e89..771af244 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77 +oid sha256:ee0c29d3782aa1cadcf4dc6ed767d9460ff00fff9fc70b460502340b832eefcc size 5104 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors index 724d22b5..3e8df708 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603 -size 33400 +oid sha256:ea76e6711959fd3f905ec2bdc306f488920f00ec99421e4870d05f6205eb323e +size 31672 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors index 6d912d81..dd7d4d0e 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b +oid sha256:c2b8f8532c7a0b776de5e536b8b54e30b1a0c2e3d5cc25a2d86fe43e40ae5e8c size 515400 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors index cc6b4a24..5da67a1a 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075 -size 33400 +oid sha256:eca0d87a699620e4fec7e68539b0be91e4cc933f6bf12032da52c182ab6f38cf +size 31672 diff --git a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors b/tests/artifacts/policies/pusht_diffusion_/actions.safetensors index 84e14b97..ef581727 100644 --- a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c +oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8 size 992 diff --git a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors index 54229791..e00ed323 100644 --- a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201 +oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002 size 47424 diff --git a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors index e91cd08b..614cc754 100644 --- a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22 -size 49120 +oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf +size 47424 diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index 6ccb47c3..b0ffa9a3 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -23,7 +23,7 @@ from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.factory import make_policy, make_policy_config +from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors from lerobot.utils.random_utils import set_seed @@ -37,7 +37,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): train_cfg.validate() # Needed for auto-setting some parameters dataset = make_dataset(train_cfg) + dataset_stats = dataset.meta.stats policy = make_policy(train_cfg.policy, ds_meta=dataset.meta) + preprocessor, postprocessor = make_pre_post_processors(train_cfg.policy, dataset_stats=dataset_stats) policy.train() optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy) @@ -49,7 +51,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): ) batch = next(iter(dataloader)) + batch = preprocessor(batch) loss, output_dict = policy.forward(batch) + if output_dict is not None: output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} output_dict["loss"] = loss @@ -96,7 +100,12 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): else: actions_queue = train_cfg.policy.n_action_repeats - actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)} + actions = {} + for i in range(actions_queue): + unnormalized_action = policy.select_action(obs).contiguous() + action_robot = postprocessor(unnormalized_action) + actions[str(i)] = action_robot + return output_dict, grad_stats, param_stats, actions diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors index fa9bf06a..e23eacff 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b +oid sha256:d640988f2269cf6aa03c8ee17f9d096edace83d837f90025011fafec5bf53c61 size 200 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors index 8d90a671..e665f73c 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10 +oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a size 16904 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors index cde6c6dc..97d78358 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b +oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703 size 164 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors index 692377d1..3090b705 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170 -size 36312 +oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184 +size 35496 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors index 7a0b165e..5ce44048 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb +oid sha256:2212ae7b910d14d723214f5af50985e419f7bd0f4261565ef48b1ef495443d6d size 200 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors index 8d90a671..e665f73c 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10 +oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a size 16904 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors index cde6c6dc..97d78358 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b +oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703 size 164 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors index 692377d1..3090b705 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170 -size 36312 +oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184 +size 35496 diff --git a/tests/conftest.py b/tests/conftest.py index e273da50..245cde52 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ import traceback import pytest from serial import SerialException -from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from tests.utils import DEVICE # Import fixture modules as plugins @@ -83,7 +83,9 @@ def policy_feature_factory(): return _pf -def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None: +def assert_contract_is_typed(features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> None: assert isinstance(features, dict) - assert all(isinstance(k, str) for k in features.keys()) - assert all(isinstance(v, PolicyFeature) for v in features.values()) + assert all(isinstance(k, PipelineFeatureType) for k in features.keys()) + assert all(isinstance(v, dict) for v in features.values()) + assert all(all(isinstance(nk, str) for nk in v.keys()) for v in features.values()) + assert all(all(isinstance(nv, PolicyFeature) for nv in v.values()) for v in features.values()) diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py new file mode 100644 index 00000000..f1ffd800 --- /dev/null +++ b/tests/datasets/test_dataset_utils.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from datasets import Dataset +from huggingface_hub import DatasetCard + +from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index +from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch + + +def test_default_parameters(): + card = create_lerobot_dataset_card() + assert isinstance(card, DatasetCard) + assert card.data.tags == ["LeRobot"] + assert card.data.task_categories == ["robotics"] + assert card.data.configs == [ + { + "config_name": "default", + "data_files": "data/*/*.parquet", + } + ] + + +def test_with_tags(): + tags = ["tag1", "tag2"] + card = create_lerobot_dataset_card(tags=tags) + assert card.data.tags == ["LeRobot", "tag1", "tag2"] + + +def test_calculate_episode_data_index(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) + assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) + + +def test_merge_simple_vectors(): + g1 = { + "action": { + "dtype": "float32", + "shape": (2,), + "names": ["ee.x", "ee.y"], + } + } + g2 = { + "action": { + "dtype": "float32", + "shape": (2,), + "names": ["ee.y", "ee.z"], + } + } + + out = combine_feature_dicts(g1, g2) + + assert "action" in out + assert out["action"]["dtype"] == "float32" + # Names merged with preserved order and de-dupuplication + assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"] + # Shape correctly recomputed from names length + assert out["action"]["shape"] == (3,) + + +def test_merge_multiple_groups_order_and_dedup(): + g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}} + g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}} + g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}} + + out = combine_feature_dicts(g1, g2, g3) + + assert out["action"]["names"] == ["a", "b", "c", "d"] + assert out["action"]["shape"] == (4,) + + +def test_non_vector_last_wins_for_images(): + # Non-vector (images) with same name should be overwritten by the last image specified + g1 = { + "observation.images.front": { + "dtype": "image", + "shape": (3, 480, 640), + "names": ["channels", "height", "width"], + } + } + g2 = { + "observation.images.front": { + "dtype": "image", + "shape": (3, 720, 1280), + "names": ["channels", "height", "width"], + } + } + + out = combine_feature_dicts(g1, g2) + assert out["observation.images.front"]["shape"] == (3, 720, 1280) + assert out["observation.images.front"]["dtype"] == "image" + + +def test_dtype_mismatch_raises(): + g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}} + g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}} + + with pytest.raises(ValueError, match="dtype mismatch for 'action'"): + _ = combine_feature_dicts(g1, g2) + + +def test_non_dict_passthrough_last_wins(): + g1 = {"misc": 123} + g2 = {"misc": 456} + + out = combine_feature_dicts(g1, g2) + # For non-dict entries the last one wins + assert out["misc"] == 456 diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py deleted file mode 100644 index 91d661b3..00000000 --- a/tests/datasets/test_utils.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from copy import deepcopy - -import torch -from datasets import Dataset -from huggingface_hub import DatasetCard - -from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index -from lerobot.datasets.utils import ( - create_lerobot_dataset_card, - flatten_dict, - hf_transform_to_torch, - unflatten_dict, -) - - -def test_default_parameters(): - card = create_lerobot_dataset_card() - assert isinstance(card, DatasetCard) - assert card.data.tags == ["LeRobot"] - assert card.data.task_categories == ["robotics"] - assert card.data.configs == [ - { - "config_name": "default", - "data_files": "data/*/*.parquet", - } - ] - - -def test_with_tags(): - tags = ["tag1", "tag2"] - card = create_lerobot_dataset_card(tags=tags) - assert card.data.tags == ["LeRobot", "tag1", "tag2"] - - -def test_calculate_episode_data_index(): - dataset = Dataset.from_dict( - { - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - "index": [0, 1, 2, 3, 4, 5], - "episode_index": [0, 0, 1, 2, 2, 2], - }, - ) - dataset.set_transform(hf_transform_to_torch) - episode_data_index = calculate_episode_data_index(dataset) - assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) - assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) - - -def test_flatten_unflatten_dict(): - d = { - "obs": { - "min": 0, - "max": 1, - "mean": 2, - "std": 3, - }, - "action": { - "min": 4, - "max": 5, - "mean": 6, - "std": 7, - }, - } - - original_d = deepcopy(d) - d = unflatten_dict(flatten_dict(d)) - - # test equality between nested dicts - assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}" diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index ef2d4ecd..ef09bcd2 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -26,7 +26,7 @@ from safetensors.torch import load_file from lerobot import available_policies from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.constants import ACTION, OBS_STATE from lerobot.datasets.factory import make_dataset from lerobot.datasets.utils import cycle, dataset_to_policy_features @@ -39,8 +39,8 @@ from lerobot.policies.factory import ( get_policy_class, make_policy, make_policy_config, + make_pre_post_processors, ) -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats @@ -154,6 +154,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): # Check that we can make the policy object. dataset = make_dataset(train_cfg) + preprocessor, _ = make_pre_post_processors(train_cfg.policy, None) policy = make_policy(train_cfg.policy, ds_meta=dataset.meta) assert isinstance(policy, PreTrainedPolicy) @@ -227,6 +228,7 @@ def test_act_backbone_lr(): assert cfg.policy.optimizer_lr_backbone == 0.001 dataset = make_dataset(cfg) + preprocessor, _ = make_pre_post_processors(cfg.policy, None) policy = make_policy(cfg.policy, ds_meta=dataset.meta) optimizer, _ = make_optimizer_and_scheduler(cfg, policy) assert len(optimizer.param_groups) == 2 @@ -266,108 +268,6 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0) -@pytest.mark.parametrize("insert_temporal_dim", [False, True]) -def test_normalize(insert_temporal_dim): - """ - Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise - an exception when the forward pass is called without the stats having been provided. - - TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as - expected. - """ - - input_features = { - "observation.image": PolicyFeature( - type=FeatureType.VISUAL, - shape=(3, 96, 96), - ), - "observation.state": PolicyFeature( - type=FeatureType.STATE, - shape=(10,), - ), - } - output_features = { - "action": PolicyFeature( - type=FeatureType.ACTION, - shape=(5,), - ), - } - - norm_map = { - "VISUAL": NormalizationMode.MEAN_STD, - "STATE": NormalizationMode.MIN_MAX, - "ACTION": NormalizationMode.MIN_MAX, - } - - dataset_stats = { - "observation.image": { - "mean": torch.randn(3, 1, 1), - "std": torch.randn(3, 1, 1), - "min": torch.randn(3, 1, 1), - "max": torch.randn(3, 1, 1), - }, - "observation.state": { - "mean": torch.randn(10), - "std": torch.randn(10), - "min": torch.randn(10), - "max": torch.randn(10), - }, - "action": { - "mean": torch.randn(5), - "std": torch.randn(5), - "min": torch.randn(5), - "max": torch.randn(5), - }, - } - - bsize = 2 - input_batch = { - "observation.image": torch.randn(bsize, 3, 96, 96), - "observation.state": torch.randn(bsize, 10), - } - output_batch = { - "action": torch.randn(bsize, 5), - } - - if insert_temporal_dim: - tdim = 4 - - for key in input_batch: - # [2,3,96,96] -> [2,tdim,3,96,96] - input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1) - - for key in output_batch: - output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1) - - # test without stats - normalize = Normalize(input_features, norm_map, stats=None) - with pytest.raises(AssertionError): - normalize(input_batch) - - # test with stats - normalize = Normalize(input_features, norm_map, stats=dataset_stats) - normalize(input_batch) - - # test loading pretrained models - new_normalize = Normalize(input_features, norm_map, stats=None) - new_normalize.load_state_dict(normalize.state_dict()) - new_normalize(input_batch) - - # test without stats - unnormalize = Unnormalize(output_features, norm_map, stats=None) - with pytest.raises(AssertionError): - unnormalize(output_batch) - - # test with stats - unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats) - unnormalize(output_batch) - - # test loading pretrained models - new_unnormalize = Unnormalize(output_features, norm_map, stats=None) - new_unnormalize.load_state_dict(unnormalize.state_dict()) - unnormalize(output_batch) - - @pytest.mark.parametrize("multikey", [True, False]) def test_multikey_construction(multikey: bool): """ @@ -467,6 +367,8 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact is out of date. For example, some PyTorch versions have different randomness, see this PR: https://github.com/huggingface/lerobot/pull/1127. + NOTE: If the test don't pass and you don't change the policy, and note the dependencies version, + and you changed your processor, you might have to update the test artifact. """ diff --git a/tests/processor/test_act_processor.py b/tests/processor/test_act_processor.py new file mode 100644 index 00000000..f96f871a --- /dev/null +++ b/tests/processor/test_act_processor.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for ACT policy processor.""" + +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_STATE +from lerobot.policies.act.configuration_act import ACTConfig +from lerobot.policies.act.processor_act import make_act_pre_post_processors +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DataProcessorPipeline, + DeviceProcessorStep, + NormalizerProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +def create_default_config(): + """Create a default ACT configuration for testing.""" + config = ACTConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(7,)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + config.device = "cpu" + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(7), "std": torch.ones(7)}, + ACTION: {"mean": torch.zeros(4), "std": torch.ones(4)}, + } + + +def test_make_act_processor_basic(): + """Test basic creation of ACT processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors(config, stats) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 4 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_act_processor_normalization(): + """Test that ACT processor correctly normalizes and unnormalizes data.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + # Create test data + observation = {OBS_STATE: torch.randn(7)} + action = torch.randn(4) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is normalized and batched + assert processed[OBS_STATE].shape == (1, 7) + assert processed[TransitionKey.ACTION.value].shape == (1, 4) + + # Process action through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is unnormalized + assert postprocessed.shape == (1, 4) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_act_processor_cuda(): + """Test ACT processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = {OBS_STATE: torch.randn(7)} + action = torch.randn(4) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + # Process through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is back on CPU + assert postprocessed.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_act_processor_accelerate_scenario(): + """Test ACT processor in simulated Accelerate scenario (data already on GPU).""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU + device = torch.device("cuda:0") + observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU + action = torch.randn(1, 4).to(device) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data stays on same GPU (not moved unnecessarily) + assert processed[OBS_STATE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_act_processor_multi_gpu(): + """Test ACT processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + # Simulate data on different GPU (like in multi-GPU training) + device = torch.device("cuda:1") + observation = {OBS_STATE: torch.randn(1, 7).to(device)} + action = torch.randn(1, 4).to(device) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data stays on cuda:1 (not moved to cuda:0) + assert processed[OBS_STATE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_act_processor_without_stats(): + """Test ACT processor creation without dataset statistics.""" + config = create_default_config() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + dataset_stats=None, + ) + + # Should still create processors, but normalization won't have stats + assert preprocessor is not None + assert postprocessor is not None + + # Process should still work (but won't normalize without stats) + observation = {OBS_STATE: torch.randn(7)} + action = torch.randn(4) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed is not None + + +def test_act_processor_save_and_load(): + """Test saving and loading ACT processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save preprocessor + preprocessor.save_pretrained(tmpdir) + + # Load preprocessor + loaded_preprocessor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="policy_preprocessor.json" + ) + + # Test that loaded processor works + observation = {OBS_STATE: torch.randn(7)} + action = torch.randn(4) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 7) + assert processed[TransitionKey.ACTION.value].shape == (1, 4) + + +def test_act_processor_device_placement_preservation(): + """Test that ACT processor preserves device placement correctly.""" + config = create_default_config() + stats = create_default_stats() + + # Test with CPU config + config.device = "cpu" + preprocessor, _ = make_act_pre_post_processors( + config, + stats, + ) + + # Process CPU data + observation = {OBS_STATE: torch.randn(7)} + action = torch.randn(4) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed[OBS_STATE].device.type == "cpu" + assert processed[TransitionKey.ACTION.value].device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_act_processor_mixed_precision(): + """Test ACT processor with mixed precision (float16).""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + # Modify the device processor to use float16 + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + # Replace DeviceProcessorStep with one that uses float16 + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + elif isinstance(step, NormalizerProcessorStep): + # Update normalizer to use the same device as the device processor + norm_step = step # Now type checker knows this is NormalizerProcessorStep + modified_steps.append( + NormalizerProcessorStep( + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, + device=config.device, + dtype=torch.float16, # Match the float16 dtype + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Create test data + observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} + action = torch.randn(4, dtype=torch.float32) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is converted to float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 + + +def test_act_processor_batch_consistency(): + """Test that ACT processor handles different batch sizes correctly.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + # Test single sample (unbatched) + observation = {OBS_STATE: torch.randn(7)} + action = torch.randn(4) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed["observation.state"].shape[0] == 1 # Batched + + # Test already batched data + observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8 + action_batched = torch.randn(8, 4) + transition_batched = create_transition(observation_batched, action_batched) + batch_batched = transition_to_batch(transition_batched) + + processed_batched = preprocessor(batch_batched) + assert processed_batched[OBS_STATE].shape[0] == 8 + assert processed_batched[TransitionKey.ACTION.value].shape[0] == 8 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_act_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_act_pre_post_processors( + config, + stats, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + norm_step = step # Now type checker knows this is NormalizerProcessorStep + modified_steps.append( + NormalizerProcessorStep( + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data + observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32 + action = torch.randn(4, dtype=torch.float32) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 63894025..631ad789 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -1,11 +1,7 @@ import torch -from lerobot.processor.pipeline import ( - RobotProcessor, - TransitionKey, - _default_batch_to_transition, - _default_transition_to_batch, -) +from lerobot.processor import DataProcessorPipeline, TransitionKey +from lerobot.processor.converters import batch_to_transition, transition_to_batch def _dummy_batch(): @@ -24,7 +20,7 @@ def _dummy_batch(): def test_observation_grouping_roundtrip(): """Test that observation.* keys are properly grouped and ungrouped.""" - proc = RobotProcessor([]) + proc = DataProcessorPipeline([]) batch_in = _dummy_batch() batch_out = proc(batch_in) @@ -48,19 +44,19 @@ def test_observation_grouping_roundtrip(): def test_batch_to_transition_observation_grouping(): - """Test that _default_batch_to_transition correctly groups observation.* keys.""" + """Test that batch_to_transition correctly groups observation.* keys.""" batch = { "observation.image.top": torch.randn(1, 3, 128, 128), "observation.image.left": torch.randn(1, 3, 128, 128), "observation.state": [1, 2, 3, 4], - "action": "action_data", + "action": torch.tensor([0.1, 0.2, 0.3, 0.4]), "next.reward": 1.5, "next.done": True, "next.truncated": False, "info": {"episode": 42}, } - transition = _default_batch_to_transition(batch) + transition = batch_to_transition(batch) # Check observation is a dict with all observation.* keys assert isinstance(transition[TransitionKey.OBSERVATION], dict) @@ -78,7 +74,7 @@ def test_batch_to_transition_observation_grouping(): assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] # Check other fields - assert transition[TransitionKey.ACTION] == "action_data" + assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4])) assert transition[TransitionKey.REWARD] == 1.5 assert transition[TransitionKey.DONE] assert not transition[TransitionKey.TRUNCATED] @@ -87,7 +83,7 @@ def test_batch_to_transition_observation_grouping(): def test_transition_to_batch_observation_flattening(): - """Test that _default_transition_to_batch correctly flattens observation dict.""" + """Test that transition_to_batch correctly flattens observation dict.""" observation_dict = { "observation.image.top": torch.randn(1, 3, 128, 128), "observation.image.left": torch.randn(1, 3, 128, 128), @@ -104,7 +100,7 @@ def test_transition_to_batch_observation_flattening(): TransitionKey.COMPLEMENTARY_DATA: {}, } - batch = _default_transition_to_batch(transition) + batch = transition_to_batch(transition) # Check that observation.* keys are flattened back to batch assert "observation.image.top" in batch @@ -127,28 +123,28 @@ def test_transition_to_batch_observation_flattening(): def test_no_observation_keys(): """Test behavior when there are no observation.* keys.""" batch = { - "action": "action_data", + "action": torch.tensor([1.0, 2.0]), "next.reward": 2.0, "next.done": False, "next.truncated": True, "info": {"test": "no_obs"}, } - transition = _default_batch_to_transition(batch) + transition = batch_to_transition(batch) # Observation should be None when no observation.* keys assert transition[TransitionKey.OBSERVATION] is None # Check other fields - assert transition[TransitionKey.ACTION] == "action_data" + assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([1.0, 2.0])) assert transition[TransitionKey.REWARD] == 2.0 assert not transition[TransitionKey.DONE] assert transition[TransitionKey.TRUNCATED] assert transition[TransitionKey.INFO] == {"test": "no_obs"} # Round trip should work - reconstructed_batch = _default_transition_to_batch(transition) - assert reconstructed_batch["action"] == "action_data" + reconstructed_batch = transition_to_batch(transition) + assert torch.allclose(reconstructed_batch["action"], torch.tensor([1.0, 2.0])) assert reconstructed_batch["next.reward"] == 2.0 assert not reconstructed_batch["next.done"] assert reconstructed_batch["next.truncated"] @@ -157,13 +153,13 @@ def test_no_observation_keys(): def test_minimal_batch(): """Test with minimal batch containing only observation.* and action.""" - batch = {"observation.state": "minimal_state", "action": "minimal_action"} + batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])} - transition = _default_batch_to_transition(batch) + transition = batch_to_transition(batch) # Check observation assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} - assert transition[TransitionKey.ACTION] == "minimal_action" + assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5])) # Check defaults assert transition[TransitionKey.REWARD] == 0.0 @@ -173,9 +169,9 @@ def test_minimal_batch(): assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} # Round trip - reconstructed_batch = _default_transition_to_batch(transition) + reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch["observation.state"] == "minimal_state" - assert reconstructed_batch["action"] == "minimal_action" + assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5])) assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] assert not reconstructed_batch["next.truncated"] @@ -186,7 +182,7 @@ def test_empty_batch(): """Test behavior with empty batch.""" batch = {} - transition = _default_batch_to_transition(batch) + transition = batch_to_transition(batch) # All fields should have defaults assert transition[TransitionKey.OBSERVATION] is None @@ -198,7 +194,7 @@ def test_empty_batch(): assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} # Round trip - reconstructed_batch = _default_transition_to_batch(transition) + reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch["action"] is None assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] @@ -219,8 +215,8 @@ def test_complex_nested_observation(): "info": {"episode_length": 200, "success": True}, } - transition = _default_batch_to_transition(batch) - reconstructed_batch = _default_transition_to_batch(transition) + transition = batch_to_transition(batch) + reconstructed_batch = transition_to_batch(transition) # Check that all observation keys are preserved original_obs_keys = {k for k in batch if k.startswith("observation.")} @@ -254,7 +250,7 @@ def test_custom_converter(): def to_tr(batch): # Custom converter that modifies the reward - tr = _default_batch_to_transition(batch) + tr = batch_to_transition(batch) # Double the reward reward = tr.get(TransitionKey.REWARD, 0.0) new_tr = tr.copy() @@ -262,10 +258,10 @@ def test_custom_converter(): return new_tr def to_batch(tr): - batch = _default_transition_to_batch(tr) + batch = transition_to_batch(tr) return batch - processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch) + processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch) batch = { "observation.state": torch.randn(1, 4), diff --git a/tests/processor/test_batch_processor.py b/tests/processor/test_batch_processor.py new file mode 100644 index 00000000..f7cbafd2 --- /dev/null +++ b/tests/processor/test_batch_processor.py @@ -0,0 +1,1184 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import torch + +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DataProcessorPipeline, + ProcessorStepRegistry, + TransitionKey, +) +from lerobot.processor.converters import create_transition, identity_transition + + +def test_state_1d_to_2d(): + """Test that 1D state tensors get unsqueezed to 2D.""" + processor = AddBatchDimensionProcessorStep() + + # Test observation.state + state_1d = torch.randn(7) + observation = {OBS_STATE: state_1d} + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_state = result[TransitionKey.OBSERVATION][OBS_STATE] + assert processed_state.shape == (1, 7) + assert torch.allclose(processed_state.squeeze(0), state_1d) + + +def test_env_state_1d_to_2d(): + """Test that 1D environment state tensors get unsqueezed to 2D.""" + processor = AddBatchDimensionProcessorStep() + + # Test observation.environment_state + env_state_1d = torch.randn(10) + observation = {OBS_ENV_STATE: env_state_1d} + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_env_state = result[TransitionKey.OBSERVATION][OBS_ENV_STATE] + assert processed_env_state.shape == (1, 10) + assert torch.allclose(processed_env_state.squeeze(0), env_state_1d) + + +def test_image_3d_to_4d(): + """Test that 3D image tensors get unsqueezed to 4D.""" + processor = AddBatchDimensionProcessorStep() + + # Test observation.image + image_3d = torch.randn(224, 224, 3) + observation = {OBS_IMAGE: image_3d} + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_image = result[TransitionKey.OBSERVATION][OBS_IMAGE] + assert processed_image.shape == (1, 224, 224, 3) + assert torch.allclose(processed_image.squeeze(0), image_3d) + + +def test_multiple_images_3d_to_4d(): + """Test that 3D image tensors in observation.images.* get unsqueezed to 4D.""" + processor = AddBatchDimensionProcessorStep() + + # Test observation.images.camera1 and observation.images.camera2 + image1_3d = torch.randn(64, 64, 3) + image2_3d = torch.randn(128, 128, 3) + observation = { + f"{OBS_IMAGES}.camera1": image1_3d, + f"{OBS_IMAGES}.camera2": image2_3d, + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_obs = result[TransitionKey.OBSERVATION] + processed_image1 = processed_obs[f"{OBS_IMAGES}.camera1"] + processed_image2 = processed_obs[f"{OBS_IMAGES}.camera2"] + + assert processed_image1.shape == (1, 64, 64, 3) + assert processed_image2.shape == (1, 128, 128, 3) + assert torch.allclose(processed_image1.squeeze(0), image1_3d) + assert torch.allclose(processed_image2.squeeze(0), image2_3d) + + +def test_already_batched_tensors_unchanged(): + """Test that already batched tensors remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + # Create already batched tensors + state_2d = torch.randn(1, 7) + env_state_2d = torch.randn(1, 10) + image_4d = torch.randn(1, 224, 224, 3) + + observation = { + OBS_STATE: state_2d, + OBS_ENV_STATE: env_state_2d, + OBS_IMAGE: image_4d, + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_obs = result[TransitionKey.OBSERVATION] + + # Should remain unchanged + assert torch.allclose(processed_obs[OBS_STATE], state_2d) + assert torch.allclose(processed_obs[OBS_ENV_STATE], env_state_2d) + assert torch.allclose(processed_obs[OBS_IMAGE], image_4d) + + +def test_higher_dimensional_tensors_unchanged(): + """Test that tensors with more dimensions than expected remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + # Create tensors with more dimensions + state_3d = torch.randn(2, 7, 5) # More than 1D + image_5d = torch.randn(2, 3, 224, 224, 1) # More than 3D + + observation = { + OBS_STATE: state_3d, + OBS_IMAGE: image_5d, + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_obs = result[TransitionKey.OBSERVATION] + + # Should remain unchanged + assert torch.allclose(processed_obs[OBS_STATE], state_3d) + assert torch.allclose(processed_obs[OBS_IMAGE], image_5d) + + +def test_non_tensor_values_unchanged(): + """Test that non-tensor values in observations remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + observation = { + OBS_STATE: [1, 2, 3], # List, not tensor + OBS_IMAGE: "not_a_tensor", # String + "custom_key": 42, # Integer + "another_key": {"nested": "dict"}, # Dict + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_obs = result[TransitionKey.OBSERVATION] + + # Should remain unchanged + assert processed_obs[OBS_STATE] == [1, 2, 3] + assert processed_obs[OBS_IMAGE] == "not_a_tensor" + assert processed_obs["custom_key"] == 42 + assert processed_obs["another_key"] == {"nested": "dict"} + + +def test_none_observation(): + """Test processor handles None observation gracefully.""" + processor = AddBatchDimensionProcessorStep() + + transition = create_transition(observation={}, action=torch.empty(0)) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION] == {} + + +def test_empty_observation(): + """Test processor handles empty observation dict.""" + processor = AddBatchDimensionProcessorStep() + + observation = {} + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + assert result[TransitionKey.OBSERVATION] == {} + + +def test_mixed_observation(): + """Test processor with mixed observation containing various types and dimensions.""" + processor = AddBatchDimensionProcessorStep() + + state_1d = torch.randn(5) + env_state_2d = torch.randn(1, 8) # Already batched + image_3d = torch.randn(32, 32, 3) + other_tensor = torch.randn(3, 3, 3, 3) # 4D, should be unchanged + + observation = { + OBS_STATE: state_1d, + OBS_ENV_STATE: env_state_2d, + OBS_IMAGE: image_3d, + f"{OBS_IMAGES}.front": torch.randn(64, 64, 3), # 3D, should be batched + f"{OBS_IMAGES}.back": torch.randn(1, 64, 64, 3), # 4D, should be unchanged + "other_tensor": other_tensor, + "non_tensor": "string_value", + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check transformations + assert processed_obs[OBS_STATE].shape == (1, 5) + assert processed_obs[OBS_ENV_STATE].shape == (1, 8) # Unchanged + assert processed_obs[OBS_IMAGE].shape == (1, 32, 32, 3) + assert processed_obs[f"{OBS_IMAGES}.front"].shape == (1, 64, 64, 3) + assert processed_obs[f"{OBS_IMAGES}.back"].shape == (1, 64, 64, 3) # Unchanged + assert processed_obs["other_tensor"].shape == (3, 3, 3, 3) # Unchanged + assert processed_obs["non_tensor"] == "string_value" # Unchanged + + +def test_integration_with_robot_processor(): + """Test AddBatchDimensionProcessorStep integration with RobotProcessor.""" + to_batch_processor = AddBatchDimensionProcessorStep() + pipeline = DataProcessorPipeline( + [to_batch_processor], to_transition=identity_transition, to_output=identity_transition + ) + + # Create unbatched observation + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(224, 224, 3), + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = pipeline(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert processed_obs[OBS_STATE].shape == (1, 7) + assert processed_obs[OBS_IMAGE].shape == (1, 224, 224, 3) + + +def test_serialization_methods(): + """Test get_config, state_dict, load_state_dict, and reset methods.""" + processor = AddBatchDimensionProcessorStep() + + # Test get_config + config = processor.get_config() + assert isinstance(config, dict) + assert config == {} + + # Test state_dict + state = processor.state_dict() + assert isinstance(state, dict) + assert state == {} + + # Test load_state_dict (should not raise an error) + processor.load_state_dict({}) + + # Test reset (should not raise an error) + processor.reset() + + +def test_save_and_load_pretrained(): + """Test saving and loading AddBatchDimensionProcessorStep with RobotProcessor.""" + processor = AddBatchDimensionProcessorStep() + pipeline = DataProcessorPipeline( + [processor], name="BatchPipeline", to_transition=identity_transition, to_output=identity_transition + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Check config file exists + config_path = Path(tmp_dir) / "batchpipeline.json" + assert config_path.exists() + + # Load pipeline + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="batchpipeline.json", + to_transition=identity_transition, + to_output=identity_transition, + ) + + assert loaded_pipeline.name == "BatchPipeline" + assert len(loaded_pipeline) == 1 + assert isinstance(loaded_pipeline.steps[0], AddBatchDimensionProcessorStep) + + # Test functionality of loaded processor + observation = {OBS_STATE: torch.randn(5)} + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = loaded_pipeline(transition) + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5) + + +def test_registry_functionality(): + """Test that AddBatchDimensionProcessorStep is properly registered.""" + # Check that the processor is registered + registered_class = ProcessorStepRegistry.get("to_batch_processor") + assert registered_class is AddBatchDimensionProcessorStep + + # Check that it's in the list of registered processors + assert "to_batch_processor" in ProcessorStepRegistry.list() + + +def test_registry_based_save_load(): + """Test saving and loading using registry name.""" + processor = AddBatchDimensionProcessorStep() + pipeline = DataProcessorPipeline( + [processor], to_transition=identity_transition, to_output=identity_transition + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Verify the loaded processor works + observation = { + OBS_STATE: torch.randn(3), + OBS_IMAGE: torch.randn(100, 100, 3), + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = loaded_pipeline(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert processed_obs[OBS_STATE].shape == (1, 3) + assert processed_obs[OBS_IMAGE].shape == (1, 100, 100, 3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_device_compatibility(): + """Test processor works with tensors on different devices.""" + processor = AddBatchDimensionProcessorStep() + + # Create tensors on GPU + state_1d = torch.randn(7, device="cuda") + image_3d = torch.randn(64, 64, 3, device="cuda") + + observation = { + OBS_STATE: state_1d, + OBS_IMAGE: image_3d, + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check shapes and that tensors stayed on GPU + assert processed_obs[OBS_STATE].shape == (1, 7) + assert processed_obs[OBS_IMAGE].shape == (1, 64, 64, 3) + assert processed_obs[OBS_STATE].device.type == "cuda" + assert processed_obs[OBS_IMAGE].device.type == "cuda" + + +def test_processor_preserves_other_transition_keys(): + """Test that processor only modifies observation and preserves other transition keys.""" + processor = AddBatchDimensionProcessorStep() + + action = torch.randn(5) + reward = 1.5 + done = True + truncated = False + info = {"step": 10} + comp_data = {"extra": "data"} + + observation = {OBS_STATE: torch.randn(7)} + + transition = create_transition( + observation=observation, + action=action, + reward=reward, + done=done, + truncated=truncated, + info=info, + complementary_data=comp_data, + ) + + result = processor(transition) + + # Check that non-observation keys are preserved + assert torch.allclose(result[TransitionKey.ACTION], action) + assert result[TransitionKey.REWARD] == reward + assert result[TransitionKey.DONE] == done + assert result[TransitionKey.TRUNCATED] == truncated + assert result[TransitionKey.INFO] == info + assert result[TransitionKey.COMPLEMENTARY_DATA] == comp_data + + # Check that observation was processed + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) + + +def test_edge_case_zero_dimensional_tensors(): + """Test processor handles 0D tensors (scalars) correctly.""" + processor = AddBatchDimensionProcessorStep() + + # 0D tensors should not be modified + scalar_tensor = torch.tensor(42.0) + + observation = { + OBS_STATE: scalar_tensor, + "scalar_value": scalar_tensor, + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # 0D tensors should remain unchanged + assert torch.allclose(processed_obs[OBS_STATE], scalar_tensor) + assert torch.allclose(processed_obs["scalar_value"], scalar_tensor) + + +# Action-specific tests +def test_action_1d_to_2d(): + """Test that 1D action tensors get batch dimension added.""" + processor = AddBatchDimensionProcessorStep() + + # Create 1D action tensor + action_1d = torch.randn(4) + transition = create_transition(observation={}, action=action_1d) + + result = processor(transition) + + # Should add batch dimension + assert result[TransitionKey.ACTION].shape == (1, 4) + assert torch.equal(result[TransitionKey.ACTION][0], action_1d) + + +def test_action_already_batched(): + """Test that already batched action tensors remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + # Test various batch sizes + action_batched_1 = torch.randn(1, 4) + action_batched_5 = torch.randn(5, 4) + + # Single batch + transition = create_transition(action=action_batched_1, observation={}) + result = processor(transition) + assert torch.equal(result[TransitionKey.ACTION], action_batched_1) + + # Multiple batch + transition = create_transition(action=action_batched_5, observation={}) + result = processor(transition) + assert torch.equal(result[TransitionKey.ACTION], action_batched_5) + + +def test_action_higher_dimensional(): + """Test that higher dimensional action tensors remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + # 3D action tensor (e.g., sequence of actions) + action_3d = torch.randn(2, 4, 3) + transition = create_transition(action=action_3d, observation={}) + result = processor(transition) + assert torch.equal(result[TransitionKey.ACTION], action_3d) + + # 4D action tensor + action_4d = torch.randn(2, 10, 4, 3) + transition = create_transition(action=action_4d, observation={}) + result = processor(transition) + assert torch.equal(result[TransitionKey.ACTION], action_4d) + + +def test_action_scalar_tensor(): + """Test that scalar (0D) action tensors remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + action_scalar = torch.tensor(1.5) + transition = create_transition(action=action_scalar, observation={}) + result = processor(transition) + + # Should remain scalar + assert result[TransitionKey.ACTION].dim() == 0 + assert torch.equal(result[TransitionKey.ACTION], action_scalar) + + +def test_action_non_tensor_raises_error(): + """Test that non-tensor actions raise ValueError for PolicyAction processors.""" + processor = AddBatchDimensionProcessorStep() + + # List action should raise error + action_list = [0.1, 0.2, 0.3, 0.4] + transition = create_transition(action=action_list) + with pytest.raises(ValueError, match="Action should be a PolicyAction type"): + processor(transition) + + # Numpy array action should raise error + action_numpy = np.array([1, 2, 3, 4]) + transition = create_transition(action=action_numpy) + with pytest.raises(ValueError, match="Action should be a PolicyAction type"): + processor(transition) + + # String action should raise error + action_string = "forward" + transition = create_transition(action=action_string) + with pytest.raises(ValueError, match="Action should be a PolicyAction type"): + processor(transition) + + # Dict action should raise error + action_dict = {"linear": [0.5, 0.0], "angular": 0.2} + transition = create_transition(action=action_dict) + with pytest.raises(ValueError, match="Action should be a PolicyAction type"): + processor(transition) + + +def test_action_none(): + """Test that empty action tensor is handled correctly.""" + processor = AddBatchDimensionProcessorStep() + + transition = create_transition(action=torch.empty(0), observation={}) + result = processor(transition) + # Empty 1D tensor becomes empty 2D tensor with batch dimension + assert result[TransitionKey.ACTION].shape == (1, 0) + + +def test_action_with_observation(): + """Test action processing together with observation processing.""" + processor = AddBatchDimensionProcessorStep() + + # Both need batching + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(64, 64, 3), + } + action = torch.randn(4) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Both should be batched + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 64, 64, 3) + assert result[TransitionKey.ACTION].shape == (1, 4) + + +def test_action_different_sizes(): + """Test action processing with various action dimensions.""" + processor = AddBatchDimensionProcessorStep() + + # Different action sizes (robot with different DOF) + action_sizes = [1, 2, 4, 7, 10, 20] + + for size in action_sizes: + action = torch.randn(size) + transition = create_transition(action=action, observation={}) + result = processor(transition) + + assert result[TransitionKey.ACTION].shape == (1, size) + assert torch.equal(result[TransitionKey.ACTION][0], action) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_action_device_compatibility(): + """Test action processing on different devices.""" + processor = AddBatchDimensionProcessorStep() + + # CUDA action + action_cuda = torch.randn(4, device="cuda") + transition = create_transition(action=action_cuda, observation={}) + result = processor(transition) + + assert result[TransitionKey.ACTION].shape == (1, 4) + assert result[TransitionKey.ACTION].device.type == "cuda" + + # CPU action + action_cpu = torch.randn(4, device="cpu") + transition = create_transition(action=action_cpu, observation={}) + result = processor(transition) + + assert result[TransitionKey.ACTION].shape == (1, 4) + assert result[TransitionKey.ACTION].device.type == "cpu" + + +def test_action_dtype_preservation(): + """Test that action dtype is preserved during processing.""" + processor = AddBatchDimensionProcessorStep() + + # Different dtypes + dtypes = [torch.float32, torch.float64, torch.int32, torch.int64] + + for dtype in dtypes: + action = torch.randn(4).to(dtype) + transition = create_transition(action=action, observation={}) + result = processor(transition) + + assert result[TransitionKey.ACTION].dtype == dtype + assert result[TransitionKey.ACTION].shape == (1, 4) + + +def test_empty_action_tensor(): + """Test handling of empty action tensors.""" + processor = AddBatchDimensionProcessorStep() + + # Empty 1D tensor + action_empty = torch.tensor([]) + transition = create_transition(action=action_empty, observation={}) + result = processor(transition) + + # Should add batch dimension even to empty tensor + assert result[TransitionKey.ACTION].shape == (1, 0) + + # Empty 2D tensor (already batched) + action_empty_2d = torch.randn(1, 0) + transition = create_transition(action=action_empty_2d, observation={}) + result = processor(transition) + + # Should remain unchanged + assert result[TransitionKey.ACTION].shape == (1, 0) + + +# Task-specific tests +def test_task_string_to_list(): + """Test that string tasks get wrapped in lists to add batch dimension.""" + processor = AddBatchDimensionProcessorStep() + + # Create complementary data with string task + complementary_data = {"task": "pick_cube"} + transition = create_transition( + action=torch.empty(0), observation={}, complementary_data=complementary_data + ) + + result = processor(transition) + + # String task should be wrapped in list + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["pick_cube"] + assert isinstance(processed_comp_data["task"], list) + assert len(processed_comp_data["task"]) == 1 + + +def test_task_string_validation(): + """Test that only string and list of strings are valid task values.""" + processor = AddBatchDimensionProcessorStep() + + # Valid string task - should be converted to list + complementary_data = {"task": "valid_task"} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + result = processor(transition) + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["valid_task"] + + # Valid list of strings - should remain unchanged + complementary_data = {"task": ["task1", "task2"]} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + result = processor(transition) + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["task1", "task2"] + + +def test_task_list_of_strings(): + """Test that lists of strings remain unchanged (already batched).""" + processor = AddBatchDimensionProcessorStep() + + # Test various list of strings + test_lists = [ + ["pick_cube"], # Single string in list + ["pick_cube", "place_cube"], # Multiple strings + ["task1", "task2", "task3"], # Three strings + [], # Empty list + [""], # List with empty string + ["task with spaces", "task_with_underscores"], # Mixed formats + ] + + for task_list in test_lists: + complementary_data = {"task": task_list} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + # Should remain unchanged since it's already a list + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == task_list + assert isinstance(processed_comp_data["task"], list) + + +def test_complementary_data_none(): + """Test processor handles None complementary_data gracefully.""" + processor = AddBatchDimensionProcessorStep() + + transition = create_transition(complementary_data=None, action=torch.empty(0), observation={}) + result = processor(transition) + + assert result[TransitionKey.COMPLEMENTARY_DATA] == {} + + +def test_complementary_data_empty(): + """Test processor handles empty complementary_data dict.""" + processor = AddBatchDimensionProcessorStep() + + complementary_data = {} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + assert result[TransitionKey.COMPLEMENTARY_DATA] == {} + + +def test_complementary_data_no_task(): + """Test processor handles complementary_data without task field.""" + processor = AddBatchDimensionProcessorStep() + + complementary_data = { + "episode_id": 123, + "timestamp": 1234567890.0, + "extra_info": "some data", + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + # Should remain unchanged + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data == complementary_data + + +def test_complementary_data_mixed(): + """Test processor with mixed complementary_data containing task and other fields.""" + processor = AddBatchDimensionProcessorStep() + + complementary_data = { + "task": "stack_blocks", + "episode_id": 456, + "difficulty": "hard", + "metadata": {"scene": "kitchen"}, + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Task should be batched + assert processed_comp_data["task"] == ["stack_blocks"] + + # Other fields should remain unchanged + assert processed_comp_data["episode_id"] == 456 + assert processed_comp_data["difficulty"] == "hard" + assert processed_comp_data["metadata"] == {"scene": "kitchen"} + + +def test_task_with_observation_and_action(): + """Test task processing together with observation and action processing.""" + processor = AddBatchDimensionProcessorStep() + + # All components need batching + observation = { + OBS_STATE: torch.randn(5), + OBS_IMAGE: torch.randn(32, 32, 3), + } + action = torch.randn(4) + complementary_data = {"task": "navigate_to_goal"} + + transition = create_transition( + observation=observation, action=action, complementary_data=complementary_data + ) + + result = processor(transition) + + # All should be batched + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5) + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 32, 32, 3) + assert result[TransitionKey.ACTION].shape == (1, 4) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["navigate_to_goal"] + + +def test_task_comprehensive_string_cases(): + """Test task processing with comprehensive string cases and edge cases.""" + processor = AddBatchDimensionProcessorStep() + + # Test various string formats + string_tasks = [ + "pick_and_place", + "navigate", + "open_drawer", + "", # Empty string (valid but edge case) + "task with spaces", + "task_with_underscores", + "task-with-dashes", + "UPPERCASE_TASK", + "MixedCaseTask", + "task123", + "数字任务", # Unicode task + "🤖 robot task", # Emoji in task + "task\nwith\nnewlines", # Special characters + "task\twith\ttabs", + "task with 'quotes'", + 'task with "double quotes"', + ] + + # Test that all string tasks get properly batched + for task in string_tasks: + complementary_data = {"task": task} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == [task] + assert isinstance(processed_comp_data["task"], list) + assert len(processed_comp_data["task"]) == 1 + + # Test various list of strings (should remain unchanged) + list_tasks = [ + ["single_task"], + ["task1", "task2"], + ["pick", "place", "navigate"], + [], # Empty list + [""], # List with empty string + ["task with spaces", "task_with_underscores", "UPPERCASE"], + ["🤖 task", "数字任务", "normal_task"], # Mixed formats + ] + + for task_list in list_tasks: + complementary_data = {"task": task_list} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == task_list + assert isinstance(processed_comp_data["task"], list) + + +def test_task_preserves_other_keys(): + """Test that task processing preserves other keys in complementary_data.""" + processor = AddBatchDimensionProcessorStep() + + complementary_data = { + "task": "clean_table", + "robot_id": "robot_123", + "motor_id": "motor_456", + "config": {"speed": "slow", "precision": "high"}, + "metrics": [1.0, 2.0, 3.0], + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Task should be processed + assert processed_comp_data["task"] == ["clean_table"] + + # All other keys should be preserved exactly + assert processed_comp_data["robot_id"] == "robot_123" + assert processed_comp_data["motor_id"] == "motor_456" + assert processed_comp_data["config"] == {"speed": "slow", "precision": "high"} + assert processed_comp_data["metrics"] == [1.0, 2.0, 3.0] + + +# Index and task_index specific tests +def test_index_scalar_to_1d(): + """Test that 0D index tensor gets unsqueezed to 1D.""" + processor = AddBatchDimensionProcessorStep() + + # Create 0D index tensor (scalar) + index_0d = torch.tensor(42, dtype=torch.int64) + complementary_data = {"index": index_0d} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"].shape == (1,) + assert processed_comp_data["index"].dtype == torch.int64 + assert processed_comp_data["index"][0] == 42 + + +def test_task_index_scalar_to_1d(): + """Test that 0D task_index tensor gets unsqueezed to 1D.""" + processor = AddBatchDimensionProcessorStep() + + # Create 0D task_index tensor (scalar) + task_index_0d = torch.tensor(7, dtype=torch.int64) + complementary_data = {"task_index": task_index_0d} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task_index"].shape == (1,) + assert processed_comp_data["task_index"].dtype == torch.int64 + assert processed_comp_data["task_index"][0] == 7 + + +def test_index_and_task_index_together(): + """Test processing both index and task_index together.""" + processor = AddBatchDimensionProcessorStep() + + # Create 0D tensors for both + index_0d = torch.tensor(100, dtype=torch.int64) + task_index_0d = torch.tensor(3, dtype=torch.int64) + complementary_data = { + "index": index_0d, + "task_index": task_index_0d, + "task": "pick_object", + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check index + assert processed_comp_data["index"].shape == (1,) + assert processed_comp_data["index"][0] == 100 + + # Check task_index + assert processed_comp_data["task_index"].shape == (1,) + assert processed_comp_data["task_index"][0] == 3 + + # Check task is also processed + assert processed_comp_data["task"] == ["pick_object"] + + +def test_index_already_batched(): + """Test that already batched index tensors remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + # Create already batched tensors + index_1d = torch.tensor([42], dtype=torch.int64) + index_2d = torch.tensor([[42, 43]], dtype=torch.int64) + + # Test 1D (already batched) + complementary_data = {"index": index_1d} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + result = processor(transition) + assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_1d) + + # Test 2D + complementary_data = {"index": index_2d} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + result = processor(transition) + assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_2d) + + +def test_task_index_already_batched(): + """Test that already batched task_index tensors remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + # Create already batched tensors + task_index_1d = torch.tensor([7], dtype=torch.int64) + task_index_2d = torch.tensor([[7, 8]], dtype=torch.int64) + + # Test 1D (already batched) + complementary_data = {"task_index": task_index_1d} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + result = processor(transition) + assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_1d) + + # Test 2D + complementary_data = {"task_index": task_index_2d} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + result = processor(transition) + assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_2d) + + +def test_index_non_tensor_unchanged(): + """Test that non-tensor index values remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + complementary_data = { + "index": 42, # Plain int, not tensor + "task_index": [1, 2, 3], # List, not tensor + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"] == 42 + assert processed_comp_data["task_index"] == [1, 2, 3] + + +def test_index_dtype_preservation(): + """Test that index and task_index dtype is preserved during processing.""" + processor = AddBatchDimensionProcessorStep() + + # Test different dtypes + dtypes = [torch.int32, torch.int64, torch.long] + + for dtype in dtypes: + index_0d = torch.tensor(42, dtype=dtype) + task_index_0d = torch.tensor(7, dtype=dtype) + complementary_data = { + "index": index_0d, + "task_index": task_index_0d, + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"].dtype == dtype + assert processed_comp_data["task_index"].dtype == dtype + + +def test_index_with_full_transition(): + """Test index/task_index processing with full transition data.""" + processor = AddBatchDimensionProcessorStep() + + # Create full transition with all components + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(64, 64, 3), + } + action = torch.randn(4) + complementary_data = { + "task": "navigate_to_goal", + "index": torch.tensor(1000, dtype=torch.int64), + "task_index": torch.tensor(5, dtype=torch.int64), + "episode_id": 123, + } + + transition = create_transition( + observation=observation, + action=action, + reward=0.5, + done=False, + complementary_data=complementary_data, + ) + + result = processor(transition) + + # Check all components are processed correctly + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 64, 64, 3) + assert result[TransitionKey.ACTION].shape == (1, 4) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["navigate_to_goal"] + assert processed_comp_data["index"].shape == (1,) + assert processed_comp_data["index"][0] == 1000 + assert processed_comp_data["task_index"].shape == (1,) + assert processed_comp_data["task_index"][0] == 5 + assert processed_comp_data["episode_id"] == 123 # Non-tensor field unchanged + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_index_device_compatibility(): + """Test processor works with index/task_index tensors on different devices.""" + processor = AddBatchDimensionProcessorStep() + + # Create tensors on GPU + index_0d = torch.tensor(42, dtype=torch.int64, device="cuda") + task_index_0d = torch.tensor(7, dtype=torch.int64, device="cuda") + + complementary_data = { + "index": index_0d, + "task_index": task_index_0d, + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check shapes and that tensors stayed on GPU + assert processed_comp_data["index"].shape == (1,) + assert processed_comp_data["task_index"].shape == (1,) + assert processed_comp_data["index"].device.type == "cuda" + assert processed_comp_data["task_index"].device.type == "cuda" + + +def test_empty_index_tensor(): + """Test handling of empty index tensors.""" + processor = AddBatchDimensionProcessorStep() + + # Empty 0D tensor doesn't make sense, but test empty 1D + index_empty = torch.tensor([], dtype=torch.int64) + complementary_data = {"index": index_empty} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + # Should remain unchanged (already 1D) + assert result[TransitionKey.COMPLEMENTARY_DATA]["index"].shape == (0,) + + +def test_action_processing_creates_new_transition(): + """Test that the processor creates a new transition object with correctly processed action.""" + processor = AddBatchDimensionProcessorStep() + + action = torch.randn(4) + transition = create_transition(action=action, observation={}) + + # Store reference to original transition + original_transition = transition + + # Process + result = processor(transition) + + # Should be a different object (functional design, not in-place mutation) + assert result is not original_transition + # Original transition should remain unchanged + assert original_transition[TransitionKey.ACTION].shape == (4,) + # Result should have correctly processed action with batch dimension + assert result[TransitionKey.ACTION].shape == (1, 4) + assert torch.equal(result[TransitionKey.ACTION][0], action) + + +def test_task_processing_creates_new_transition(): + """Test that the processor creates a new transition object with correctly processed task.""" + processor = AddBatchDimensionProcessorStep() + + complementary_data = {"task": "sort_objects"} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + # Store reference to original transition and complementary_data + original_transition = transition + original_comp_data = complementary_data + + # Process + result = processor(transition) + + # Should be different transition object (functional design) + assert result is not original_transition + # The task should be processed correctly (wrapped in list) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["sort_objects"] + # Original complementary data is also modified (current behavior) + assert original_comp_data["task"] == "sort_objects" diff --git a/tests/processor/test_classifier_processor.py b/tests/processor/test_classifier_processor.py new file mode 100644 index 00000000..139e99bd --- /dev/null +++ b/tests/processor/test_classifier_processor.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Reward Classifier processor.""" + +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import OBS_IMAGE, OBS_STATE +from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor +from lerobot.processor import ( + DataProcessorPipeline, + DeviceProcessorStep, + IdentityProcessorStep, + NormalizerProcessorStep, + TransitionKey, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +def create_default_config(): + """Create a default Reward Classifier configuration for testing.""" + config = RewardClassifierConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), # Classifier output + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.IDENTITY, # No normalization for classifier output + } + config.device = "cpu" + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)}, + OBS_IMAGE: {}, # No normalization for images + "reward": {}, # No normalization for classifier output + } + + +def test_make_classifier_processor_basic(): + """Test basic creation of Classifier processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor(config, stats) + + # Check processor names + assert preprocessor.name == "classifier_preprocessor" + assert postprocessor.name == "classifier_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 3 + assert isinstance(preprocessor.steps[0], NormalizerProcessorStep) # For input features + assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) # For output features + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], DeviceProcessorStep) + assert isinstance(postprocessor.steps[1], IdentityProcessorStep) + + +def test_classifier_processor_normalization(): + """Test that Classifier processor correctly normalizes data.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor( + config, + stats, + ) + + # Create test data + observation = { + OBS_STATE: torch.randn(10), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(1) # Dummy action/reward + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is processed + assert processed[OBS_STATE].shape == (10,) + assert processed[OBS_IMAGE].shape == (3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1,) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_classifier_processor_cuda(): + """Test Classifier processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor( + config, + stats, + ) + + # Create CPU data + observation = { + OBS_STATE: torch.randn(10), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(1) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + # Process through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that output is back on CPU + assert postprocessed.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_classifier_processor_accelerate_scenario(): + """Test Classifier processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU + device = torch.device("cuda:0") + observation = { + OBS_STATE: torch.randn(10).to(device), + OBS_IMAGE: torch.randn(3, 224, 224).to(device), + } + action = torch.randn(1).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_classifier_processor_multi_gpu(): + """Test Classifier processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor(config, stats) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = { + OBS_STATE: torch.randn(10).to(device), + OBS_IMAGE: torch.randn(3, 224, 224).to(device), + } + action = torch.randn(1).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_classifier_processor_without_stats(): + """Test Classifier processor creation without dataset statistics.""" + config = create_default_config() + + preprocessor, postprocessor = make_classifier_processor(config, dataset_stats=None) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + # Process should still work + observation = { + OBS_STATE: torch.randn(10), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(1) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed is not None + + +def test_classifier_processor_save_and_load(): + """Test saving and loading Classifier processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor(config, stats) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save preprocessor + preprocessor.save_pretrained(tmpdir) + + # Load preprocessor + loaded_preprocessor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="classifier_preprocessor.json" + ) + + # Test that loaded processor works + observation = { + OBS_STATE: torch.randn(10), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(1) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (10,) + assert processed[OBS_IMAGE].shape == (3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1,) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_classifier_processor_mixed_precision(): + """Test Classifier processor with mixed precision.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor(config, stats) + + # Replace DeviceProcessorStep with one that uses float16 + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Create test data + observation = { + OBS_STATE: torch.randn(10, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(1, dtype=torch.float32) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is converted to float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[OBS_IMAGE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 + + +def test_classifier_processor_batch_data(): + """Test Classifier processor with batched data.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor( + config, + stats, + ) + + # Test with batched data + batch_size = 16 + observation = { + OBS_STATE: torch.randn(batch_size, 10), + OBS_IMAGE: torch.randn(batch_size, 3, 224, 224), + } + action = torch.randn(batch_size, 1) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that batch dimension is preserved + assert processed[OBS_STATE].shape == (batch_size, 10) + assert processed[OBS_IMAGE].shape == (batch_size, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (batch_size, 1) + + +def test_classifier_processor_postprocessor_identity(): + """Test that Classifier postprocessor uses IdentityProcessor correctly.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor( + config, + stats, + ) + + # Create test data for postprocessor + reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions + transition = create_transition(action=reward) + + _ = transition_to_batch(transition) + + # Process through postprocessor + processed = postprocessor(reward) + + # IdentityProcessor should leave values unchanged (except device) + assert torch.allclose(processed.cpu(), reward.cpu()) + assert processed.device.type == "cpu" diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py new file mode 100644 index 00000000..fc91951d --- /dev/null +++ b/tests/processor/test_converters.py @@ -0,0 +1,292 @@ +import numpy as np +import pytest +import torch + +from lerobot.processor import TransitionKey +from lerobot.processor.converters import ( + batch_to_transition, + create_transition, + to_tensor, + transition_to_batch, +) + + +# Tests for the unified to_tensor function +def test_to_tensor_numpy_arrays(): + """Test to_tensor with various numpy arrays.""" + # Regular numpy array + arr = np.array([1.0, 2.0, 3.0]) + result = to_tensor(arr) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0])) + + # Different numpy dtypes should convert to float32 by default + int_arr = np.array([1, 2, 3], dtype=np.int64) + result = to_tensor(int_arr) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0])) + + # uint8 arrays (previously "preserved") should now convert + uint8_arr = np.array([100, 150, 200], dtype=np.uint8) + result = to_tensor(uint8_arr) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([100.0, 150.0, 200.0])) + + +def test_to_tensor_numpy_scalars(): + """Test to_tensor with numpy scalars (0-dimensional arrays).""" + # numpy float32 scalar + scalar = np.float32(3.14) + result = to_tensor(scalar) + assert isinstance(result, torch.Tensor) + assert result.ndim == 0 # Should be 0-dimensional tensor + assert result.dtype == torch.float32 + assert result.item() == pytest.approx(3.14) + + # numpy int32 scalar + int_scalar = np.int32(42) + result = to_tensor(int_scalar) + assert isinstance(result, torch.Tensor) + assert result.ndim == 0 + assert result.dtype == torch.float32 + assert result.item() == pytest.approx(42.0) + + +def test_to_tensor_python_scalars(): + """Test to_tensor with Python scalars.""" + # Python int + result = to_tensor(42) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert result.item() == pytest.approx(42.0) + + # Python float + result = to_tensor(3.14) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert result.item() == pytest.approx(3.14) + + +def test_to_tensor_sequences(): + """Test to_tensor with lists and tuples.""" + # List + result = to_tensor([1, 2, 3]) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0])) + + # Tuple + result = to_tensor((4.5, 5.5, 6.5)) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([4.5, 5.5, 6.5])) + + +def test_to_tensor_existing_tensors(): + """Test to_tensor with existing PyTorch tensors.""" + # Tensor with same dtype should pass through with potential device change + tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + result = to_tensor(tensor) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, tensor) + + # Tensor with different dtype should convert + int_tensor = torch.tensor([1, 2, 3], dtype=torch.int64) + result = to_tensor(int_tensor) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0])) + + +def test_to_tensor_dictionaries(): + """Test to_tensor with nested dictionaries.""" + # Simple dictionary + data = {"mean": [0.1, 0.2], "std": np.array([1.0, 2.0]), "count": 42} + result = to_tensor(data) + assert isinstance(result, dict) + assert isinstance(result["mean"], torch.Tensor) + assert isinstance(result["std"], torch.Tensor) + assert isinstance(result["count"], torch.Tensor) + assert torch.allclose(result["mean"], torch.tensor([0.1, 0.2])) + assert torch.allclose(result["std"], torch.tensor([1.0, 2.0])) + assert result["count"].item() == pytest.approx(42.0) + + # Nested dictionary + nested = { + "action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]}, + "observation": {"mean": np.array([0.5, 0.6]), "count": 10}, + } + result = to_tensor(nested) + assert isinstance(result, dict) + assert isinstance(result["action"], dict) + assert isinstance(result["observation"], dict) + assert isinstance(result["action"]["mean"], torch.Tensor) + assert isinstance(result["observation"]["mean"], torch.Tensor) + assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2])) + assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6])) + + +def test_to_tensor_none_filtering(): + """Test that None values are filtered out from dictionaries.""" + data = {"valid": [1, 2, 3], "none_value": None, "nested": {"valid": [4, 5], "also_none": None}} + result = to_tensor(data) + assert "none_value" not in result + assert "also_none" not in result["nested"] + assert "valid" in result + assert "valid" in result["nested"] + assert torch.allclose(result["valid"], torch.tensor([1.0, 2.0, 3.0])) + + +def test_to_tensor_dtype_parameter(): + """Test to_tensor with different dtype parameters.""" + arr = np.array([1, 2, 3]) + + # Default dtype (float32) + result = to_tensor(arr) + assert result.dtype == torch.float32 + + # Explicit float32 + result = to_tensor(arr, dtype=torch.float32) + assert result.dtype == torch.float32 + + # Float64 + result = to_tensor(arr, dtype=torch.float64) + assert result.dtype == torch.float64 + + # Preserve original dtype + float64_arr = np.array([1.0, 2.0, 3.0], dtype=np.float64) + result = to_tensor(float64_arr, dtype=None) + assert result.dtype == torch.float64 + + +def test_to_tensor_device_parameter(): + """Test to_tensor with device parameter.""" + arr = np.array([1.0, 2.0, 3.0]) + + # CPU device (default) + result = to_tensor(arr, device="cpu") + assert result.device.type == "cpu" + + # CUDA device (if available) + if torch.cuda.is_available(): + result = to_tensor(arr, device="cuda") + assert result.device.type == "cuda" + + +def test_to_tensor_empty_dict(): + """Test to_tensor with empty dictionary.""" + result = to_tensor({}) + assert isinstance(result, dict) + assert len(result) == 0 + + +def test_to_tensor_unsupported_type(): + """Test to_tensor with unsupported types raises TypeError.""" + with pytest.raises(TypeError, match="Unsupported type for tensor conversion"): + to_tensor("unsupported_string") + + with pytest.raises(TypeError, match="Unsupported type for tensor conversion"): + to_tensor(object()) + + +def test_batch_to_transition_with_index_fields(): + """Test that batch_to_transition handles index and task_index fields correctly.""" + + # Create batch with index and task_index fields + batch = { + "observation.state": torch.randn(1, 7), + "action": torch.randn(1, 4), + "next.reward": 1.5, + "next.done": False, + "task": ["pick_cube"], + "index": torch.tensor([42], dtype=torch.int64), + "task_index": torch.tensor([3], dtype=torch.int64), + } + + transition = batch_to_transition(batch) + + # Check basic transition structure + assert TransitionKey.OBSERVATION in transition + assert TransitionKey.ACTION in transition + assert TransitionKey.COMPLEMENTARY_DATA in transition + + # Check that index and task_index are in complementary_data + comp_data = transition[TransitionKey.COMPLEMENTARY_DATA] + assert "index" in comp_data + assert "task_index" in comp_data + assert "task" in comp_data + + # Verify values + assert torch.equal(comp_data["index"], batch["index"]) + assert torch.equal(comp_data["task_index"], batch["task_index"]) + assert comp_data["task"] == batch["task"] + + +def testtransition_to_batch_with_index_fields(): + """Test that transition_to_batch handles index and task_index fields correctly.""" + + # Create transition with index and task_index in complementary_data + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + action=torch.randn(1, 4), + reward=1.5, + done=False, + complementary_data={ + "task": ["navigate"], + "index": torch.tensor([100], dtype=torch.int64), + "task_index": torch.tensor([5], dtype=torch.int64), + }, + ) + + batch = transition_to_batch(transition) + + # Check that index and task_index are in the batch + assert "index" in batch + assert "task_index" in batch + assert "task" in batch + + # Verify values + assert torch.equal(batch["index"], transition[TransitionKey.COMPLEMENTARY_DATA]["index"]) + assert torch.equal(batch["task_index"], transition[TransitionKey.COMPLEMENTARY_DATA]["task_index"]) + assert batch["task"] == transition[TransitionKey.COMPLEMENTARY_DATA]["task"] + + +def test_batch_to_transition_without_index_fields(): + """Test that conversion works without index and task_index fields.""" + + # Batch without index/task_index + batch = { + "observation.state": torch.randn(1, 7), + "action": torch.randn(1, 4), + "task": ["pick_cube"], + } + + transition = batch_to_transition(batch) + comp_data = transition[TransitionKey.COMPLEMENTARY_DATA] + + # Should have task but not index/task_index + assert "task" in comp_data + assert "index" not in comp_data + assert "task_index" not in comp_data + + +def test_transition_to_batch_without_index_fields(): + """Test that conversion works without index and task_index fields.""" + + # Transition without index/task_index + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + action=torch.randn(1, 4), + complementary_data={"task": ["navigate"]}, + ) + + batch = transition_to_batch(transition) + + # Should have task but not index/task_index + assert "task" in batch + assert "index" not in batch + assert "task_index" not in batch diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py new file mode 100644 index 00000000..ba00bde4 --- /dev/null +++ b/tests/processor/test_device_processor.py @@ -0,0 +1,1161 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey +from lerobot.processor.converters import create_transition, identity_transition + + +def test_basic_functionality(): + """Test basic device processor functionality on CPU.""" + processor = DeviceProcessorStep(device="cpu") + + # Create a transition with CPU tensors + observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} + action = torch.randn(5) + reward = torch.tensor(1.0) + done = torch.tensor(False) + truncated = torch.tensor(False) + + transition = create_transition( + observation=observation, action=action, reward=reward, done=done, truncated=truncated + ) + + result = processor(transition) + + # Check that all tensors are on CPU + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cpu" + assert result[TransitionKey.ACTION].device.type == "cpu" + assert result[TransitionKey.REWARD].device.type == "cpu" + assert result[TransitionKey.DONE].device.type == "cpu" + assert result[TransitionKey.TRUNCATED].device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_functionality(): + """Test device processor functionality on CUDA.""" + processor = DeviceProcessorStep(device="cuda") + + # Create a transition with CPU tensors + observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} + action = torch.randn(5) + reward = torch.tensor(1.0) + done = torch.tensor(False) + truncated = torch.tensor(False) + + transition = create_transition( + observation=observation, action=action, reward=reward, done=done, truncated=truncated + ) + + result = processor(transition) + + # Check that all tensors are on CUDA + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.REWARD].device.type == "cuda" + assert result[TransitionKey.DONE].device.type == "cuda" + assert result[TransitionKey.TRUNCATED].device.type == "cuda" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_specific_cuda_device(): + """Test device processor with specific CUDA device.""" + processor = DeviceProcessorStep(device="cuda:0") + + observation = {"observation.state": torch.randn(10)} + action = torch.randn(5) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.state"].device.index == 0 + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.ACTION].device.index == 0 + + +def test_non_tensor_values(): + """Test that non-tensor values are preserved.""" + processor = DeviceProcessorStep(device="cpu") + + observation = { + "observation.state": torch.randn(10), + "observation.metadata": {"key": "value"}, # Non-tensor data + "observation.list": [1, 2, 3], # Non-tensor data + } + action = torch.randn(5) + info = {"episode": 1, "step": 42} + + transition = create_transition(observation=observation, action=action, info=info) + + result = processor(transition) + + # Check tensors are processed + assert isinstance(result[TransitionKey.OBSERVATION]["observation.state"], torch.Tensor) + assert isinstance(result[TransitionKey.ACTION], torch.Tensor) + + # Check non-tensor values are preserved + assert result[TransitionKey.OBSERVATION]["observation.metadata"] == {"key": "value"} + assert result[TransitionKey.OBSERVATION]["observation.list"] == [1, 2, 3] + assert result[TransitionKey.INFO] == {"episode": 1, "step": 42} + + +def test_none_values(): + """Test handling of None values.""" + processor = DeviceProcessorStep(device="cpu") + + # Test with None observation + transition = create_transition(observation=None, action=torch.randn(5)) + result = processor(transition) + assert result[TransitionKey.OBSERVATION] is None + assert result[TransitionKey.ACTION].device.type == "cpu" + + # Test with None action + transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None) + result = processor(transition) + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result[TransitionKey.ACTION] is None + + +def test_empty_observation(): + """Test handling of empty observation dictionary.""" + processor = DeviceProcessorStep(device="cpu") + + transition = create_transition(observation={}, action=torch.randn(5)) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION] == {} + assert result[TransitionKey.ACTION].device.type == "cpu" + + +def test_scalar_tensors(): + """Test handling of scalar tensors.""" + processor = DeviceProcessorStep(device="cpu") + + observation = {"observation.scalar": torch.tensor(1.5)} + action = torch.tensor(2.0) + reward = torch.tensor(0.5) + + transition = create_transition(observation=observation, action=action, reward=reward) + + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.scalar"].item() == 1.5 + assert result[TransitionKey.ACTION].item() == 2.0 + assert result[TransitionKey.REWARD].item() == 0.5 + + +def test_dtype_preservation(): + """Test that tensor dtypes are preserved.""" + processor = DeviceProcessorStep(device="cpu") + + observation = { + "observation.float32": torch.randn(5, dtype=torch.float32), + "observation.float64": torch.randn(5, dtype=torch.float64), + "observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32), + "observation.bool": torch.tensor([True, False, True], dtype=torch.bool), + } + action = torch.randn(3, dtype=torch.float16) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float32 + assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float64 + assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32 + assert result[TransitionKey.OBSERVATION]["observation.bool"].dtype == torch.bool + assert result[TransitionKey.ACTION].dtype == torch.float16 + + +def test_shape_preservation(): + """Test that tensor shapes are preserved.""" + processor = DeviceProcessorStep(device="cpu") + + observation = { + "observation.1d": torch.randn(10), + "observation.2d": torch.randn(5, 10), + "observation.3d": torch.randn(3, 224, 224), + "observation.4d": torch.randn(2, 3, 224, 224), + } + action = torch.randn(2, 5, 3) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.1d"].shape == (10,) + assert result[TransitionKey.OBSERVATION]["observation.2d"].shape == (5, 10) + assert result[TransitionKey.OBSERVATION]["observation.3d"].shape == (3, 224, 224) + assert result[TransitionKey.OBSERVATION]["observation.4d"].shape == (2, 3, 224, 224) + assert result[TransitionKey.ACTION].shape == (2, 5, 3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_mixed_devices(): + """Test handling of tensors already on different devices.""" + processor = DeviceProcessorStep(device="cuda") + + # Create tensors on different devices + observation = { + "observation.cpu": torch.randn(5), # CPU + "observation.cuda": torch.randn(5).cuda(), # Already on CUDA + } + action = torch.randn(3).cuda() # Already on CUDA + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # All should be on CUDA + assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.cuda"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + + +def test_non_blocking_flag(): + """Test that non_blocking flag is set correctly.""" + # CPU processor should have non_blocking=False + cpu_processor = DeviceProcessorStep(device="cpu") + assert cpu_processor.non_blocking is False + + if torch.cuda.is_available(): + # CUDA processor should have non_blocking=True + cuda_processor = DeviceProcessorStep(device="cuda") + assert cuda_processor.non_blocking is True + + cuda_0_processor = DeviceProcessorStep(device="cuda:0") + assert cuda_0_processor.non_blocking is True + + +def test_serialization_methods(): + """Test get_config, state_dict, and load_state_dict methods.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + processor = DeviceProcessorStep(device=device) + + # Test get_config + config = processor.get_config() + assert config == {"device": device, "float_dtype": None} + + # Test state_dict (should be empty) + state = processor.state_dict() + assert state == {} + + # Test load_state_dict (should be no-op) + processor.load_state_dict({}) + assert processor.device == device + + # Test reset (should be no-op) + processor.reset() + assert processor.device == device + + +def test_features(): + """Test that features returns features unchanged.""" + processor = DeviceProcessorStep(device="cpu") + + features = { + PipelineFeatureType.OBSERVATION: { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)) + }, + PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, + } + + result = processor.transform_features(features) + assert result == features + assert result is features # Should return the same object + + +def test_integration_with_robot_processor(): + """Test integration with RobotProcessor.""" + from lerobot.constants import OBS_STATE + from lerobot.processor import AddBatchDimensionProcessorStep + + # Create a pipeline with DeviceProcessorStep + device_processor = DeviceProcessorStep(device="cpu") + batch_processor = AddBatchDimensionProcessorStep() + + processor = DataProcessorPipeline( + steps=[batch_processor, device_processor], + name="test_pipeline", + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Create test data + observation = {OBS_STATE: torch.randn(10)} + action = torch.randn(5) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check that tensors are batched and on correct device + # The result has TransitionKey.OBSERVATION as the key, with observation.state inside + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu" + assert result[TransitionKey.ACTION].shape[0] == 1 # Batched + assert result[TransitionKey.ACTION].device.type == "cpu" + + +def test_save_and_load_pretrained(): + """Test saving and loading processor with DeviceProcessorStep.""" + device = "cuda:0" if torch.cuda.is_available() else "cpu" + processor = DeviceProcessorStep(device=device, float_dtype="float16") + robot_processor = DataProcessorPipeline(steps=[processor], name="device_test_processor") + + with tempfile.TemporaryDirectory() as tmpdir: + # Save + robot_processor.save_pretrained(tmpdir) + + # Load + loaded_processor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="device_test_processor.json" + ) + + assert len(loaded_processor.steps) == 1 + loaded_device_processor = loaded_processor.steps[0] + assert isinstance(loaded_device_processor, DeviceProcessorStep) + # Use getattr to access attributes safely + assert ( + getattr(loaded_device_processor, "device", None) == device.split(":")[0] + ) # Device normalizes cuda:0 to cuda + assert getattr(loaded_device_processor, "float_dtype", None) == "float16" + + +def test_registry_functionality(): + """Test that DeviceProcessorStep is properly registered.""" + from lerobot.processor import ProcessorStepRegistry + + # Check that DeviceProcessorStep is registered + registered_class = ProcessorStepRegistry.get("device_processor") + assert registered_class is DeviceProcessorStep + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_performance_with_large_tensors(): + """Test performance with large tensors and non_blocking flag.""" + processor = DeviceProcessorStep(device="cuda") + + # Create large tensors + observation = { + "observation.large_image": torch.randn(10, 3, 512, 512), # Large image batch + "observation.features": torch.randn(10, 2048), # Large feature vector + } + action = torch.randn(10, 100) # Large action space + + transition = create_transition(observation=observation, action=action) + + # Process should not raise any errors + result = processor(transition) + + # Verify all tensors are on CUDA + assert result[TransitionKey.OBSERVATION]["observation.large_image"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.features"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + + +def test_reward_done_truncated_types(): + """Test handling of different types for reward, done, and truncated.""" + processor = DeviceProcessorStep(device="cpu") + + # Test with scalar values (not tensors) + transition = create_transition( + observation={"observation.state": torch.randn(5)}, + action=torch.randn(3), + reward=1.0, # float + done=False, # bool + truncated=True, # bool + ) + + result = processor(transition) + + # Non-tensor values should be preserved as-is + assert result[TransitionKey.REWARD] == 1.0 + assert result[TransitionKey.DONE] is False + assert result[TransitionKey.TRUNCATED] is True + + # Test with tensor values + transition = create_transition( + observation={"observation.state": torch.randn(5)}, + action=torch.randn(3), + reward=torch.tensor(1.0), + done=torch.tensor(False), + truncated=torch.tensor(True), + ) + + result = processor(transition) + + # Tensor values should be moved to device + assert isinstance(result[TransitionKey.REWARD], torch.Tensor) + assert isinstance(result[TransitionKey.DONE], torch.Tensor) + assert isinstance(result[TransitionKey.TRUNCATED], torch.Tensor) + assert result[TransitionKey.REWARD].device.type == "cpu" + assert result[TransitionKey.DONE].device.type == "cpu" + assert result[TransitionKey.TRUNCATED].device.type == "cpu" + + +def test_complementary_data_preserved(): + """Test that complementary_data is preserved unchanged.""" + processor = DeviceProcessorStep(device="cpu") + + complementary_data = { + "task": "pick_object", + "episode_id": 42, + "metadata": {"sensor": "camera_1"}, + "observation_is_pad": torch.tensor([False, False, True]), # This should be moved to device + } + + transition = create_transition( + observation={"observation.state": torch.randn(5)}, complementary_data=complementary_data + ) + + result = processor(transition) + + # Check that complementary_data is preserved + assert TransitionKey.COMPLEMENTARY_DATA in result + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick_object" + assert result[TransitionKey.COMPLEMENTARY_DATA]["episode_id"] == 42 + assert result[TransitionKey.COMPLEMENTARY_DATA]["metadata"] == {"sensor": "camera_1"} + # Note: Currently DeviceProcessorStep doesn't process tensors in complementary_data + # This is intentional as complementary_data is typically metadata + + +def test_float_dtype_conversion(): + """Test float dtype conversion functionality.""" + processor = DeviceProcessorStep(device="cpu", float_dtype="float16") + + # Create tensors of different types + observation = { + "observation.float32": torch.randn(5, dtype=torch.float32), + "observation.float64": torch.randn(5, dtype=torch.float64), + "observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32), + "observation.int64": torch.randint(0, 10, (5,), dtype=torch.int64), + "observation.bool": torch.tensor([True, False, True], dtype=torch.bool), + } + action = torch.randn(3, dtype=torch.float32) + reward = torch.tensor(1.0, dtype=torch.float32) + + transition = create_transition(observation=observation, action=action, reward=reward) + result = processor(transition) + + # Check that float tensors are converted to float16 + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float16 + assert result[TransitionKey.ACTION].dtype == torch.float16 + assert result[TransitionKey.REWARD].dtype == torch.float16 + + # Check that non-float tensors are preserved + assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32 + assert result[TransitionKey.OBSERVATION]["observation.int64"].dtype == torch.int64 + assert result[TransitionKey.OBSERVATION]["observation.bool"].dtype == torch.bool + + +def test_float_dtype_none(): + """Test that when float_dtype is None, no dtype conversion occurs.""" + processor = DeviceProcessorStep(device="cpu", float_dtype=None) + + observation = { + "observation.float32": torch.randn(5, dtype=torch.float32), + "observation.float64": torch.randn(5, dtype=torch.float64), + "observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32), + } + action = torch.randn(3, dtype=torch.float64) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check that dtypes are preserved when float_dtype is None + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float32 + assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float64 + assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32 + assert result[TransitionKey.ACTION].dtype == torch.float64 + + +def test_float_dtype_bfloat16(): + """Test conversion to bfloat16.""" + processor = DeviceProcessorStep(device="cpu", float_dtype="bfloat16") + + observation = {"observation.state": torch.randn(5, dtype=torch.float32)} + action = torch.randn(3, dtype=torch.float64) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.bfloat16 + assert result[TransitionKey.ACTION].dtype == torch.bfloat16 + + +def test_float_dtype_float64(): + """Test conversion to float64.""" + processor = DeviceProcessorStep(device="cpu", float_dtype="float64") + + observation = {"observation.state": torch.randn(5, dtype=torch.float16)} + action = torch.randn(3, dtype=torch.float32) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float64 + assert result[TransitionKey.ACTION].dtype == torch.float64 + + +def test_float_dtype_invalid(): + """Test that invalid float_dtype raises ValueError.""" + with pytest.raises(ValueError, match="Invalid float_dtype 'invalid_dtype'"): + DeviceProcessorStep(device="cpu", float_dtype="invalid_dtype") + + +def test_float_dtype_aliases(): + """Test that dtype aliases work correctly.""" + # Test 'half' alias for float16 + processor_half = DeviceProcessorStep(device="cpu", float_dtype="half") + assert processor_half._target_float_dtype == torch.float16 + + # Test 'float' alias for float32 + processor_float = DeviceProcessorStep(device="cpu", float_dtype="float") + assert processor_float._target_float_dtype == torch.float32 + + # Test 'double' alias for float64 + processor_double = DeviceProcessorStep(device="cpu", float_dtype="double") + assert processor_double._target_float_dtype == torch.float64 + + +def test_float_dtype_with_mixed_tensors(): + """Test float dtype conversion with mixed tensor types.""" + processor = DeviceProcessorStep(device="cpu", float_dtype="float32") + + observation = { + "observation.image": torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert + "observation.state": torch.randn(10, dtype=torch.float64), # Should convert + "observation.mask": torch.tensor([True, False, True], dtype=torch.bool), # Should not convert + "observation.indices": torch.tensor([1, 2, 3], dtype=torch.long), # Should not convert + } + action = torch.randn(5, dtype=torch.float16) # Should convert + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check conversions + assert result[TransitionKey.OBSERVATION]["observation.image"].dtype == torch.uint8 # Unchanged + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted + assert result[TransitionKey.OBSERVATION]["observation.mask"].dtype == torch.bool # Unchanged + assert result[TransitionKey.OBSERVATION]["observation.indices"].dtype == torch.long # Unchanged + assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted + + +def test_float_dtype_serialization(): + """Test that float_dtype is properly serialized in get_config.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + processor = DeviceProcessorStep(device=device, float_dtype="float16") + config = processor.get_config() + + assert config == {"device": device, "float_dtype": "float16"} + + # Test with None float_dtype + processor_none = DeviceProcessorStep(device="cpu", float_dtype=None) + config_none = processor_none.get_config() + + assert config_none == {"device": "cpu", "float_dtype": None} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_float_dtype_with_cuda(): + """Test float dtype conversion combined with CUDA device.""" + processor = DeviceProcessorStep(device="cuda", float_dtype="float16") + + # Create tensors on CPU with different dtypes + observation = { + "observation.float32": torch.randn(5, dtype=torch.float32), + "observation.int64": torch.tensor([1, 2, 3], dtype=torch.int64), + } + action = torch.randn(3, dtype=torch.float64) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check that tensors are on CUDA and float types are converted + assert result[TransitionKey.OBSERVATION]["observation.float32"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float16 + + assert result[TransitionKey.OBSERVATION]["observation.int64"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.int64"].dtype == torch.int64 # Unchanged + + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.ACTION].dtype == torch.float16 + + +def test_complementary_data_index_fields(): + """Test processing of index and task_index fields in complementary_data.""" + processor = DeviceProcessorStep(device="cpu") + + # Create transition with index and task_index in complementary_data + complementary_data = { + "task": ["pick_cube"], + "index": torch.tensor([42], dtype=torch.int64), + "task_index": torch.tensor([3], dtype=torch.int64), + "episode_id": 123, # Non-tensor field + } + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + action=torch.randn(1, 4), + complementary_data=complementary_data, + ) + + result = processor(transition) + + # Check that tensors in complementary_data are processed + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check index tensor + assert isinstance(processed_comp_data["index"], torch.Tensor) + assert processed_comp_data["index"].device.type == "cpu" + assert torch.equal(processed_comp_data["index"], complementary_data["index"]) + + # Check task_index tensor + assert isinstance(processed_comp_data["task_index"], torch.Tensor) + assert processed_comp_data["task_index"].device.type == "cpu" + assert torch.equal(processed_comp_data["task_index"], complementary_data["task_index"]) + + # Check non-tensor fields remain unchanged + assert processed_comp_data["task"] == ["pick_cube"] + assert processed_comp_data["episode_id"] == 123 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_complementary_data_index_fields_cuda(): + """Test moving index and task_index fields to CUDA.""" + processor = DeviceProcessorStep(device="cuda:0") + + # Create CPU tensors + complementary_data = { + "index": torch.tensor([100, 101], dtype=torch.int64), + "task_index": torch.tensor([5], dtype=torch.int64), + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check tensors moved to CUDA + assert processed_comp_data["index"].device.type == "cuda" + assert processed_comp_data["index"].device.index == 0 + assert processed_comp_data["task_index"].device.type == "cuda" + assert processed_comp_data["task_index"].device.index == 0 + + +def test_complementary_data_without_index_fields(): + """Test that complementary_data without index/task_index fields works correctly.""" + processor = DeviceProcessorStep(device="cpu") + + complementary_data = { + "task": ["navigate"], + "episode_id": 456, + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + # Should process without errors and preserve non-tensor fields + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["navigate"] + assert processed_comp_data["episode_id"] == 456 + + +def test_complementary_data_mixed_tensors(): + """Test complementary_data with mix of tensors and non-tensors.""" + processor = DeviceProcessorStep(device="cpu") + + complementary_data = { + "task": ["pick_and_place"], + "index": torch.tensor([42], dtype=torch.int64), + "task_index": torch.tensor([3], dtype=torch.int64), + "metrics": [1.0, 2.0, 3.0], # List, not tensor + "config": {"speed": "fast"}, # Dict + "episode_id": 789, # Int + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check tensors are processed + assert isinstance(processed_comp_data["index"], torch.Tensor) + assert isinstance(processed_comp_data["task_index"], torch.Tensor) + + # Check non-tensors remain unchanged + assert processed_comp_data["task"] == ["pick_and_place"] + assert processed_comp_data["metrics"] == [1.0, 2.0, 3.0] + assert processed_comp_data["config"] == {"speed": "fast"} + assert processed_comp_data["episode_id"] == 789 + + +def test_complementary_data_float_dtype_conversion(): + """Test that float dtype conversion doesn't affect int tensors in complementary_data.""" + processor = DeviceProcessorStep(device="cpu", float_dtype="float16") + + complementary_data = { + "index": torch.tensor([42], dtype=torch.int64), + "task_index": torch.tensor([3], dtype=torch.int64), + "float_tensor": torch.tensor([1.5, 2.5], dtype=torch.float32), # Should be converted + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Int tensors should keep their dtype + assert processed_comp_data["index"].dtype == torch.int64 + assert processed_comp_data["task_index"].dtype == torch.int64 + + # Float tensor should be converted + assert processed_comp_data["float_tensor"].dtype == torch.float16 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_complementary_data_full_pipeline_cuda(): + """Test full transition with complementary_data on CUDA.""" + processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16") + + # Create full transition with mixed CPU tensors + observation = {"observation.state": torch.randn(1, 7, dtype=torch.float32)} + action = torch.randn(1, 4, dtype=torch.float32) + reward = torch.tensor(1.5, dtype=torch.float32) + done = torch.tensor(False) + complementary_data = { + "task": ["reach_target"], + "index": torch.tensor([1000], dtype=torch.int64), + "task_index": torch.tensor([10], dtype=torch.int64), + } + + transition = create_transition( + observation=observation, + action=action, + reward=reward, + done=done, + complementary_data=complementary_data, + ) + + result = processor(transition) + + # Check all components moved to CUDA + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.REWARD].device.type == "cuda" + assert result[TransitionKey.DONE].device.type == "cuda" + + # Check complementary_data tensors + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"].device.type == "cuda" + assert processed_comp_data["task_index"].device.type == "cuda" + + # Check float conversion happened for float tensors + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float16 + assert result[TransitionKey.ACTION].dtype == torch.float16 + assert result[TransitionKey.REWARD].dtype == torch.float16 + + # Check int tensors kept their dtype + assert processed_comp_data["index"].dtype == torch.int64 + assert processed_comp_data["task_index"].dtype == torch.int64 + + +def test_complementary_data_empty(): + """Test empty complementary_data handling.""" + processor = DeviceProcessorStep(device="cpu") + + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + complementary_data={}, + ) + + result = processor(transition) + + # Should have empty dict + assert result[TransitionKey.COMPLEMENTARY_DATA] == {} + + +def test_complementary_data_none(): + """Test None complementary_data handling.""" + processor = DeviceProcessorStep(device="cpu") + + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + complementary_data=None, + ) + + result = processor(transition) + + # Complementary data should not be in the result (same as input) + assert result[TransitionKey.COMPLEMENTARY_DATA] == {} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_preserves_gpu_placement(): + """Test that DeviceProcessorStep preserves GPU placement when tensor is already on GPU.""" + processor = DeviceProcessorStep(device="cuda:0") + + # Create tensors already on GPU + observation = { + "observation.state": torch.randn(10).cuda(), # Already on GPU + "observation.image": torch.randn(3, 224, 224).cuda(), # Already on GPU + } + action = torch.randn(5).cuda() # Already on GPU + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check that tensors remain on their original GPU + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + + # Verify no unnecessary copies were made (same data pointer) + assert torch.equal( + result[TransitionKey.OBSERVATION]["observation.state"], observation["observation.state"] + ) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_multi_gpu_preservation(): + """Test that DeviceProcessorStep preserves placement on different GPUs in multi-GPU setup.""" + # Test 1: GPU-to-GPU preservation (cuda:0 config, cuda:1 input) + processor_gpu = DeviceProcessorStep(device="cuda:0") + + # Create tensors on cuda:1 (simulating Accelerate placement) + cuda1_device = torch.device("cuda:1") + observation = { + "observation.state": torch.randn(10).to(cuda1_device), + "observation.image": torch.randn(3, 224, 224).to(cuda1_device), + } + action = torch.randn(5).to(cuda1_device) + + transition = create_transition(observation=observation, action=action) + result = processor_gpu(transition) + + # Check that tensors remain on cuda:1 (not moved to cuda:0) + assert result[TransitionKey.OBSERVATION]["observation.state"].device == cuda1_device + assert result[TransitionKey.OBSERVATION]["observation.image"].device == cuda1_device + assert result[TransitionKey.ACTION].device == cuda1_device + + # Test 2: GPU-to-CPU should move to CPU (not preserve GPU) + processor_cpu = DeviceProcessorStep(device="cpu") + + transition_gpu = create_transition( + observation={"observation.state": torch.randn(10).cuda()}, action=torch.randn(5).cuda() + ) + result_cpu = processor_cpu(transition_gpu) + + # Check that tensors are moved to CPU + assert result_cpu[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result_cpu[TransitionKey.ACTION].device.type == "cpu" + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_multi_gpu_with_cpu_tensors(): + """Test that CPU tensors are moved to configured device even in multi-GPU context.""" + # Processor configured for cuda:1 + processor = DeviceProcessorStep(device="cuda:1") + + # Mix of CPU and GPU tensors + observation = { + "observation.cpu": torch.randn(10), # CPU tensor + "observation.gpu0": torch.randn(10).cuda(0), # Already on cuda:0 + "observation.gpu1": torch.randn(10).cuda(1), # Already on cuda:1 + } + action = torch.randn(5) # CPU tensor + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # CPU tensor should move to configured device (cuda:1) + assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.index == 1 + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.ACTION].device.index == 1 + + # GPU tensors should stay on their original devices + assert result[TransitionKey.OBSERVATION]["observation.gpu0"].device.index == 0 + assert result[TransitionKey.OBSERVATION]["observation.gpu1"].device.index == 1 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_multi_gpu_with_float_dtype(): + """Test float dtype conversion works correctly with multi-GPU preservation.""" + processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16") + + # Create float tensors on different GPUs + observation = { + "observation.gpu0": torch.randn(5, dtype=torch.float32).cuda(0), + "observation.gpu1": torch.randn(5, dtype=torch.float32).cuda(1), + "observation.cpu": torch.randn(5, dtype=torch.float32), # CPU + } + + transition = create_transition(observation=observation) + result = processor(transition) + + # Check device placement + assert result[TransitionKey.OBSERVATION]["observation.gpu0"].device.index == 0 + assert result[TransitionKey.OBSERVATION]["observation.gpu1"].device.index == 1 + assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.index == 0 # Moved to cuda:0 + + # Check dtype conversion happened for all + assert result[TransitionKey.OBSERVATION]["observation.gpu0"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION]["observation.gpu1"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION]["observation.cpu"].dtype == torch.float16 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_simulated_accelerate_scenario(): + """Test a scenario simulating how Accelerate would use the processor.""" + # Simulate different processes getting different GPU assignments + for gpu_id in range(min(torch.cuda.device_count(), 2)): + # Each "process" has a processor configured for cuda:0 + # but data comes in already placed on the process's GPU + processor = DeviceProcessorStep(device="cuda:0") + + # Simulate data already placed by Accelerate + device = torch.device(f"cuda:{gpu_id}") + observation = {"observation.state": torch.randn(1, 10).to(device)} + action = torch.randn(1, 5).to(device) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Verify data stays on the GPU where Accelerate placed it + assert result[TransitionKey.OBSERVATION]["observation.state"].device == device + assert result[TransitionKey.ACTION].device == device + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_policy_processor_integration(): + """Test integration with policy processors - input on GPU, output on CPU.""" + from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + from lerobot.constants import ACTION, OBS_STATE + from lerobot.processor import ( + AddBatchDimensionProcessorStep, + NormalizerProcessorStep, + UnnormalizerProcessorStep, + ) + + # Create features and stats + features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,)), + } + + stats = { + OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)}, + ACTION: {"mean": torch.zeros(5), "std": torch.ones(5)}, + } + + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MEAN_STD} + + # Create input processor (preprocessor) that moves to GPU + input_processor = DataProcessorPipeline( + steps=[ + NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device="cuda"), + ], + name="test_preprocessor", + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Create output processor (postprocessor) that moves to CPU + output_processor = DataProcessorPipeline( + steps=[ + DeviceProcessorStep(device="cpu"), + UnnormalizerProcessorStep(features={ACTION: features[ACTION]}, norm_map=norm_map, stats=stats), + ], + name="test_postprocessor", + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Test data on CPU + observation = {OBS_STATE: torch.randn(10)} + action = torch.randn(5) + transition = create_transition(observation=observation, action=action) + + # Process through input processor + input_result = input_processor(transition) + + # Verify tensors are on GPU and batched + # The result has TransitionKey.OBSERVATION as the key, with observation.state inside + assert input_result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" + assert input_result[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 + assert input_result[TransitionKey.ACTION].device.type == "cuda" + assert input_result[TransitionKey.ACTION].shape[0] == 1 + + # Simulate model output on GPU + model_output = create_transition(action=torch.randn(1, 5).cuda()) + + # Process through output processor + output_result = output_processor(model_output) + + # Verify action is back on CPU and unnormalized + assert output_result[TransitionKey.ACTION].device.type == "cpu" + assert output_result[TransitionKey.ACTION].shape == (1, 5) + + +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available") +def test_mps_float64_compatibility(): + """Test MPS device compatibility with float64 tensors (automatic conversion to float32).""" + processor = DeviceProcessorStep(device="mps") + + # Create tensors with different dtypes, including float64 which MPS doesn't support + observation = { + "observation.float64": torch.randn(5, dtype=torch.float64), # Should be converted to float32 + "observation.float32": torch.randn(5, dtype=torch.float32), # Should remain float32 + "observation.float16": torch.randn(5, dtype=torch.float16), # Should remain float16 + "observation.int64": torch.randint(0, 10, (5,), dtype=torch.int64), # Should remain int64 + "observation.bool": torch.tensor([True, False, True], dtype=torch.bool), # Should remain bool + } + action = torch.randn(3, dtype=torch.float64) # Should be converted to float32 + reward = torch.tensor(1.0, dtype=torch.float64) # Should be converted to float32 + done = torch.tensor(False, dtype=torch.bool) # Should remain bool + truncated = torch.tensor(True, dtype=torch.bool) # Should remain bool + + transition = create_transition( + observation=observation, action=action, reward=reward, done=done, truncated=truncated + ) + + result = processor(transition) + + # Check that all tensors are on MPS device + assert result[TransitionKey.OBSERVATION]["observation.float64"].device.type == "mps" + assert result[TransitionKey.OBSERVATION]["observation.float32"].device.type == "mps" + assert result[TransitionKey.OBSERVATION]["observation.float16"].device.type == "mps" + assert result[TransitionKey.OBSERVATION]["observation.int64"].device.type == "mps" + assert result[TransitionKey.OBSERVATION]["observation.bool"].device.type == "mps" + assert result[TransitionKey.ACTION].device.type == "mps" + assert result[TransitionKey.REWARD].device.type == "mps" + assert result[TransitionKey.DONE].device.type == "mps" + assert result[TransitionKey.TRUNCATED].device.type == "mps" + + # Check that float64 tensors were automatically converted to float32 + assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float32 + assert result[TransitionKey.ACTION].dtype == torch.float32 + assert result[TransitionKey.REWARD].dtype == torch.float32 + + # Check that other dtypes were preserved + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float32 + assert result[TransitionKey.OBSERVATION]["observation.float16"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION]["observation.int64"].dtype == torch.int64 + assert result[TransitionKey.OBSERVATION]["observation.bool"].dtype == torch.bool + assert result[TransitionKey.DONE].dtype == torch.bool + assert result[TransitionKey.TRUNCATED].dtype == torch.bool + + +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available") +def test_mps_float64_with_complementary_data(): + """Test MPS float64 conversion with complementary_data tensors.""" + processor = DeviceProcessorStep(device="mps") + + # Create complementary_data with float64 tensors + complementary_data = { + "task": ["pick_object"], + "index": torch.tensor([42], dtype=torch.int64), # Should remain int64 + "task_index": torch.tensor([3], dtype=torch.int64), # Should remain int64 + "float64_tensor": torch.tensor([1.5, 2.5], dtype=torch.float64), # Should convert to float32 + "float32_tensor": torch.tensor([3.5], dtype=torch.float32), # Should remain float32 + } + + transition = create_transition( + observation={"observation.state": torch.randn(5, dtype=torch.float64)}, + action=torch.randn(3, dtype=torch.float64), + complementary_data=complementary_data, + ) + + result = processor(transition) + + # Check that all tensors are on MPS device + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "mps" + assert result[TransitionKey.ACTION].device.type == "mps" + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"].device.type == "mps" + assert processed_comp_data["task_index"].device.type == "mps" + assert processed_comp_data["float64_tensor"].device.type == "mps" + assert processed_comp_data["float32_tensor"].device.type == "mps" + + # Check dtype conversions + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted + assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted + assert processed_comp_data["float64_tensor"].dtype == torch.float32 # Converted + assert processed_comp_data["float32_tensor"].dtype == torch.float32 # Unchanged + assert processed_comp_data["index"].dtype == torch.int64 # Unchanged + assert processed_comp_data["task_index"].dtype == torch.int64 # Unchanged + + # Check non-tensor data preserved + assert processed_comp_data["task"] == ["pick_object"] + + +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available") +def test_mps_with_explicit_float_dtype(): + """Test MPS device with explicit float_dtype setting.""" + # Test that explicit float_dtype still works on MPS + processor = DeviceProcessorStep(device="mps", float_dtype="float16") + + observation = { + "observation.float64": torch.randn( + 5, dtype=torch.float64 + ), # First converted to float32, then to float16 + "observation.float32": torch.randn(5, dtype=torch.float32), # Converted to float16 + "observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32), # Should remain int32 + } + action = torch.randn(3, dtype=torch.float64) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check device placement + assert result[TransitionKey.OBSERVATION]["observation.float64"].device.type == "mps" + assert result[TransitionKey.OBSERVATION]["observation.float32"].device.type == "mps" + assert result[TransitionKey.OBSERVATION]["observation.int32"].device.type == "mps" + assert result[TransitionKey.ACTION].device.type == "mps" + + # Check that all float tensors end up as float16 (the target dtype) + assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float16 + assert result[TransitionKey.ACTION].dtype == torch.float16 + + # Check that non-float tensors are preserved + assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32 + + +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available") +def test_mps_serialization(): + """Test that MPS device processor can be serialized and loaded correctly.""" + processor = DeviceProcessorStep(device="mps", float_dtype="float32") + + # Test get_config + config = processor.get_config() + assert config == {"device": "mps", "float_dtype": "float32"} + + # Test state_dict (should be empty) + state = processor.state_dict() + assert state == {} + + # Test load_state_dict (should be no-op) + processor.load_state_dict({}) + assert processor.device == "mps" diff --git a/tests/processor/test_diffusion_processor.py b/tests/processor/test_diffusion_processor.py new file mode 100644 index 00000000..5d280f9c --- /dev/null +++ b/tests/processor/test_diffusion_processor.py @@ -0,0 +1,398 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Diffusion policy processor.""" + +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DataProcessorPipeline, + DeviceProcessorStep, + NormalizerProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +def create_default_config(): + """Create a default Diffusion configuration for testing.""" + config = DiffusionConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(7,)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + config.device = "cpu" + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(7), "std": torch.ones(7)}, + OBS_IMAGE: {}, # No normalization for images + ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)}, + } + + +def test_make_diffusion_processor_basic(): + """Test basic creation of Diffusion processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 4 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_diffusion_processor_with_images(): + """Test Diffusion processor with image observations.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors( + config, + stats, + ) + + # Create test data with images + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is batched + assert processed[OBS_STATE].shape == (1, 7) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 6) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_diffusion_processor_cuda(): + """Test Diffusion processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + # Process through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is back on CPU + assert postprocessed.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_diffusion_processor_accelerate_scenario(): + """Test Diffusion processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU + device = torch.device("cuda:0") + observation = { + OBS_STATE: torch.randn(1, 7).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 6).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_diffusion_processor_multi_gpu(): + """Test Diffusion processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = { + OBS_STATE: torch.randn(1, 7).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 6).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_diffusion_processor_without_stats(): + """Test Diffusion processor creation without dataset statistics.""" + config = create_default_config() + + preprocessor, postprocessor = make_diffusion_pre_post_processors( + config, + dataset_stats=None, + ) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + # Process should still work + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed is not None + + +def test_diffusion_processor_save_and_load(): + """Test saving and loading Diffusion processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save preprocessor + preprocessor.save_pretrained(tmpdir) + + # Load preprocessor + loaded_preprocessor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="policy_preprocessor.json" + ) + + # Test that loaded processor works + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 7) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 6) + + +def test_diffusion_processor_identity_normalization(): + """Test that images with IDENTITY normalization are not normalized.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors( + config, + stats, + ) + + # Create test data + image_value = torch.rand(3, 224, 224) * 255 # Large values + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: image_value.clone(), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Image should not be normalized (IDENTITY mode) + # Just batched + assert torch.allclose(processed[OBS_IMAGE][0], image_value, rtol=1e-5) + + +def test_diffusion_processor_batch_consistency(): + """Test Diffusion processor with different batch sizes.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors( + config, + stats, + ) + + # Test with different batch sizes + for batch_size in [1, 8, 32]: + observation = { + OBS_STATE: torch.randn(batch_size, 7) if batch_size > 1 else torch.randn(7), + OBS_IMAGE: torch.randn(batch_size, 3, 224, 224) if batch_size > 1 else torch.randn(3, 224, 224), + } + action = torch.randn(batch_size, 6) if batch_size > 1 else torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + + # Check correct batch size + expected_batch = batch_size if batch_size > 1 else 1 + assert processed[OBS_STATE].shape[0] == expected_batch + assert processed[OBS_IMAGE].shape[0] == expected_batch + assert processed[TransitionKey.ACTION.value].shape[0] == expected_batch + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_diffusion_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_diffusion_pre_post_processors(config, stats) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + norm_step = step # Now type checker knows this is NormalizerProcessorStep + modified_steps.append( + NormalizerProcessorStep( + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(7, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(6, dtype=torch.float32) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/processor/test_migration_detection.py b/tests/processor/test_migration_detection.py new file mode 100644 index 00000000..6bed8289 --- /dev/null +++ b/tests/processor/test_migration_detection.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for processor migration detection functionality. +""" + +import json +import tempfile +from pathlib import Path + +import pytest + +from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError + + +def test_is_processor_config_valid_configs(): + """Test processor config detection with valid configurations.""" + valid_configs = [ + {"steps": []}, # Empty steps + {"steps": [{"class": "MyClass"}]}, # Class-based step + {"steps": [{"registry_name": "my_step"}]}, # Registry-based step + {"steps": [{"class": "A"}, {"registry_name": "B"}]}, # Mixed + {"name": "Test", "steps": [{"class": "MyClass"}]}, # With name + ] + + for i, config in enumerate(valid_configs): + assert DataProcessorPipeline._is_processor_config(config), ( + f"Valid config {i} should be detected as processor config: {config}" + ) + + +def test_is_processor_config_invalid_configs(): + """Test processor config detection with invalid configurations.""" + invalid_configs = [ + {}, # No steps field + {"steps": "not a list"}, # Steps is not a list + {"steps": [{}]}, # Step without class or registry_name + {"steps": ["not a dict"]}, # Step is not a dict + {"steps": [{"other_field": "value"}]}, # Step with wrong fields + {"other_field": "value"}, # Completely different structure + ] + + for i, config in enumerate(invalid_configs): + assert not DataProcessorPipeline._is_processor_config(config), ( + f"Invalid config {i} should not be detected as processor config: {config}" + ) + + +def test_should_suggest_migration_with_processor_config(): + """Test that migration is NOT suggested when processor config exists.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a valid processor config + processor_config = { + "name": "TestProcessor", + "steps": [ + { + "class": "lerobot.processor.normalize.NormalizeStep", + "config": {"mean": 0.0, "std": 1.0}, + } + ], + } + + with open(tmp_path / "processor.json", "w") as f: + json.dump(processor_config, f) + + # Should NOT suggest migration (processor config exists) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert not result + + +def test_should_suggest_migration_with_empty_processor_config(): + """Test that migration is NOT suggested when empty processor config exists.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create an empty processor config + empty_processor_config = { + "name": "EmptyProcessor", + "steps": [], # Empty steps is valid + } + + with open(tmp_path / "empty_processor.json", "w") as f: + json.dump(empty_processor_config, f) + + # Should NOT suggest migration (processor config exists, even if empty) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert not result + + +def test_should_suggest_migration_with_model_config_only(): + """Test that migration IS suggested when only model config exists.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a model config (like old LeRobot format) + model_config = { + "type": "act", + "input_features": {"observation.state": {"shape": [7]}}, + "output_features": {"action": {"shape": [7]}}, + "hidden_dim": 256, + "n_obs_steps": 1, + "n_action_steps": 1, + } + + with open(tmp_path / "config.json", "w") as f: + json.dump(model_config, f) + + # SHOULD suggest migration (model config exists but no processor) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert result + + +def test_should_suggest_migration_no_json_files(): + """Test that migration is NOT suggested when no JSON files exist.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create some non-JSON files + with open(tmp_path / "model.safetensors", "w") as f: + f.write("fake model data") + + with open(tmp_path / "README.md", "w") as f: + f.write("# Model README") + + # Should NOT suggest migration (no JSON files) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert not result + + +def test_should_suggest_migration_random_json_files(): + """Test that migration IS suggested when JSON files exist but none are processor configs.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create some random JSON file (not a processor config) + random_config = {"some_field": "some_value", "another_field": 123} + + with open(tmp_path / "random.json", "w") as f: + json.dump(random_config, f) + + # SHOULD suggest migration (JSON files exist but none are processor configs) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert result + + +def test_should_suggest_migration_mixed_configs(): + """Test that migration is NOT suggested when processor config exists alongside other configs.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create both a processor config and a model config + processor_config = {"name": "TestProcessor", "steps": [{"registry_name": "normalize_step"}]} + + model_config = {"type": "diffusion", "hidden_dim": 512} + + with open(tmp_path / "processor.json", "w") as f: + json.dump(processor_config, f) + + with open(tmp_path / "config.json", "w") as f: + json.dump(model_config, f) + + # Should NOT suggest migration (processor config exists) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert not result + + +def test_should_suggest_migration_invalid_json(): + """Test that invalid JSON is handled gracefully.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create an invalid JSON file + with open(tmp_path / "invalid.json", "w") as f: + f.write("{ invalid json") + + # Create a valid non-processor config + model_config = {"type": "act"} + with open(tmp_path / "model.json", "w") as f: + json.dump(model_config, f) + + # SHOULD suggest migration (invalid JSON is ignored, but we have non-processor JSON) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert result + + +def test_from_pretrained_multiple_json_files_migration_error(): + """Test that multiple JSON files trigger ProcessorMigrationError.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create multiple non-processor configs + model_config = {"type": "act", "hidden_dim": 128} + train_config = {"batch_size": 32, "lr": 0.001} + + with open(tmp_path / "config.json", "w") as f: + json.dump(model_config, f) + + with open(tmp_path / "train_config.json", "w") as f: + json.dump(train_config, f) + + # Should raise ProcessorMigrationError + with pytest.raises(ProcessorMigrationError) as exc_info: + DataProcessorPipeline.from_pretrained(tmp_path, config_filename="config.json") + + # Check the error details + error = exc_info.value + assert str(tmp_path) in str(error.model_path) + assert "migrate_policy_normalization.py" in error.migration_command + assert "not a valid processor configuration" in error.original_error + + +def test_from_pretrained_no_processor_config_migration_error(): + """Test that missing processor config triggers ProcessorMigrationError.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a model config but no processor + model_config = {"type": "diffusion", "hidden_dim": 256} + + with open(tmp_path / "config.json", "w") as f: + json.dump(model_config, f) + + # Should raise ProcessorMigrationError + with pytest.raises(ProcessorMigrationError) as exc_info: + DataProcessorPipeline.from_pretrained(tmp_path, config_filename="config.json") + + # Check the error details + error = exc_info.value + assert str(tmp_path) in str(error.model_path) + assert "migrate_policy_normalization.py" in error.migration_command + assert "not a valid processor configuration" in error.original_error + + +def test_from_pretrained_valid_processor_no_migration_error(): + """Test that valid processor config does NOT trigger migration error.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a valid processor config + processor_config = { + "name": "TestProcessor", + "steps": [], # Empty is valid + } + + with open(tmp_path / "processor.json", "w") as f: + json.dump(processor_config, f) + + # Should succeed and create pipeline + pipeline = DataProcessorPipeline.from_pretrained(tmp_path, config_filename="processor.json") + assert pipeline is not None + assert pipeline.name == "TestProcessor" + assert len(pipeline) == 0 + + +def test_from_pretrained_no_json_files_no_migration_error(): + """Test that directories with no JSON files don't trigger migration errors.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create some non-JSON files + with open(tmp_path / "model.safetensors", "w") as f: + f.write("fake model data") + + # Should raise FileNotFoundError (config file not found) + with pytest.raises(FileNotFoundError, match="not found in directory"): + DataProcessorPipeline.from_pretrained(tmp_path, config_filename="processor.json") + + +def test_processor_migration_error_creation(): + """Test that ProcessorMigrationError is created correctly.""" + model_path = "/path/to/model" + migration_command = "python migrate.py --path /path/to/model" + original_error = "Config not found" + + error = ProcessorMigrationError(model_path, migration_command, original_error) + + assert error.model_path == model_path + assert error.migration_command == migration_command + assert error.original_error == original_error + assert model_path in str(error) + assert migration_command in str(error) + assert original_error in str(error) + + +def test_processor_migration_error_attributes(): + """Test that ProcessorMigrationError has correct attributes.""" + model_path = Path("/test/path") + migration_command = "python test.py" + original_error = "Test error" + + error = ProcessorMigrationError(model_path, migration_command, original_error) + + # Test that attributes are accessible + assert hasattr(error, "model_path") + assert hasattr(error, "migration_command") + assert hasattr(error, "original_error") + + # Test that it's still an Exception + assert isinstance(error, Exception) + + +def test_migration_suggestion_raises_error(): + """Test that migration suggestion always raises ProcessorMigrationError.""" + with pytest.raises(ProcessorMigrationError) as exc_info: + DataProcessorPipeline._suggest_processor_migration("/test/path", "Test error") + + error = exc_info.value + assert "/test/path" in str(error.model_path) + assert "Test error" in error.original_error + assert "migrate_policy_normalization.py" in error.migration_command + + +def test_migration_error_always_raised_for_invalid_configs(): + """Test that ProcessorMigrationError is always raised for invalid configs.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a model config + model_config = {"type": "test", "param": "value"} + with open(tmp_path / "config.json", "w") as f: + json.dump(model_config, f) + + # Should always raise ProcessorMigrationError + with pytest.raises(ProcessorMigrationError): + DataProcessorPipeline.from_pretrained(tmp_path, config_filename="config.json") diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 26aea56c..5d779191 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -20,27 +20,16 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.processor.normalize_processor import ( - NormalizerProcessor, - UnnormalizerProcessor, - _convert_stats_to_tensors, +from lerobot.processor import ( + DataProcessorPipeline, + IdentityProcessorStep, + NormalizerProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, + hotswap_stats, ) -from lerobot.processor.pipeline import RobotProcessor, TransitionKey - - -def create_transition( - observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None -): - """Helper to create an EnvTransition dictionary.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info, - TransitionKey.COMPLEMENTARY_DATA: complementary_data, - } +from lerobot.processor.converters import create_transition, identity_transition, to_tensor +from lerobot.utils.utils import auto_select_torch_device def test_numpy_conversion(): @@ -50,7 +39,7 @@ def test_numpy_conversion(): "std": np.array([0.2, 0.2, 0.2]), } } - tensor_stats = _convert_stats_to_tensors(stats) + tensor_stats = to_tensor(stats) assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) @@ -65,7 +54,7 @@ def test_tensor_conversion(): "std": torch.tensor([1.0, 1.0]), } } - tensor_stats = _convert_stats_to_tensors(stats) + tensor_stats = to_tensor(stats) assert tensor_stats["action"]["mean"].dtype == torch.float32 assert tensor_stats["action"]["std"].dtype == torch.float32 @@ -78,7 +67,7 @@ def test_scalar_conversion(): "std": 0.1, } } - tensor_stats = _convert_stats_to_tensors(stats) + tensor_stats = to_tensor(stats) assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5)) assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1)) @@ -91,7 +80,7 @@ def test_list_conversion(): "max": [1.0, 1.0, 2.0], } } - tensor_stats = _convert_stats_to_tensors(stats) + tensor_stats = to_tensor(stats) assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0])) assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0])) @@ -104,7 +93,7 @@ def test_unsupported_type(): } } with pytest.raises(TypeError, match="Unsupported type"): - _convert_stats_to_tensors(stats) + to_tensor(stats) # Helper functions to create feature maps and norm maps @@ -122,7 +111,7 @@ def _create_observation_norm_map(): } -# Fixtures for observation normalisation tests using NormalizerProcessor +# Fixtures for observation normalisation tests using NormalizerProcessorStep @pytest.fixture def observation_stats(): return { @@ -139,10 +128,10 @@ def observation_stats(): @pytest.fixture def observation_normalizer(observation_stats): - """Return a NormalizerProcessor that only has observation stats (no action).""" + """Return a NormalizerProcessorStep that only has observation stats (no action).""" features = _create_observation_features() norm_map = _create_observation_norm_map() - return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) + return NormalizerProcessorStep(features=features, norm_map=norm_map, stats=observation_stats) def test_mean_std_normalization(observation_normalizer): @@ -179,8 +168,11 @@ def test_min_max_normalization(observation_normalizer): def test_selective_normalization(observation_stats): features = _create_observation_features() norm_map = _create_observation_norm_map() - normalizer = NormalizerProcessor( - features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"} + normalizer = NormalizerProcessorStep( + features=features, + norm_map=norm_map, + stats=observation_stats, + normalize_observation_keys={"observation.image"}, ) observation = { @@ -202,7 +194,7 @@ def test_selective_normalization(observation_stats): def test_device_compatibility(observation_stats): features = _create_observation_features() norm_map = _create_observation_norm_map() - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=observation_stats) observation = { "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), } @@ -231,7 +223,7 @@ def test_from_lerobot_dataset(): FeatureType.ACTION: NormalizationMode.MEAN_STD, } - normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + normalizer = NormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) # Both observation and action statistics should be present in tensor stats assert "observation.image" in normalizer._tensor_stats @@ -241,11 +233,12 @@ def test_from_lerobot_dataset(): def test_state_dict_save_load(observation_normalizer): # Save state state_dict = observation_normalizer.state_dict() + print("State dict:", state_dict) # Create new normalizer and load state features = _create_observation_features() norm_map = _create_observation_norm_map() - new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) + new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) new_normalizer.load_state_dict(state_dict) # Test that it works the same @@ -296,7 +289,7 @@ def _create_action_norm_map_min_max(): def test_mean_std_unnormalization(action_stats_mean_std): features = _create_action_features() norm_map = _create_action_norm_map_mean_std() - unnormalizer = UnnormalizerProcessor( + unnormalizer = UnnormalizerProcessorStep( features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} ) @@ -314,7 +307,7 @@ def test_mean_std_unnormalization(action_stats_mean_std): def test_min_max_unnormalization(action_stats_min_max): features = _create_action_features() norm_map = _create_action_norm_map_min_max() - unnormalizer = UnnormalizerProcessor( + unnormalizer = UnnormalizerProcessorStep( features=features, norm_map=norm_map, stats={"action": action_stats_min_max} ) @@ -337,14 +330,14 @@ def test_min_max_unnormalization(action_stats_min_max): assert torch.allclose(unnormalized_action, expected) -def test_numpy_action_input(action_stats_mean_std): +def test_tensor_action_input(action_stats_mean_std): features = _create_action_features() norm_map = _create_action_norm_map_mean_std() - unnormalizer = UnnormalizerProcessor( + unnormalizer = UnnormalizerProcessorStep( features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} ) - normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32) + normalized_action = torch.tensor([1.0, -0.5, 2.0], dtype=torch.float32) transition = create_transition(action=normalized_action) unnormalized_transition = unnormalizer(transition) @@ -358,7 +351,7 @@ def test_numpy_action_input(action_stats_mean_std): def test_none_action(action_stats_mean_std): features = _create_action_features() norm_map = _create_action_norm_map_mean_std() - unnormalizer = UnnormalizerProcessor( + unnormalizer = UnnormalizerProcessorStep( features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} ) @@ -374,11 +367,11 @@ def test_action_from_lerobot_dataset(): mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}} features = {"action": PolicyFeature(FeatureType.ACTION, (1,))} norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} - unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + unnormalizer = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) assert "mean" in unnormalizer._tensor_stats["action"] -# Fixtures for NormalizerProcessor tests +# Fixtures for NormalizerProcessorStep tests @pytest.fixture def full_stats(): return { @@ -417,7 +410,7 @@ def _create_full_norm_map(): def normalizer_processor(full_stats): features = _create_full_features() norm_map = _create_full_norm_map() - return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats) + return NormalizerProcessorStep(features=features, norm_map=norm_map, stats=full_stats) def test_combined_normalization(normalizer_processor): @@ -461,11 +454,11 @@ def test_processor_from_lerobot_dataset(full_stats): features = _create_full_features() norm_map = _create_full_norm_map() - processor = NormalizerProcessor.from_lerobot_dataset( - mock_dataset, features, norm_map, normalize_keys={"observation.image"} + processor = NormalizerProcessorStep.from_lerobot_dataset( + mock_dataset, features, norm_map, normalize_observation_keys={"observation.image"} ) - assert processor.normalize_keys == {"observation.image"} + assert processor.normalize_observation_keys == {"observation.image"} assert "observation.image" in processor._tensor_stats assert "action" in processor._tensor_stats @@ -473,13 +466,17 @@ def test_processor_from_lerobot_dataset(full_stats): def test_get_config(full_stats): features = _create_full_features() norm_map = _create_full_norm_map() - processor = NormalizerProcessor( - features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + processor = NormalizerProcessorStep( + features=features, + norm_map=norm_map, + stats=full_stats, + normalize_observation_keys={"observation.image"}, + eps=1e-6, ) config = processor.get_config() expected_config = { - "normalize_keys": ["observation.image"], + "normalize_observation_keys": ["observation.image"], "eps": 1e-6, "features": { "observation.image": {"type": "VISUAL", "shape": (3, 96, 96)}, @@ -497,7 +494,9 @@ def test_get_config(full_stats): def test_integration_with_robot_processor(normalizer_processor): """Test integration with RobotProcessor pipeline""" - robot_processor = RobotProcessor([normalizer_processor]) + robot_processor = DataProcessorPipeline( + [normalizer_processor], to_transition=identity_transition, to_output=identity_transition + ) observation = { "observation.image": torch.tensor([0.7, 0.5, 0.3]), @@ -526,7 +525,7 @@ def test_empty_observation(): stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) transition = create_transition() result = normalizer(transition) @@ -537,7 +536,7 @@ def test_empty_observation(): def test_empty_stats(): features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) observation = {"observation.image": torch.tensor([0.5])} transition = create_transition(observation=observation) @@ -553,7 +552,7 @@ def test_partial_stats(): stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = {"observation.image": torch.tensor([0.7])} transition = create_transition(observation=observation) @@ -568,7 +567,7 @@ def test_missing_action_stats_no_error(): features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + processor = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) # The tensor stats should not contain the 'action' key assert "action" not in processor._tensor_stats @@ -577,19 +576,23 @@ def test_serialization_roundtrip(full_stats): """Test that features and norm_map can be serialized and deserialized correctly.""" features = _create_full_features() norm_map = _create_full_norm_map() - original_processor = NormalizerProcessor( - features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + original_processor = NormalizerProcessorStep( + features=features, + norm_map=norm_map, + stats=full_stats, + normalize_observation_keys={"observation.image"}, + eps=1e-6, ) # Get config (serialization) config = original_processor.get_config() # Create a new processor from the config (deserialization) - new_processor = NormalizerProcessor( + new_processor = NormalizerProcessorStep( features=config["features"], norm_map=config["norm_map"], stats=full_stats, - normalize_keys=set(config["normalize_keys"]), + normalize_observation_keys=set(config["normalize_observation_keys"]), eps=config["eps"], ) @@ -620,9 +623,1299 @@ def test_serialization_roundtrip(full_stats): assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) # Verify features and norm_map are correctly reconstructed - assert new_processor.features.keys() == original_processor.features.keys() - for key in new_processor.features: - assert new_processor.features[key].type == original_processor.features[key].type - assert new_processor.features[key].shape == original_processor.features[key].shape + assert ( + new_processor.transform_features(features).keys() + == original_processor.transform_features(features).keys() + ) + for key in new_processor.transform_features(features): + assert ( + new_processor.transform_features(features)[key].type + == original_processor.transform_features(features)[key].type + ) + assert ( + new_processor.transform_features(features)[key].shape + == original_processor.transform_features(features)[key].shape + ) assert new_processor.norm_map == original_processor.norm_map + + +# Identity normalization tests +def test_identity_normalization_observations(): + """Test that IDENTITY mode skips normalization for observations.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode + FeatureType.STATE: NormalizationMode.MEAN_STD, # Normal mode for comparison + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([1.0, -0.5]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Image should remain unchanged (IDENTITY) + assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + + # State should be normalized (MEAN_STD) + expected_state = (torch.tensor([1.0, -0.5]) - torch.tensor([0.0, 0.0])) / torch.tensor([1.0, 1.0]) + assert torch.allclose(normalized_obs["observation.state"], expected_state) + + +def test_identity_normalization_actions(): + """Test that IDENTITY mode skips normalization for actions.""" + features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY} + stats = {"action": {"mean": [0.0, 0.0], "std": [1.0, 2.0]}} + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + action = torch.tensor([1.0, -0.5]) + transition = create_transition(action=action) + + normalized_transition = normalizer(transition) + + # Action should remain unchanged + assert torch.allclose(normalized_transition[TransitionKey.ACTION], action) + + +def test_identity_unnormalization_observations(): + """Test that IDENTITY mode skips unnormalization for observations.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode + FeatureType.STATE: NormalizationMode.MIN_MAX, # Normal mode for comparison + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "observation.state": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + } + + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.0, -1.0]), # Normalized values in [-1, 1] + } + transition = create_transition(observation=observation) + + unnormalized_transition = unnormalizer(transition) + unnormalized_obs = unnormalized_transition[TransitionKey.OBSERVATION] + + # Image should remain unchanged (IDENTITY) + assert torch.allclose(unnormalized_obs["observation.image"], observation["observation.image"]) + + # State should be unnormalized (MIN_MAX) + # (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = 0.0 + # (-1.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = -1.0 + expected_state = torch.tensor([0.0, -1.0]) + assert torch.allclose(unnormalized_obs["observation.state"], expected_state) + + +def test_identity_unnormalization_actions(): + """Test that IDENTITY mode skips unnormalization for actions.""" + features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY} + stats = {"action": {"min": [-1.0, -2.0], "max": [1.0, 2.0]}} + + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + action = torch.tensor([0.5, -0.8]) # Normalized values + transition = create_transition(action=action) + + unnormalized_transition = unnormalizer(transition) + + # Action should remain unchanged + assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action) + + +def test_identity_with_missing_stats(): + """Test that IDENTITY mode works even when stats are missing.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.IDENTITY, + } + stats = {} # No stats provided + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + action = torch.tensor([1.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + # Both should work without errors and return unchanged data + normalized_transition = normalizer(transition) + unnormalized_transition = unnormalizer(transition) + + assert torch.allclose( + normalized_transition[TransitionKey.OBSERVATION]["observation.image"], + observation["observation.image"], + ) + assert torch.allclose(normalized_transition[TransitionKey.ACTION], action) + assert torch.allclose( + unnormalized_transition[TransitionKey.OBSERVATION]["observation.image"], + observation["observation.image"], + ) + assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action) + + +def test_identity_mixed_with_other_modes(): + """Test IDENTITY mode mixed with other normalization modes.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored + "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([1.0, -0.5]), + } + action = torch.tensor([0.5, 0.0]) + transition = create_transition(observation=observation, action=action) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + normalized_action = normalized_transition[TransitionKey.ACTION] + + # Image should remain unchanged (IDENTITY) + assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + + # State should be normalized (MEAN_STD) + expected_state = torch.tensor([1.0, -0.5]) # (x - 0) / 1 = x + assert torch.allclose(normalized_obs["observation.state"], expected_state) + + # Action should be normalized (MIN_MAX) to [-1, 1] + # 2 * (0.5 - (-1)) / (1 - (-1)) - 1 = 2 * 1.5 / 2 - 1 = 0.5 + # 2 * (0.0 - (-1)) / (1 - (-1)) - 1 = 2 * 1.0 / 2 - 1 = 0.0 + expected_action = torch.tensor([0.5, 0.0]) + assert torch.allclose(normalized_action, expected_action) + + +def test_identity_defaults_when_not_in_norm_map(): + """Test that IDENTITY is used as default when feature type not in norm_map.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + # VISUAL not specified, should default to IDENTITY + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([1.0, -0.5]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Image should remain unchanged (defaults to IDENTITY) + assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + + # State should be normalized (explicitly MEAN_STD) + expected_state = torch.tensor([1.0, -0.5]) + assert torch.allclose(normalized_obs["observation.state"], expected_state) + + +def test_identity_roundtrip(): + """Test that IDENTITY normalization and unnormalization are true inverses.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.IDENTITY, + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + original_observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + original_action = torch.tensor([0.5, -0.2]) + original_transition = create_transition(observation=original_observation, action=original_action) + + # Normalize then unnormalize + normalized = normalizer(original_transition) + roundtrip = unnormalizer(normalized) + + # Should be identical to original + assert torch.allclose( + roundtrip[TransitionKey.OBSERVATION]["observation.image"], original_observation["observation.image"] + ) + assert torch.allclose(roundtrip[TransitionKey.ACTION], original_action) + + +def test_identity_config_serialization(): + """Test that IDENTITY mode is properly saved and loaded in config.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + stats = { + "observation.image": {"mean": [0.5], "std": [0.2]}, + "action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + # Get config + config = normalizer.get_config() + + # Check that IDENTITY is properly serialized + assert config["norm_map"]["VISUAL"] == "IDENTITY" + assert config["norm_map"]["ACTION"] == "MEAN_STD" + + # Create new processor from config (simulating load) + new_normalizer = NormalizerProcessorStep( + features=config["features"], + norm_map=config["norm_map"], + stats=stats, + eps=config["eps"], + ) + + # Test that both work the same way + observation = {"observation.image": torch.tensor([0.7])} + action = torch.tensor([1.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + result1 = normalizer(transition) + result2 = new_normalizer(transition) + + # Results should be identical + assert torch.allclose( + result1[TransitionKey.OBSERVATION]["observation.image"], + result2[TransitionKey.OBSERVATION]["observation.image"], + ) + assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) + + +# def test_unsupported_normalization_mode_error(): +# """Test that unsupported normalization modes raise appropriate errors.""" +# features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))} + +# # Create an invalid norm_map (this would never happen in practice, but tests error handling) +# from enum import Enum + +# class InvalidMode(str, Enum): +# INVALID = "INVALID" + +# # We can't actually pass an invalid enum to the processor due to type checking, +# # but we can test the error by manipulating the norm_map after creation +# norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} +# stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}} + +# normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + +# # Manually inject an invalid mode to test error handling +# normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE" + +# observation = {"observation.state": torch.tensor([1.0, -0.5])} +# transition = create_transition(observation=observation) + +# with pytest.raises(ValueError, match="Unsupported normalization mode"): +# normalizer(transition) + + +def test_hotswap_stats_basic_functionality(): + """Test that hotswap_stats correctly updates stats in normalizer/unnormalizer steps.""" + # Create initial stats + initial_stats = { + "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + } + + # Create new stats for hotswapping + new_stats = { + "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + } + + # Create features and norm_map + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + # Create processors + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + identity = IdentityProcessorStep() + + # Create robot processor + robot_processor = DataProcessorPipeline(steps=[normalizer, unnormalizer, identity]) + + # Hotswap stats + new_processor = hotswap_stats(robot_processor, new_stats) + + # Check that normalizer and unnormalizer have new stats + assert new_processor.steps[0].stats == new_stats + assert new_processor.steps[1].stats == new_stats + + # Check that tensor stats are updated correctly + expected_tensor_stats = to_tensor(new_stats) + for key in expected_tensor_stats: + for stat_name in expected_tensor_stats[key]: + torch.testing.assert_close( + new_processor.steps[0]._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name] + ) + torch.testing.assert_close( + new_processor.steps[1]._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name] + ) + + +def test_hotswap_stats_deep_copy(): + """Test that hotswap_stats creates a deep copy and doesn't modify the original processor.""" + initial_stats = { + "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + } + + new_stats = { + "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + } + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + original_processor = DataProcessorPipeline(steps=[normalizer]) + + # Store reference to original stats + original_stats_reference = original_processor.steps[0].stats + original_tensor_stats_reference = original_processor.steps[0]._tensor_stats + + # Hotswap stats + new_processor = hotswap_stats(original_processor, new_stats) + + # Original processor should be unchanged + assert original_processor.steps[0].stats is original_stats_reference + assert original_processor.steps[0]._tensor_stats is original_tensor_stats_reference + assert original_processor.steps[0].stats == initial_stats + + # New processor should have new stats + assert new_processor.steps[0].stats == new_stats + assert new_processor.steps[0].stats is not original_stats_reference + + # Processors should be different objects + assert new_processor is not original_processor + assert new_processor.steps[0] is not original_processor.steps[0] + + +def test_hotswap_stats_only_affects_normalizer_steps(): + """Test that hotswap_stats only modifies NormalizerProcessorStep and UnnormalizerProcessorStep steps.""" + stats = { + "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + } + + new_stats = { + "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + } + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + + # Create mixed steps + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + identity = IdentityProcessorStep() + + robot_processor = DataProcessorPipeline(steps=[normalizer, identity, unnormalizer]) + + # Hotswap stats + new_processor = hotswap_stats(robot_processor, new_stats) + + # Check that only normalizer and unnormalizer steps are affected + assert new_processor.steps[0].stats == new_stats # normalizer + assert new_processor.steps[2].stats == new_stats # unnormalizer + + # Identity processor should remain unchanged (and it doesn't have stats attribute) + assert not hasattr(new_processor.steps[1], "stats") + + +def test_hotswap_stats_empty_stats(): + """Test hotswap_stats with empty stats dictionary.""" + initial_stats = { + "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + } + + empty_stats = {} + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + } + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + robot_processor = DataProcessorPipeline(steps=[normalizer]) + + # Hotswap with empty stats + new_processor = hotswap_stats(robot_processor, empty_stats) + + # Should update to empty stats + assert new_processor.steps[0].stats == empty_stats + assert new_processor.steps[0]._tensor_stats == {} + + +def test_hotswap_stats_no_normalizer_steps(): + """Test hotswap_stats with a processor that has no normalizer/unnormalizer steps.""" + stats = { + "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + } + + # Create processor with only identity steps + robot_processor = DataProcessorPipeline(steps=[IdentityProcessorStep(), IdentityProcessorStep()]) + + # Hotswap stats - should work without error + new_processor = hotswap_stats(robot_processor, stats) + + # Should return a different object (deep copy) + assert new_processor is not robot_processor + + # Steps should be deep copied but unchanged + assert len(new_processor.steps) == len(robot_processor.steps) + for i, step in enumerate(new_processor.steps): + assert step is not robot_processor.steps[i] # Different objects + assert isinstance(step, type(robot_processor.steps[i])) # Same type + + +def test_hotswap_stats_preserves_other_attributes(): + """Test that hotswap_stats preserves other processor attributes like features and norm_map.""" + initial_stats = { + "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + } + + new_stats = { + "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + } + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalize_observation_keys = {"observation.image"} + eps = 1e-6 + + normalizer = NormalizerProcessorStep( + features=features, + norm_map=norm_map, + stats=initial_stats, + normalize_observation_keys=normalize_observation_keys, + eps=eps, + ) + robot_processor = DataProcessorPipeline(steps=[normalizer]) + + # Hotswap stats + new_processor = hotswap_stats(robot_processor, new_stats) + + # Check that other attributes are preserved + new_normalizer = new_processor.steps[0] + assert new_normalizer.features == features + assert new_normalizer.norm_map == norm_map + assert new_normalizer.normalize_observation_keys == normalize_observation_keys + assert new_normalizer.eps == eps + + # But stats should be updated + assert new_normalizer.stats == new_stats + + +def test_hotswap_stats_multiple_normalizer_types(): + """Test hotswap_stats with multiple normalizer and unnormalizer steps.""" + initial_stats = { + "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + "action": {"min": np.array([-1.0]), "max": np.array([1.0])}, + } + + new_stats = { + "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + "action": {"min": np.array([-2.0]), "max": np.array([2.0])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + + # Create multiple normalizers and unnormalizers + normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + normalizer2 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + unnormalizer1 = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + unnormalizer2 = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + + robot_processor = DataProcessorPipeline(steps=[normalizer1, unnormalizer1, normalizer2, unnormalizer2]) + + # Hotswap stats + new_processor = hotswap_stats(robot_processor, new_stats) + + # All normalizer/unnormalizer steps should be updated + for step in new_processor.steps: + assert step.stats == new_stats + + # Check tensor stats conversion + expected_tensor_stats = to_tensor(new_stats) + for key in expected_tensor_stats: + for stat_name in expected_tensor_stats[key]: + torch.testing.assert_close( + step._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name] + ) + + +def test_hotswap_stats_with_different_data_types(): + """Test hotswap_stats with various data types in stats.""" + initial_stats = { + "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + } + + # New stats with different data types (int, float, list, tuple) + new_stats = { + "observation.image": { + "mean": [0.3, 0.4, 0.5], # list + "std": (0.1, 0.2, 0.3), # tuple + "min": 0, # int + "max": 1.0, # float + }, + "action": { + "mean": np.array([0.1, 0.2]), # numpy array + "std": torch.tensor([0.5, 0.6]), # torch tensor + }, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + robot_processor = DataProcessorPipeline(steps=[normalizer]) + + # Hotswap stats + new_processor = hotswap_stats(robot_processor, new_stats) + + # Check that stats are updated + assert new_processor.steps[0].stats == new_stats + + # Check that tensor conversion worked correctly + tensor_stats = new_processor.steps[0]._tensor_stats + assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) + assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) + assert isinstance(tensor_stats["observation.image"]["min"], torch.Tensor) + assert isinstance(tensor_stats["observation.image"]["max"], torch.Tensor) + assert isinstance(tensor_stats["action"]["mean"], torch.Tensor) + assert isinstance(tensor_stats["action"]["std"], torch.Tensor) + + # Check values + torch.testing.assert_close(tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.4, 0.5])) + torch.testing.assert_close(tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2, 0.3])) + torch.testing.assert_close(tensor_stats["observation.image"]["min"], torch.tensor(0.0)) + torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0)) + + +def test_hotswap_stats_functional_test(): + """Test that hotswapped processor actually works functionally.""" + # Create test data + observation = { + "observation.image": torch.tensor([[[0.6, 0.7], [0.8, 0.9]], [[0.5, 0.6], [0.7, 0.8]]]), + } + action = torch.tensor([0.5, -0.5]) + transition = create_transition(observation=observation, action=action) + + # Initial stats + initial_stats = { + "observation.image": {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])}, + "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + } + + # New stats + new_stats = { + "observation.image": {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])}, + "action": {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + # Create original processor + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + original_processor = DataProcessorPipeline( + steps=[normalizer], to_transition=identity_transition, to_output=identity_transition + ) + + # Process with original stats + original_result = original_processor(transition) + + # Hotswap stats + new_processor = hotswap_stats(original_processor, new_stats) + + # Process with new stats + new_result = new_processor(transition) + + # Results should be different since normalization changed + assert not torch.allclose( + original_result["observation"]["observation.image"], + new_result["observation"]["observation.image"], + rtol=1e-3, + atol=1e-3, + ) + assert not torch.allclose(original_result["action"], new_result["action"], rtol=1e-3, atol=1e-3) + + # Verify that the new processor is actually using the new stats by checking internal state + assert new_processor.steps[0].stats == new_stats + assert torch.allclose( + new_processor.steps[0]._tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.2]) + ) + assert torch.allclose( + new_processor.steps[0]._tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2]) + ) + assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["mean"], torch.tensor([0.1, -0.1])) + assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["std"], torch.tensor([0.5, 0.5])) + + # Test that normalization actually happens (output should not equal input) + assert not torch.allclose( + new_result["observation"]["observation.image"], observation["observation.image"] + ) + assert not torch.allclose(new_result["action"], action) + + +def test_zero_std_uses_eps(): + """When std == 0, (x-mean)/(std+eps) is well-defined; x==mean should map to 0.""" + features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = {"observation.state": {"mean": np.array([0.5]), "std": np.array([0.0])}} + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6) + + observation = {"observation.state": torch.tensor([0.5])} # equals mean + out = normalizer(create_transition(observation=observation)) + assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([0.0])) + + +def test_min_equals_max_maps_to_minus_one(): + """When min == max, MIN_MAX path maps to -1 after [-1,1] scaling for x==min.""" + features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + norm_map = {FeatureType.STATE: NormalizationMode.MIN_MAX} + stats = {"observation.state": {"min": np.array([2.0]), "max": np.array([2.0])}} + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6) + + observation = {"observation.state": torch.tensor([2.0])} + out = normalizer(create_transition(observation=observation)) + assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([-1.0])) + + +def test_action_normalized_despite_normalize_observation_keys(): + """Action normalization is independent of normalize_observation_keys filter for observations.""" + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (1,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD} + stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} + normalizer = NormalizerProcessorStep( + features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={"observation.state"} + ) + + transition = create_transition( + observation={"observation.state": torch.tensor([3.0])}, action=torch.tensor([3.0, 3.0]) + ) + out = normalizer(transition) + # (3-1)/2 = 1.0 ; (3-(-1))/4 = 1.0 + assert torch.allclose(out[TransitionKey.ACTION], torch.tensor([1.0, 1.0])) + + +def test_unnormalize_observations_mean_std_and_min_max(): + features = { + "observation.ms": PolicyFeature(FeatureType.STATE, (2,)), + "observation.mm": PolicyFeature(FeatureType.STATE, (2,)), + } + # Build two processors: one mean/std and one min/max + unnorm_ms = UnnormalizerProcessorStep( + features={"observation.ms": features["observation.ms"]}, + norm_map={FeatureType.STATE: NormalizationMode.MEAN_STD}, + stats={"observation.ms": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}, + ) + unnorm_mm = UnnormalizerProcessorStep( + features={"observation.mm": features["observation.mm"]}, + norm_map={FeatureType.STATE: NormalizationMode.MIN_MAX}, + stats={"observation.mm": {"min": np.array([0.0, -2.0]), "max": np.array([2.0, 2.0])}}, + ) + + tr = create_transition( + observation={ + "observation.ms": torch.tensor([0.0, 0.0]), # → mean + "observation.mm": torch.tensor([0.0, 0.0]), # → mid-point + } + ) + out_ms = unnorm_ms(tr)[TransitionKey.OBSERVATION]["observation.ms"] + out_mm = unnorm_mm(tr)[TransitionKey.OBSERVATION]["observation.mm"] + assert torch.allclose(out_ms, torch.tensor([1.0, -1.0])) + assert torch.allclose(out_mm, torch.tensor([1.0, 0.0])) # mid of [0,2] and [-2,2] + + +def test_unknown_observation_keys_ignored(): + features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}} + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + obs = {"observation.state": torch.tensor([1.0]), "observation.unknown": torch.tensor([5.0])} + tr = create_transition(observation=obs) + out = normalizer(tr) + + # Unknown key should pass through unchanged and not be tracked + assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.unknown"], obs["observation.unknown"]) + + +def test_batched_action_normalization(): + features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} + stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + actions = torch.tensor([[1.0, -1.0], [3.0, 3.0]]) # first equals mean → zeros; second → [1, 1] + out = normalizer(create_transition(action=actions))[TransitionKey.ACTION] + expected = torch.tensor([[0.0, 0.0], [1.0, 1.0]]) + assert torch.allclose(out, expected) + + +def test_complementary_data_preservation(): + features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}} + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + comp = {"existing": 123} + tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp) + out = normalizer(tr) + new_comp = out[TransitionKey.COMPLEMENTARY_DATA] + assert new_comp["existing"] == 123 + + +def test_roundtrip_normalize_unnormalize_non_identity(): + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MIN_MAX} + stats = { + "observation.state": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}, + "action": {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])}, + } + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + # Add a time dimension in action for broadcasting check (B,T,D) + obs = {"observation.state": torch.tensor([[3.0, 3.0], [1.0, -1.0]])} + act = torch.tensor([[[0.0, -1.0], [1.0, 1.0]]]) # shape (1,2,2) already in [-1,1] + + tr = create_transition(observation=obs, action=act) + out = unnormalizer(normalizer(tr)) + + assert torch.allclose( + out[TransitionKey.OBSERVATION]["observation.state"], obs["observation.state"], atol=1e-5 + ) + assert torch.allclose(out[TransitionKey.ACTION], act, atol=1e-5) + + +def test_dtype_adaptation_bfloat16_input_float32_normalizer(): + """Test automatic dtype adaptation: NormalizerProcessor(float32) adapts to bfloat16 input → bfloat16 output""" + features = {"observation.state": PolicyFeature(FeatureType.STATE, (5,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = { + "observation.state": { + "mean": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), + "std": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), + } + } + + # Create normalizer configured with float32 dtype + normalizer = NormalizerProcessorStep( + features=features, norm_map=norm_map, stats=stats, dtype=torch.float32 + ) + + # Verify initial configuration + assert normalizer.dtype == torch.float32 + for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + assert stat_tensor.dtype == torch.float32 + + # Create bfloat16 input tensor + observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)} + transition = create_transition(observation=observation) + + # Process the transition + result = normalizer(transition) + + # Verify that: + # 1. Stats were automatically adapted to bfloat16 + assert normalizer.dtype == torch.bfloat16 + for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + assert stat_tensor.dtype == torch.bfloat16 + + # 2. Output is in bfloat16 + output_tensor = result[TransitionKey.OBSERVATION]["observation.state"] + assert output_tensor.dtype == torch.bfloat16 + + # 3. Normalization was applied correctly (mean should be close to original - mean) / std + expected = ( + torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16) + - torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], dtype=torch.bfloat16) + ) / torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.bfloat16) + assert torch.allclose(output_tensor, expected, atol=1e-2) # bfloat16 has lower precision + + +def test_stats_override_preservation_in_load_state_dict(): + """ + Test that explicitly provided stats are preserved during load_state_dict. + + This tests the fix for the bug where stats provided via overrides were + being overwritten when load_state_dict was called. + """ + # Create original stats + original_stats = { + "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + } + + # Create override stats (what user wants to use) + override_stats = { + "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + # Create a normalizer with original stats and save its state + original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats) + saved_state_dict = original_normalizer.state_dict() + + # Create a new normalizer with override stats (simulating from_pretrained with overrides) + override_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=override_stats) + + # Verify that the override stats are initially set correctly + assert set(override_normalizer.stats.keys()) == set(override_stats.keys()) + for key in override_stats: + assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys()) + for stat_name in override_stats[key]: + np.testing.assert_array_equal( + override_normalizer.stats[key][stat_name], override_stats[key][stat_name] + ) + assert override_normalizer._stats_explicitly_provided is True + + # This is the critical test: load_state_dict should NOT overwrite the override stats + override_normalizer.load_state_dict(saved_state_dict) + + # After loading state_dict, stats should still be the override stats, not the original stats + # Check that loaded stats match override stats + assert set(override_normalizer.stats.keys()) == set(override_stats.keys()) + for key in override_stats: + assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys()) + for stat_name in override_stats[key]: + np.testing.assert_array_equal( + override_normalizer.stats[key][stat_name], override_stats[key][stat_name] + ) + # Compare individual arrays to avoid numpy array comparison ambiguity + for key in override_stats: + for stat_name in override_stats[key]: + assert not np.array_equal( + override_normalizer.stats[key][stat_name], original_stats[key][stat_name] + ), f"Stats for {key}.{stat_name} should not match original stats" + + # Verify that _tensor_stats are also correctly set to match the override stats + expected_tensor_stats = to_tensor(override_stats) + for key in expected_tensor_stats: + for stat_name in expected_tensor_stats[key]: + if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor): + torch.testing.assert_close( + override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name] + ) + + +def test_stats_without_override_loads_normally(): + """ + Test that when stats are not explicitly provided (normal case), + load_state_dict works as before. + """ + original_stats = { + "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + # Create a normalizer with original stats and save its state + original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats) + saved_state_dict = original_normalizer.state_dict() + + # Create a new normalizer without stats (simulating normal from_pretrained) + new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) + + # Verify that stats are not explicitly provided + assert new_normalizer._stats_explicitly_provided is False + + # Load state dict - this should work normally and load the saved stats + new_normalizer.load_state_dict(saved_state_dict) + + # Stats should now match the original stats (normal behavior) + # Check that all keys and values match + assert set(new_normalizer.stats.keys()) == set(original_stats.keys()) + for key in original_stats: + assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys()) + for stat_name in original_stats[key]: + np.testing.assert_allclose( + new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6 + ) + + +def test_stats_explicit_provided_flag_detection(): + """Test that the _stats_explicitly_provided flag is set correctly in different scenarios.""" + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + } + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + + # Test 1: Explicitly provided stats (non-empty dict) + stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + assert normalizer1._stats_explicitly_provided is True + + # Test 2: Empty stats dict + normalizer2 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) + assert normalizer2._stats_explicitly_provided is False + + # Test 3: None stats + normalizer3 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=None) + assert normalizer3._stats_explicitly_provided is False + + # Test 4: Stats not provided (defaults to None) + normalizer4 = NormalizerProcessorStep(features=features, norm_map=norm_map) + assert normalizer4._stats_explicitly_provided is False + + +def test_pipeline_from_pretrained_with_stats_overrides(): + """ + Test the actual use case: DataProcessorPipeline.from_pretrained with stat overrides. + + This is an integration test that verifies the fix works in the real scenario + where users provide stat overrides when loading a pipeline. + """ + import tempfile + + # Create test data + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + original_stats = { + "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + } + + override_stats = { + "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + } + + # Create and save a pipeline with the original stats + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats) + identity = IdentityProcessorStep() + original_pipeline = DataProcessorPipeline(steps=[normalizer, identity], name="test_pipeline") + + with tempfile.TemporaryDirectory() as temp_dir: + # Save the pipeline + original_pipeline.save_pretrained(temp_dir) + + # Load the pipeline with stat overrides + overrides = {"normalizer_processor": {"stats": override_stats}} + + loaded_pipeline = DataProcessorPipeline.from_pretrained( + temp_dir, config_filename="test_pipeline.json", overrides=overrides + ) + + # The critical test: the loaded pipeline should use override stats, not original stats + loaded_normalizer = loaded_pipeline.steps[0] + assert isinstance(loaded_normalizer, NormalizerProcessorStep) + + # Check that loaded stats match override stats + assert set(loaded_normalizer.stats.keys()) == set(override_stats.keys()) + for key in override_stats: + assert set(loaded_normalizer.stats[key].keys()) == set(override_stats[key].keys()) + for stat_name in override_stats[key]: + np.testing.assert_array_equal( + loaded_normalizer.stats[key][stat_name], override_stats[key][stat_name] + ) + + # Verify stats don't match original stats + for key in override_stats: + for stat_name in override_stats[key]: + assert not np.array_equal( + loaded_normalizer.stats[key][stat_name], original_stats[key][stat_name] + ), f"Stats for {key}.{stat_name} should not match original stats" + + # Test that the override stats are actually used in processing + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + # Process with override pipeline + override_result = loaded_pipeline(transition) + + # Create a reference pipeline with override stats for comparison + reference_normalizer = NormalizerProcessorStep( + features=features, norm_map=norm_map, stats=override_stats + ) + reference_pipeline = DataProcessorPipeline( + steps=[reference_normalizer, identity], + to_transition=identity_transition, + to_output=identity_transition, + ) + _ = reference_pipeline(transition) + + # The critical part was verified above: loaded_normalizer.stats == override_stats + # This confirms that override stats are preserved during load_state_dict. + # Let's just verify the pipeline processes data successfully. + assert "action" in override_result + assert isinstance(override_result["action"], torch.Tensor) + + +def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): + """Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output""" + from lerobot.processor import DeviceProcessorStep + + features = {"observation.state": PolicyFeature(FeatureType.STATE, (3,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = {"observation.state": {"mean": np.array([0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0])}} + + # Create pipeline: DeviceProcessor(bfloat16) → NormalizerProcessor(float32) + device_processor = DeviceProcessorStep(device=str(auto_select_torch_device()), float_dtype="bfloat16") + normalizer = NormalizerProcessorStep( + features=features, norm_map=norm_map, stats=stats, dtype=torch.float32 + ) + + # Verify initial normalizer configuration + assert normalizer.dtype == torch.float32 + + # Create CPU input + observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)} + transition = create_transition(observation=observation) + + # Step 1: DeviceProcessor converts to bfloat16 + moves to CUDA + processed_1 = device_processor(transition) + intermediate_tensor = processed_1[TransitionKey.OBSERVATION]["observation.state"] + assert intermediate_tensor.dtype == torch.bfloat16 + assert intermediate_tensor.device.type == str(auto_select_torch_device()) + + # Step 2: NormalizerProcessor receives bfloat16 input and adapts + final_result = normalizer(processed_1) + final_tensor = final_result[TransitionKey.OBSERVATION]["observation.state"] + + # Verify final output is bfloat16 (automatic adaptation worked) + assert final_tensor.dtype == torch.bfloat16 + assert final_tensor.device.type == str(auto_select_torch_device()) + + # Verify normalizer adapted its internal state + assert normalizer.dtype == torch.bfloat16 + for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + assert stat_tensor.dtype == torch.bfloat16 + assert stat_tensor.device.type == str(auto_select_torch_device()) + + +def test_stats_reconstruction_after_load_state_dict(): + """ + Test that stats dict is properly reconstructed from _tensor_stats after loading. + + This test ensures the bug where stats became empty after loading is fixed. + The bug occurred when: + 1. Only _tensor_stats were saved via state_dict() + 2. stats field became empty {} after loading + 3. Calling to() method or hotswap_stats would fail because they depend on self.stats + """ + + # Create normalizer with stats + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + stats = { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "observation.state": { + "min": np.array([0.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + "action": { + "mean": np.array([0.0, 0.0]), + "std": np.array([1.0, 2.0]), + }, + } + + original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + # Save state dict (simulating save/load) + state_dict = original_normalizer.state_dict() + + # Create new normalizer with empty stats (simulating load) + new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) + + # Before fix: this would cause stats to remain empty + new_normalizer.load_state_dict(state_dict) + + # Verify that stats dict is properly reconstructed from _tensor_stats + assert new_normalizer.stats is not None + assert new_normalizer.stats != {} + + # Check that all expected keys are present + assert "observation.image" in new_normalizer.stats + assert "observation.state" in new_normalizer.stats + assert "action" in new_normalizer.stats + + # Check that values are correct (converted back from tensors) + np.testing.assert_allclose(new_normalizer.stats["observation.image"]["mean"], [0.5, 0.5, 0.5]) + np.testing.assert_allclose(new_normalizer.stats["observation.image"]["std"], [0.2, 0.2, 0.2]) + np.testing.assert_allclose(new_normalizer.stats["observation.state"]["min"], [0.0, -1.0]) + np.testing.assert_allclose(new_normalizer.stats["observation.state"]["max"], [1.0, 1.0]) + np.testing.assert_allclose(new_normalizer.stats["action"]["mean"], [0.0, 0.0]) + np.testing.assert_allclose(new_normalizer.stats["action"]["std"], [1.0, 2.0]) + + # Test that methods that depend on self.stats work correctly after loading + # This would fail before the bug fix because self.stats was empty + + # Test 1: to() method should work without crashing + try: + new_normalizer.to(device="cpu", dtype=torch.float32) + # If we reach here, the bug is fixed + except (KeyError, AttributeError) as e: + pytest.fail(f"to() method failed after loading state_dict: {e}") + + # Test 2: hotswap_stats should work + new_stats = { + "observation.image": {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]}, + "observation.state": {"min": [-1.0, -2.0], "max": [2.0, 2.0]}, + "action": {"mean": [0.1, 0.1], "std": [0.5, 0.5]}, + } + + pipeline = DataProcessorPipeline([new_normalizer]) + try: + new_pipeline = hotswap_stats(pipeline, new_stats) + # If we reach here, hotswap_stats worked correctly + assert new_pipeline.steps[0].stats == new_stats + except (KeyError, AttributeError) as e: + pytest.fail(f"hotswap_stats failed after loading state_dict: {e}") + + # Test 3: The normalizer should work functionally the same as the original + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + original_result = original_normalizer(transition) + new_result = new_normalizer(transition) + + # Results should be identical (within floating point precision) + torch.testing.assert_close( + original_result[TransitionKey.OBSERVATION]["observation.image"], + new_result[TransitionKey.OBSERVATION]["observation.image"], + ) + torch.testing.assert_close( + original_result[TransitionKey.OBSERVATION]["observation.state"], + new_result[TransitionKey.OBSERVATION]["observation.state"], + ) + torch.testing.assert_close(original_result[TransitionKey.ACTION], new_result[TransitionKey.ACTION]) diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index e48b6bc0..57f32482 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -18,31 +18,16 @@ import numpy as np import pytest import torch -from lerobot.configs.types import FeatureType +from lerobot.configs.types import FeatureType, PipelineFeatureType from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE -from lerobot.processor import VanillaObservationProcessor -from lerobot.processor.pipeline import TransitionKey +from lerobot.processor import TransitionKey, VanillaObservationProcessorStep +from lerobot.processor.converters import create_transition from tests.conftest import assert_contract_is_typed -def create_transition( - observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None -): - """Helper to create an EnvTransition dictionary.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info, - TransitionKey.COMPLEMENTARY_DATA: complementary_data, - } - - def test_process_single_image(): """Test processing a single image.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Create a mock image (H, W, C) format, uint8 image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) @@ -68,7 +53,7 @@ def test_process_single_image(): def test_process_image_dict(): """Test processing multiple images in a dictionary.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Create mock images image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) @@ -91,7 +76,7 @@ def test_process_image_dict(): def test_process_batched_image(): """Test processing already batched images.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Create a batched image (B, H, W, C) image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8) @@ -108,7 +93,7 @@ def test_process_batched_image(): def test_invalid_image_format(): """Test error handling for invalid image formats.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Test wrong channel order (channels first) image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8) @@ -121,7 +106,7 @@ def test_invalid_image_format(): def test_invalid_image_dtype(): """Test error handling for invalid image dtype.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Test wrong dtype image = np.random.rand(64, 64, 3).astype(np.float32) @@ -134,7 +119,7 @@ def test_invalid_image_dtype(): def test_no_pixels_in_observation(): """Test processor when no pixels are in observation.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() observation = {"other_data": np.array([1, 2, 3])} transition = create_transition(observation=observation) @@ -149,9 +134,9 @@ def test_no_pixels_in_observation(): def test_none_observation(): """Test processor with None observation.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() - transition = create_transition() + transition = create_transition(observation={}) result = processor(transition) assert result == transition @@ -159,7 +144,7 @@ def test_none_observation(): def test_serialization_methods(): """Test serialization methods.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Test get_config config = processor.get_config() @@ -178,7 +163,7 @@ def test_serialization_methods(): def test_process_environment_state(): """Test processing environment_state.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) observation = {"environment_state": env_state} @@ -199,7 +184,7 @@ def test_process_environment_state(): def test_process_agent_pos(): """Test processing agent_pos.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) observation = {"agent_pos": agent_pos} @@ -220,7 +205,7 @@ def test_process_agent_pos(): def test_process_batched_states(): """Test processing already batched states.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32) @@ -238,7 +223,7 @@ def test_process_batched_states(): def test_process_both_states(): """Test processing both environment_state and agent_pos.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() env_state = np.array([1.0, 2.0], dtype=np.float32) agent_pos = np.array([0.5, -0.5], dtype=np.float32) @@ -263,7 +248,7 @@ def test_process_both_states(): def test_no_states_in_observation(): """Test processor when no states are in observation.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() observation = {"other_data": np.array([1, 2, 3])} transition = create_transition(observation=observation) @@ -277,7 +262,7 @@ def test_no_states_in_observation(): def test_complete_observation_processing(): """Test processing a complete observation with both images and states.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Create mock data image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) @@ -314,7 +299,7 @@ def test_complete_observation_processing(): def test_image_only_processing(): """Test processing observation with only images.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) observation = {"pixels": image} @@ -329,7 +314,7 @@ def test_image_only_processing(): def test_state_only_processing(): """Test processing observation with only states.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() agent_pos = np.array([1.0, 2.0], dtype=np.float32) observation = {"agent_pos": agent_pos} @@ -344,7 +329,7 @@ def test_state_only_processing(): def test_empty_observation(): """Test processing empty observation.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() observation = {} transition = create_transition(observation=observation) @@ -360,7 +345,7 @@ def test_equivalent_to_original_function(): # Import the original function for comparison from lerobot.envs.utils import preprocess_observation - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Create test data similar to what the original function expects image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) @@ -387,7 +372,7 @@ def test_equivalent_with_image_dict(): """Test equivalence with dictionary of images.""" from lerobot.envs.utils import preprocess_observation - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Create test data with multiple cameras image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) @@ -410,77 +395,133 @@ def test_equivalent_with_image_dict(): torch.testing.assert_close(original_result[key], processor_result[key]) -def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory): - processor = VanillaObservationProcessor() +def test_image_processor_features_pixels_to_image(policy_feature_factory): + processor = VanillaObservationProcessorStep() features = { - "pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "keep": policy_feature_factory(FeatureType.ENV, (1,)), + PipelineFeatureType.OBSERVATION: { + "pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + }, } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) - assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"] - assert "pixels" not in out - assert out["keep"] == features["keep"] + assert ( + OBS_IMAGE in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE] + == features[PipelineFeatureType.OBSERVATION]["pixels"] + ) + assert "pixels" not in out[PipelineFeatureType.OBSERVATION] + assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"] assert_contract_is_typed(out) -def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory): - processor = VanillaObservationProcessor() +def test_image_processor_features_observation_pixels_to_image(policy_feature_factory): + processor = VanillaObservationProcessorStep() features = { - "observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "keep": policy_feature_factory(FeatureType.ENV, (1,)), + PipelineFeatureType.OBSERVATION: { + "observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + }, } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) - assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"] - assert "observation.pixels" not in out - assert out["keep"] == features["keep"] + assert ( + OBS_IMAGE in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE] + == features[PipelineFeatureType.OBSERVATION]["observation.pixels"] + ) + assert "observation.pixels" not in out[PipelineFeatureType.OBSERVATION] + assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"] assert_contract_is_typed(out) -def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory): - processor = VanillaObservationProcessor() +def test_image_processor_features_multi_camera_and_prefixed(policy_feature_factory): + processor = VanillaObservationProcessorStep() features = { - "pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "keep": policy_feature_factory(FeatureType.ENV, (7,)), + PipelineFeatureType.OBSERVATION: { + "pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (7,)), + }, } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) - assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"] - assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"] - assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"] - assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out - assert out["keep"] == features["keep"] + assert ( + f"{OBS_IMAGES}.front" in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"] + == features[PipelineFeatureType.OBSERVATION]["pixels.front"] + ) + assert ( + f"{OBS_IMAGES}.wrist" in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.wrist"] + == features[PipelineFeatureType.OBSERVATION]["pixels.wrist"] + ) + assert ( + f"{OBS_IMAGES}.rear" in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.rear"] + == features[PipelineFeatureType.OBSERVATION]["observation.pixels.rear"] + ) + assert ( + "pixels.front" not in out[PipelineFeatureType.OBSERVATION] + and "pixels.wrist" not in out[PipelineFeatureType.OBSERVATION] + and "observation.pixels.rear" not in out[PipelineFeatureType.OBSERVATION] + ) + assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"] assert_contract_is_typed(out) -def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory): - processor = VanillaObservationProcessor() +def test_state_processor_features_environment_and_agent_pos(policy_feature_factory): + processor = VanillaObservationProcessorStep() features = { - "environment_state": policy_feature_factory(FeatureType.STATE, (3,)), - "agent_pos": policy_feature_factory(FeatureType.STATE, (7,)), - "keep": policy_feature_factory(FeatureType.ENV, (1,)), + PipelineFeatureType.OBSERVATION: { + "environment_state": policy_feature_factory(FeatureType.STATE, (3,)), + "agent_pos": policy_feature_factory(FeatureType.STATE, (7,)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + }, } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) - assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"] - assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"] - assert "environment_state" not in out and "agent_pos" not in out - assert out["keep"] == features["keep"] + assert ( + OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE] + == features[PipelineFeatureType.OBSERVATION]["environment_state"] + ) + assert ( + OBS_STATE in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][OBS_STATE] + == features[PipelineFeatureType.OBSERVATION]["agent_pos"] + ) + assert ( + "environment_state" not in out[PipelineFeatureType.OBSERVATION] + and "agent_pos" not in out[PipelineFeatureType.OBSERVATION] + ) + assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"] assert_contract_is_typed(out) -def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory): - proc = VanillaObservationProcessor() +def test_state_processor_features_prefixed_inputs(policy_feature_factory): + proc = VanillaObservationProcessorStep() features = { - "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), - "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), + PipelineFeatureType.OBSERVATION: { + "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), + "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), + }, } - out = proc.feature_contract(features.copy()) + out = proc.transform_features(features.copy()) - assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"] - assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"] - assert "environment_state" not in out and "agent_pos" not in out + assert ( + OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE] + == features[PipelineFeatureType.OBSERVATION]["observation.environment_state"] + ) + assert ( + OBS_STATE in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][OBS_STATE] + == features[PipelineFeatureType.OBSERVATION]["observation.agent_pos"] + ) + assert ( + "environment_state" not in out[PipelineFeatureType.OBSERVATION] + and "agent_pos" not in out[PipelineFeatureType.OBSERVATION] + ) assert_contract_is_typed(out) diff --git a/tests/processor/test_pi0_processor.py b/tests/processor/test_pi0_processor.py new file mode 100644 index 00000000..c481cb18 --- /dev/null +++ b/tests/processor/test_pi0_processor.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for PI0 policy processor.""" + +from unittest.mock import patch + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.policies.pi0.configuration_pi0 import PI0Config +from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + EnvTransition, + NormalizerProcessorStep, + ProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +class MockTokenizerProcessorStep(ProcessorStep): + """Mock tokenizer processor step for testing.""" + + def __init__(self, *args, **kwargs): + # Accept any arguments to mimic the real TokenizerProcessorStep interface + pass + + def __call__(self, transition: EnvTransition) -> EnvTransition: + # Pass through transition unchanged + return transition + + def transform_features(self, features): + # Pass through features unchanged + return features + + +def create_default_config(): + """Create a default PI0 configuration for testing.""" + config = PI0Config() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + config.device = "cpu" + config.tokenizer_max_length = 128 + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)}, + OBS_IMAGE: {}, # No normalization for images + ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)}, + } + + +def test_make_pi0_processor_basic(): + """Test basic creation of PI0 processor.""" + config = create_default_config() + stats = create_default_stats() + + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): + preprocessor, postprocessor = make_pi0_pre_post_processors( + config, + stats, + ) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 6 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], Pi0NewLineProcessor) + # Step 3 would be TokenizerProcessorStep but it's mocked + assert isinstance(preprocessor.steps[4], DeviceProcessorStep) + assert isinstance(preprocessor.steps[5], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_pi0_newline_processor_single_task(): + """Test Pi0NewLineProcessor with single task string.""" + processor = Pi0NewLineProcessor() + + # Test with task that doesn't have newline + transition = create_transition(complementary_data={"task": "test task"}) + result = processor(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n" + + # Test with task that already has newline + transition = create_transition(complementary_data={"task": "test task\n"}) + result = processor(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n" + + +def test_pi0_newline_processor_list_of_tasks(): + """Test Pi0NewLineProcessor with list of task strings.""" + processor = Pi0NewLineProcessor() + + # Test with list of tasks + tasks = ["task1", "task2\n", "task3"] + transition = create_transition(complementary_data={"task": tasks}) + result = processor(transition) + expected = ["task1\n", "task2\n", "task3\n"] + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == expected + + +def test_pi0_newline_processor_empty_transition(): + """Test Pi0NewLineProcessor with empty transition.""" + processor = Pi0NewLineProcessor() + + # Test with no complementary_data + transition = create_transition() + result = processor(transition) + assert result == transition + + # Test with complementary_data but no task + transition = create_transition(complementary_data={"other": "data"}) + result = processor(transition) + assert result == transition + + # Test with None task + transition = create_transition(complementary_data={"task": None}) + result = processor(transition) + assert result == transition + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_pi0_processor_cuda(): + """Test PI0 processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + # Mock the tokenizer processor to act as pass-through + class MockTokenizerProcessorStep(ProcessorStep): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, transition): + return transition + + def state_dict(self): + return {} + + def load_state_dict(self, state): + pass + + def reset(self): + pass + + def get_config(self): + return {"tokenizer_name": "google/paligemma-3b-pt-224"} + + def transform_features(self, features): + return features + + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): + preprocessor, postprocessor = make_pi0_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = { + OBS_STATE: torch.randn(10), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action, complementary_data={"task": "test task"}) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_pi0_processor_accelerate_scenario(): + """Test PI0 processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + # Mock the tokenizer processor to act as pass-through + class MockTokenizerProcessorStep(ProcessorStep): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, transition): + return transition + + def state_dict(self): + return {} + + def load_state_dict(self, state): + pass + + def reset(self): + pass + + def get_config(self): + return {"tokenizer_name": "google/paligemma-3b-pt-224"} + + def transform_features(self, features): + return features + + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): + preprocessor, postprocessor = make_pi0_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU and batched + device = torch.device("cuda:0") + observation = { + OBS_STATE: torch.randn(1, 10).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 6).to(device) + transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_pi0_processor_multi_gpu(): + """Test PI0 processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + # Mock the tokenizer processor to act as pass-through + class MockTokenizerProcessorStep(ProcessorStep): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, transition): + return transition + + def state_dict(self): + return {} + + def load_state_dict(self, state): + pass + + def reset(self): + pass + + def get_config(self): + return {"tokenizer_name": "google/paligemma-3b-pt-224"} + + def transform_features(self, features): + return features + + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): + preprocessor, postprocessor = make_pi0_pre_post_processors( + config, + stats, + ) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = { + OBS_STATE: torch.randn(1, 10).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 6).to(device) + transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_pi0_processor_without_stats(): + """Test PI0 processor creation without dataset statistics.""" + config = create_default_config() + + # Mock the tokenizer processor + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): + preprocessor, postprocessor = make_pi0_pre_post_processors( + config, + dataset_stats=None, + ) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + +def test_pi0_newline_processor_state_dict(): + """Test Pi0NewLineProcessor state dict methods.""" + processor = Pi0NewLineProcessor() + + # Test state_dict (should be empty) + state = processor.state_dict() + assert state == {} + + # Test load_state_dict (should do nothing) + processor.load_state_dict({}) + + # Test reset (should do nothing) + processor.reset() + + # Test get_config + config = processor.get_config() + assert config == {} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_pi0_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + stats = create_default_stats() + config.device = "cuda" + + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): + preprocessor, _ = make_pi0_pre_post_processors( + config, + stats, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + norm_step = step # Now type checker knows this is NormalizerProcessorStep + modified_steps.append( + NormalizerProcessorStep( + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration (PI0 has NormalizerProcessorStep at index 5) + normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(10, dtype=torch.float32), # PI0 expects size 10 + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6 + transition = create_transition( + observation, action, complementary_data={"task": "test bfloat16 adaptation"} + ) + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 5665d5a7..0d17fed0 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -25,29 +25,21 @@ import pytest import torch import torch.nn as nn -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor -from lerobot.processor.pipeline import TransitionKey +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features +from lerobot.processor import ( + DataProcessorPipeline, + EnvTransition, + ProcessorStep, + ProcessorStepRegistry, + TransitionKey, +) +from lerobot.processor.converters import create_transition, identity_transition from tests.conftest import assert_contract_is_typed -def create_transition( - observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None -): - """Helper to create an EnvTransition dictionary.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info if info is not None else {}, - TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, - } - - @dataclass -class MockStep: +class MockStep(ProcessorStep): """Mock pipeline step for testing - demonstrates best practices. This example shows the proper separation: @@ -90,13 +82,15 @@ class MockStep: def reset(self) -> None: self.counter = 0 - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features @dataclass -class MockStepWithoutOptionalMethods: +class MockStepWithoutOptionalMethods(ProcessorStep): """Mock step that only implements the required __call__ method.""" multiplier: float = 2.0 @@ -112,13 +106,15 @@ class MockStepWithoutOptionalMethods: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features @dataclass -class MockStepWithTensorState: +class MockStepWithTensorState(ProcessorStep): """Mock step demonstrating mixed JSON attributes and tensor state.""" name: str = "tensor_step" @@ -168,14 +164,16 @@ class MockStepWithTensorState: self.running_mean.zero_() self.running_count.zero_() - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features def test_empty_pipeline(): """Test pipeline with no steps.""" - pipeline = RobotProcessor() + pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition) transition = create_transition() result = pipeline(transition) @@ -187,7 +185,7 @@ def test_empty_pipeline(): def test_single_step_pipeline(): """Test pipeline with a single step.""" step = MockStep("test_step") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step], to_transition=identity_transition, to_output=identity_transition) transition = create_transition() result = pipeline(transition) @@ -204,7 +202,9 @@ def test_multiple_steps_pipeline(): """Test pipeline with multiple steps.""" step1 = MockStep("step1") step2 = MockStep("step2") - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline( + [step1, step2], to_transition=identity_transition, to_output=identity_transition + ) transition = create_transition() result = pipeline(transition) @@ -216,7 +216,7 @@ def test_multiple_steps_pipeline(): def test_invalid_transition_format(): """Test pipeline with invalid transition format.""" - pipeline = RobotProcessor([MockStep()]) + pipeline = DataProcessorPipeline([MockStep()]) # Test with wrong type (tuple instead of dict) with pytest.raises(ValueError, match="EnvTransition must be a dictionary"): @@ -231,7 +231,7 @@ def test_step_through(): """Test step_through method with dict input.""" step1 = MockStep("step1") step2 = MockStep("step2") - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) transition = create_transition() @@ -252,7 +252,7 @@ def test_step_through_with_dict(): """Test step_through method with dict input.""" step1 = MockStep("step1") step2 = MockStep("step2") - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) batch = { "observation.image": None, @@ -291,7 +291,7 @@ def test_step_through_with_dict(): def test_step_through_no_hooks(): """Test that step_through doesn't execute hooks.""" step = MockStep("test_step") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) hook_calls = [] @@ -326,7 +326,7 @@ def test_indexing(): """Test pipeline indexing.""" step1 = MockStep("step1") step2 = MockStep("step2") - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) # Test integer indexing assert pipeline[0] is step1 @@ -334,7 +334,7 @@ def test_indexing(): # Test slice indexing sub_pipeline = pipeline[0:1] - assert isinstance(sub_pipeline, RobotProcessor) + assert isinstance(sub_pipeline, DataProcessorPipeline) assert len(sub_pipeline) == 1 assert sub_pipeline[0] is step1 @@ -342,7 +342,7 @@ def test_indexing(): def test_hooks(): """Test before/after step hooks.""" step = MockStep("test_step") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) before_calls = [] after_calls = [] @@ -366,7 +366,7 @@ def test_hooks(): def test_unregister_hooks(): """Test unregistering hooks from the pipeline.""" step = MockStep("test_step") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) # Test before_step_hook before_calls = [] @@ -405,7 +405,7 @@ def test_unregister_hooks(): def test_unregister_nonexistent_hook(): """Test error handling when unregistering hooks that don't exist.""" - pipeline = RobotProcessor([MockStep()]) + pipeline = DataProcessorPipeline([MockStep()]) def some_hook(idx: int, transition: EnvTransition): pass @@ -423,7 +423,7 @@ def test_unregister_nonexistent_hook(): def test_multiple_hooks_and_selective_unregister(): """Test registering multiple hooks and selectively unregistering them.""" - pipeline = RobotProcessor([MockStep("step1"), MockStep("step2")]) + pipeline = DataProcessorPipeline([MockStep("step1"), MockStep("step2")]) calls_1 = [] calls_2 = [] @@ -469,7 +469,7 @@ def test_multiple_hooks_and_selective_unregister(): def test_hook_execution_order_documentation(): """Test and document that hooks are executed sequentially in registration order.""" - pipeline = RobotProcessor([MockStep("step")]) + pipeline = DataProcessorPipeline([MockStep("step")]) execution_order = [] @@ -521,7 +521,7 @@ def test_save_and_load_pretrained(): step1.counter = 5 step2.counter = 10 - pipeline = RobotProcessor([step1, step2], name="TestPipeline") + pipeline = DataProcessorPipeline([step1, step2], name="TestPipeline") with tempfile.TemporaryDirectory() as tmp_dir: # Save pipeline @@ -543,7 +543,7 @@ def test_save_and_load_pretrained(): assert config["steps"][1]["config"]["counter"] == 10 # Load pipeline - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="testpipeline.json") assert loaded_pipeline.name == "TestPipeline" assert len(loaded_pipeline) == 2 @@ -556,7 +556,9 @@ def test_save_and_load_pretrained(): def test_step_without_optional_methods(): """Test pipeline with steps that don't implement optional methods.""" step = MockStepWithoutOptionalMethods(multiplier=3.0) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline( + [step], to_transition=identity_transition, to_output=identity_transition + ) # Identity for EnvTransition input/output transition = create_transition(reward=2.0) result = pipeline(transition) @@ -569,14 +571,16 @@ def test_step_without_optional_methods(): # Save/load should work even without optional methods with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json" + ) assert len(loaded_pipeline) == 1 def test_mixed_json_and_tensor_state(): """Test step with both JSON attributes and tensor state.""" step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) # Process some transitions with rewards for i in range(10): @@ -592,13 +596,15 @@ def test_mixed_json_and_tensor_state(): pipeline.save_pretrained(tmp_dir) # Check that both config and state files were created - config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor" - state_path = Path(tmp_dir) / "robotprocessor_step_0.safetensors" + config_path = Path(tmp_dir) / "dataprocessorpipeline.json" # Default name is "RobotProcessor" + state_path = Path(tmp_dir) / "dataprocessorpipeline_step_0.safetensors" assert config_path.exists() assert state_path.exists() # Load and verify - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json" + ) loaded_step = loaded_pipeline.steps[0] # Check JSON attributes were restored @@ -611,7 +617,7 @@ def test_mixed_json_and_tensor_state(): assert torch.allclose(loaded_step.running_mean, step.running_mean) -class MockModuleStep(nn.Module): +class MockModuleStep(ProcessorStep, nn.Module): """Mock step that inherits from nn.Module to test state_dict handling of module parameters.""" def __init__(self, input_dim: int = 10, hidden_dim: int = 5): @@ -651,23 +657,25 @@ class MockModuleStep(nn.Module): def state_dict(self) -> dict[str, torch.Tensor]: """Override to return all module parameters and buffers.""" # Get the module's state dict (includes all parameters and buffers) - return super().state_dict() + return nn.Module.state_dict(self) def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: """Override to load all module parameters and buffers.""" # Use the module's load_state_dict - super().load_state_dict(state) + nn.Module.load_state_dict(self, state) def reset(self) -> None: self.running_mean.zero_() self.counter = 0 - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features -class MockNonModuleStepWithState: +class MockNonModuleStepWithState(ProcessorStep): """Mock step that explicitly does NOT inherit from nn.Module but has tensor state. This tests the state_dict/load_state_dict path for regular classes. @@ -744,14 +752,16 @@ class MockNonModuleStepWithState: self.step_count.zero_() self.history.clear() - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features # Tests for overrides functionality @dataclass -class MockStepWithNonSerializableParam: +class MockStepWithNonSerializableParam(ProcessorStep): """Mock step that requires a non-serializable parameter.""" def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None): @@ -799,14 +809,16 @@ class MockStepWithNonSerializableParam: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features @ProcessorStepRegistry.register("registered_mock_step") @dataclass -class RegisteredMockStep: +class RegisteredMockStep(ProcessorStep): """Mock step registered in the registry.""" value: int = 42 @@ -838,8 +850,10 @@ class RegisteredMockStep: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features @@ -859,7 +873,7 @@ def test_from_pretrained_with_overrides(): env_step = MockStepWithNonSerializableParam(name="env_step", multiplier=2.0) registered_step = RegisteredMockStep(value=100, device="cpu") - pipeline = RobotProcessor([env_step, registered_step], name="TestOverrides") + pipeline = DataProcessorPipeline([env_step, registered_step], name="TestOverrides") with tempfile.TemporaryDirectory() as tmp_dir: # Save the pipeline @@ -877,7 +891,13 @@ def test_from_pretrained_with_overrides(): "registered_mock_step": {"device": "cuda", "value": 200}, } - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="testoverrides.json", + overrides=overrides, + to_transition=identity_transition, + to_output=identity_transition, + ) # Verify the pipeline was loaded correctly assert len(loaded_pipeline) == 2 @@ -903,7 +923,7 @@ def test_from_pretrained_with_partial_overrides(): step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0) step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0) - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -913,7 +933,13 @@ def test_from_pretrained_with_partial_overrides(): # The current implementation applies overrides to ALL steps with the same class name # Both steps will get the override - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides=overrides, + to_transition=identity_transition, + to_output=identity_transition, + ) transition = create_transition(reward=1.0) result = loaded_pipeline(transition) @@ -927,7 +953,7 @@ def test_from_pretrained_with_partial_overrides(): def test_from_pretrained_invalid_override_key(): """Test that invalid override keys raise KeyError.""" step = MockStepWithNonSerializableParam() - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -936,13 +962,15 @@ def test_from_pretrained_invalid_override_key(): overrides = {"NonExistentStep": {"param": "value"}} with pytest.raises(KeyError, match="Override keys.*do not match any step"): - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json", overrides=overrides + ) def test_from_pretrained_multiple_invalid_override_keys(): """Test that multiple invalid override keys are reported.""" step = MockStepWithNonSerializableParam() - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -951,7 +979,9 @@ def test_from_pretrained_multiple_invalid_override_keys(): overrides = {"NonExistentStep1": {"param": "value1"}, "NonExistentStep2": {"param": "value2"}} with pytest.raises(KeyError) as exc_info: - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json", overrides=overrides + ) error_msg = str(exc_info.value) assert "NonExistentStep1" in error_msg @@ -962,7 +992,7 @@ def test_from_pretrained_multiple_invalid_override_keys(): def test_from_pretrained_registered_step_override(): """Test overriding registered steps using registry names.""" registered_step = RegisteredMockStep(value=50, device="cpu") - pipeline = RobotProcessor([registered_step]) + pipeline = DataProcessorPipeline([registered_step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -970,7 +1000,13 @@ def test_from_pretrained_registered_step_override(): # Override using registry name overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}} - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides=overrides, + to_transition=identity_transition, + to_output=identity_transition, + ) # Test that overrides were applied transition = create_transition() @@ -986,7 +1022,7 @@ def test_from_pretrained_mixed_registered_and_unregistered(): unregistered_step = MockStepWithNonSerializableParam(name="unregistered", multiplier=1.0) registered_step = RegisteredMockStep(value=10, device="cpu") - pipeline = RobotProcessor([unregistered_step, registered_step]) + pipeline = DataProcessorPipeline([unregistered_step, registered_step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -998,7 +1034,13 @@ def test_from_pretrained_mixed_registered_and_unregistered(): "registered_mock_step": {"value": 777}, } - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides=overrides, + to_transition=identity_transition, + to_output=identity_transition, + ) # Test both steps transition = create_transition(reward=2.0) @@ -1013,13 +1055,18 @@ def test_from_pretrained_mixed_registered_and_unregistered(): def test_from_pretrained_no_overrides(): """Test that from_pretrained works without overrides (backward compatibility).""" step = MockStepWithNonSerializableParam(name="no_override", multiplier=3.0) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Load without overrides - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + to_transition=identity_transition, + to_output=identity_transition, + ) assert len(loaded_pipeline) == 1 @@ -1033,13 +1080,19 @@ def test_from_pretrained_no_overrides(): def test_from_pretrained_empty_overrides(): """Test that from_pretrained works with empty overrides dict.""" step = MockStepWithNonSerializableParam(multiplier=2.0) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Load with empty overrides - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={}) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides={}, + to_transition=identity_transition, + to_output=identity_transition, + ) assert len(loaded_pipeline) == 1 @@ -1053,7 +1106,7 @@ def test_from_pretrained_empty_overrides(): def test_from_pretrained_override_instantiation_error(): """Test that instantiation errors with overrides are properly reported.""" step = MockStepWithNonSerializableParam(multiplier=1.0) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -1066,13 +1119,15 @@ def test_from_pretrained_override_instantiation_error(): } with pytest.raises(ValueError, match="Failed to instantiate processor step"): - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json", overrides=overrides + ) def test_from_pretrained_with_state_and_overrides(): """Test that overrides work correctly with steps that have tensor state.""" step = MockStepWithTensorState(name="tensor_step", learning_rate=0.01, window_size=5) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) # Process some data to create state for i in range(10): @@ -1090,7 +1145,9 @@ def test_from_pretrained_with_state_and_overrides(): } } - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json", overrides=overrides + ) loaded_step = loaded_pipeline.steps[0] # Check that config overrides were applied @@ -1109,7 +1166,7 @@ def test_from_pretrained_override_error_messages(): """Test that error messages for override failures are helpful.""" step1 = MockStepWithNonSerializableParam(name="step1") step2 = RegisteredMockStep() - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -1118,7 +1175,9 @@ def test_from_pretrained_override_error_messages(): overrides = {"WrongStepName": {"param": "value"}} with pytest.raises(KeyError) as exc_info: - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json", overrides=overrides + ) error_msg = str(exc_info.value) assert "WrongStepName" in error_msg @@ -1129,20 +1188,20 @@ def test_from_pretrained_override_error_messages(): def test_repr_empty_processor(): """Test __repr__ with empty processor.""" - pipeline = RobotProcessor() + pipeline = DataProcessorPipeline() repr_str = repr(pipeline) - expected = "RobotProcessor(name='RobotProcessor', steps=0: [])" + expected = "DataProcessorPipeline(name='DataProcessorPipeline', steps=0: [])" assert repr_str == expected def test_repr_single_step(): """Test __repr__ with single step.""" step = MockStep("test_step") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) repr_str = repr(pipeline) - expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])" + expected = "DataProcessorPipeline(name='DataProcessorPipeline', steps=1: [MockStep])" assert repr_str == expected @@ -1150,18 +1209,18 @@ def test_repr_multiple_steps_under_limit(): """Test __repr__ with 2-3 steps (all shown).""" step1 = MockStep("step1") step2 = MockStepWithoutOptionalMethods() - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) repr_str = repr(pipeline) - expected = "RobotProcessor(name='RobotProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" + expected = "DataProcessorPipeline(name='DataProcessorPipeline', steps=2: [MockStep, MockStepWithoutOptionalMethods])" assert repr_str == expected # Test with 3 steps (boundary case) step3 = MockStepWithTensorState() - pipeline = RobotProcessor([step1, step2, step3]) + pipeline = DataProcessorPipeline([step1, step2, step3]) repr_str = repr(pipeline) - expected = "RobotProcessor(name='RobotProcessor', steps=3: [MockStep, MockStepWithoutOptionalMethods, MockStepWithTensorState])" + expected = "DataProcessorPipeline(name='DataProcessorPipeline', steps=3: [MockStep, MockStepWithoutOptionalMethods, MockStepWithTensorState])" assert repr_str == expected @@ -1173,30 +1232,30 @@ def test_repr_many_steps_truncated(): step4 = MockModuleStep() step5 = MockNonModuleStepWithState() - pipeline = RobotProcessor([step1, step2, step3, step4, step5]) + pipeline = DataProcessorPipeline([step1, step2, step3, step4, step5]) repr_str = repr(pipeline) - expected = "RobotProcessor(name='RobotProcessor', steps=5: [MockStep, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" + expected = "DataProcessorPipeline(name='DataProcessorPipeline', steps=5: [MockStep, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" assert repr_str == expected def test_repr_with_custom_name(): """Test __repr__ with custom processor name.""" step = MockStep("test_step") - pipeline = RobotProcessor([step], name="CustomProcessor") + pipeline = DataProcessorPipeline([step], name="CustomProcessor") repr_str = repr(pipeline) - expected = "RobotProcessor(name='CustomProcessor', steps=1: [MockStep])" + expected = "DataProcessorPipeline(name='CustomProcessor', steps=1: [MockStep])" assert repr_str == expected def test_repr_with_seed(): """Test __repr__ with seed parameter.""" step = MockStep("test_step") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) repr_str = repr(pipeline) - expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])" + expected = "DataProcessorPipeline(name='DataProcessorPipeline', steps=1: [MockStep])" assert repr_str == expected @@ -1204,20 +1263,22 @@ def test_repr_with_custom_name_and_seed(): """Test __repr__ with both custom name and seed.""" step1 = MockStep("step1") step2 = MockStepWithoutOptionalMethods() - pipeline = RobotProcessor([step1, step2], name="MyProcessor") + pipeline = DataProcessorPipeline([step1, step2], name="MyProcessor") repr_str = repr(pipeline) - expected = "RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" + expected = ( + "DataProcessorPipeline(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" + ) assert repr_str == expected def test_repr_without_seed(): """Test __repr__ when seed is explicitly None (should not show seed).""" step = MockStep("test_step") - pipeline = RobotProcessor([step], name="TestProcessor") + pipeline = DataProcessorPipeline([step], name="TestProcessor") repr_str = repr(pipeline) - expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])" + expected = "DataProcessorPipeline(name='TestProcessor', steps=1: [MockStep])" assert repr_str == expected @@ -1228,10 +1289,10 @@ def test_repr_various_step_types(): step3 = MockModuleStep() step4 = MockNonModuleStepWithState() - pipeline = RobotProcessor([step1, step2, step3, step4], name="MixedSteps") + pipeline = DataProcessorPipeline([step1, step2, step3, step4], name="MixedSteps") repr_str = repr(pipeline) - expected = "RobotProcessor(name='MixedSteps', steps=4: [MockStep, MockStepWithTensorState, ..., MockNonModuleStepWithState])" + expected = "DataProcessorPipeline(name='MixedSteps', steps=4: [MockStep, MockStepWithTensorState, ..., MockNonModuleStepWithState])" assert repr_str == expected @@ -1242,10 +1303,10 @@ def test_repr_edge_case_long_names(): step3 = MockStepWithTensorState() step4 = MockNonModuleStepWithState() - pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames") + pipeline = DataProcessorPipeline([step1, step2, step3, step4], name="LongNames") repr_str = repr(pipeline) - expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" + expected = "DataProcessorPipeline(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" assert repr_str == expected @@ -1253,7 +1314,7 @@ def test_repr_edge_case_long_names(): def test_save_with_custom_config_filename(): """Test saving processor with custom config filename.""" step = MockStep("test") - pipeline = RobotProcessor([step], name="TestProcessor") + pipeline = DataProcessorPipeline([step], name="TestProcessor") with tempfile.TemporaryDirectory() as tmp_dir: # Save with custom filename @@ -1269,16 +1330,18 @@ def test_save_with_custom_config_filename(): assert config["name"] == "TestProcessor" # Load with specific filename - loaded = RobotProcessor.from_pretrained(tmp_dir, config_filename="my_custom_config.json") + loaded = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="my_custom_config.json") assert loaded.name == "TestProcessor" def test_multiple_processors_same_directory(): """Test saving multiple processors to the same directory with different config files.""" # Create different processors - preprocessor = RobotProcessor([MockStep("pre1"), MockStep("pre2")], name="preprocessor") + preprocessor = DataProcessorPipeline([MockStep("pre1"), MockStep("pre2")], name="preprocessor") - postprocessor = RobotProcessor([MockStepWithoutOptionalMethods(multiplier=0.5)], name="postprocessor") + postprocessor = DataProcessorPipeline( + [MockStepWithoutOptionalMethods(multiplier=0.5)], name="postprocessor" + ) with tempfile.TemporaryDirectory() as tmp_dir: # Save both to same directory @@ -1290,8 +1353,8 @@ def test_multiple_processors_same_directory(): assert (Path(tmp_dir) / "postprocessor.json").exists() # Load them back - loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json") - loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json") + loaded_pre = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="preprocessor.json") + loaded_post = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="postprocessor.json") assert loaded_pre.name == "preprocessor" assert loaded_post.name == "postprocessor" @@ -1299,31 +1362,34 @@ def test_multiple_processors_same_directory(): assert len(loaded_post) == 1 -def test_auto_detect_single_config(): - """Test automatic config detection when there's only one JSON file.""" +def test_explicit_config_filename_loading(): + """Test explicit config filename loading (no more auto-detection).""" step = MockStepWithTensorState() - pipeline = RobotProcessor([step], name="SingleConfig") + pipeline = DataProcessorPipeline([step], name="SingleConfig") with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) - # Load without specifying config_filename - loaded = RobotProcessor.from_pretrained(tmp_dir) + # Load with explicit config_filename (now required) + loaded = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="singleconfig.json") assert loaded.name == "SingleConfig" -def test_error_multiple_configs_no_filename(): - """Test error when multiple configs exist and no filename specified.""" - proc1 = RobotProcessor([MockStep()], name="processor1") - proc2 = RobotProcessor([MockStep()], name="processor2") +def test_explicit_config_selection_with_multiple_configs(): + """Test explicit config selection when multiple configs exist.""" + proc1 = DataProcessorPipeline([MockStep()], name="processor1") + proc2 = DataProcessorPipeline([MockStep()], name="processor2") with tempfile.TemporaryDirectory() as tmp_dir: proc1.save_pretrained(tmp_dir) proc2.save_pretrained(tmp_dir) - # Should raise error - with pytest.raises(ValueError, match="Multiple .json files found"): - RobotProcessor.from_pretrained(tmp_dir) + # Can load specific configs explicitly + loaded1 = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor1.json") + loaded2 = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor2.json") + + assert loaded1.name == "processor1" + assert loaded2.name == "processor2" def test_state_file_naming_with_indices(): @@ -1333,7 +1399,7 @@ def test_state_file_naming_with_indices(): step2 = MockStepWithTensorState(name="norm2", window_size=10) step3 = MockModuleStep(input_dim=5) - pipeline = RobotProcessor([step1, step2, step3]) + pipeline = DataProcessorPipeline([step1, step2, step3]) # Process some data to create state for i in range(5): @@ -1349,9 +1415,9 @@ def test_state_file_naming_with_indices(): # Files should be named with pipeline name prefix and indices expected_names = [ - "robotprocessor_step_0.safetensors", - "robotprocessor_step_1.safetensors", - "robotprocessor_step_2.safetensors", + "dataprocessorpipeline_step_0.safetensors", + "dataprocessorpipeline_step_1.safetensors", + "dataprocessorpipeline_step_2.safetensors", ] actual_names = [f.name for f in state_files] assert actual_names == expected_names @@ -1363,7 +1429,7 @@ def test_state_file_naming_with_registry(): # Register a test step @ProcessorStepRegistry.register("test_stateful_step") @dataclass - class TestStatefulStep: + class TestStatefulStep(ProcessorStep): value: int = 0 def __init__(self, value: int = 0): @@ -1382,15 +1448,17 @@ def test_state_file_naming_with_registry(): def load_state_dict(self, state): self.state_tensor = state["state_tensor"] - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features try: # Create pipeline with registered steps step1 = TestStatefulStep(1) step2 = TestStatefulStep(2) - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -1401,8 +1469,8 @@ def test_state_file_naming_with_registry(): # Should include pipeline name, index and registry name expected_names = [ - "robotprocessor_step_0_test_stateful_step.safetensors", - "robotprocessor_step_1_test_stateful_step.safetensors", + "dataprocessorpipeline_step_0_test_stateful_step.safetensors", + "dataprocessorpipeline_step_1_test_stateful_step.safetensors", ] actual_names = [f.name for f in state_files] assert actual_names == expected_names @@ -1418,7 +1486,7 @@ def test_override_with_nested_config(): @ProcessorStepRegistry.register("complex_config_step") @dataclass - class ComplexConfigStep: + class ComplexConfigStep(ProcessorStep): name: str = "complex" simple_param: int = 42 nested_config: dict = None @@ -1439,21 +1507,26 @@ def test_override_with_nested_config(): def get_config(self): return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config} - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features try: step = ComplexConfigStep() - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Load with nested override - loaded = RobotProcessor.from_pretrained( + loaded = DataProcessorPipeline.from_pretrained( tmp_dir, + config_filename="dataprocessorpipeline.json", overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}}, + to_transition=identity_transition, + to_output=identity_transition, ) # Test that override worked @@ -1467,14 +1540,15 @@ def test_override_with_nested_config(): def test_override_preserves_defaults(): """Test that overrides only affect specified parameters.""" step = MockStepWithNonSerializableParam(name="test", multiplier=2.0) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Override only one parameter - loaded = RobotProcessor.from_pretrained( + loaded = DataProcessorPipeline.from_pretrained( tmp_dir, + config_filename="dataprocessorpipeline.json", overrides={ "MockStepWithNonSerializableParam": { "multiplier": 5.0 # Only override multiplier @@ -1491,7 +1565,7 @@ def test_override_preserves_defaults(): def test_override_type_validation(): """Test that type errors in overrides are caught properly.""" step = MockStepWithTensorState(learning_rate=0.01) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -1504,7 +1578,9 @@ def test_override_type_validation(): } with pytest.raises(ValueError, match="Failed to instantiate"): - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json", overrides=overrides + ) def test_override_with_callables(): @@ -1512,7 +1588,7 @@ def test_override_with_callables(): @ProcessorStepRegistry.register("callable_step") @dataclass - class CallableStep: + class CallableStep(ProcessorStep): name: str = "callable_step" transform_fn: Any = None @@ -1531,13 +1607,15 @@ def test_override_with_callables(): def get_config(self): return {"name": self.name} - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features try: step = CallableStep() - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -1551,8 +1629,12 @@ def test_override_with_callables(): return x # Load with callable override - loaded = RobotProcessor.from_pretrained( - tmp_dir, overrides={"callable_step": {"transform_fn": double_values}} + loaded = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides={"callable_step": {"transform_fn": double_values}}, + to_transition=identity_transition, + to_output=identity_transition, ) # Test it works @@ -1567,14 +1649,16 @@ def test_override_multiple_same_class_warning(): """Test behavior when multiple steps of same class exist.""" step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0) step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0) - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Override affects all instances of the class - loaded = RobotProcessor.from_pretrained( - tmp_dir, overrides={"MockStepWithNonSerializableParam": {"multiplier": 10.0}} + loaded = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides={"MockStepWithNonSerializableParam": {"multiplier": 10.0}}, ) # Both steps get the same override @@ -1589,7 +1673,7 @@ def test_override_multiple_same_class_warning(): def test_config_filename_special_characters(): """Test config filenames with special characters are sanitized.""" # Processor name with special characters - pipeline = RobotProcessor([MockStep()], name="My/Processor\\With:Special*Chars") + pipeline = DataProcessorPipeline([MockStep()], name="My/Processor\\With:Special*Chars") with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -1607,10 +1691,10 @@ def test_state_file_naming_with_multiple_processors(): """Test that state files are properly prefixed with pipeline names to avoid conflicts.""" # Create two processors with state step1 = MockStepWithTensorState(name="norm", window_size=5) - preprocessor = RobotProcessor([step1], name="PreProcessor") + preprocessor = DataProcessorPipeline([step1], name="PreProcessor") step2 = MockStepWithTensorState(name="norm", window_size=10) - postprocessor = RobotProcessor([step2], name="PostProcessor") + postprocessor = DataProcessorPipeline([step2], name="PostProcessor") # Process some data to create state for i in range(3): @@ -1630,8 +1714,8 @@ def test_state_file_naming_with_multiple_processors(): assert (Path(tmp_dir) / "postprocessor_step_0.safetensors").exists() # Load both back and verify they work correctly - loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json") - loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json") + loaded_pre = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="preprocessor.json") + loaded_post = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="postprocessor.json") assert loaded_pre.name == "PreProcessor" assert loaded_post.name == "PostProcessor" @@ -1644,7 +1728,7 @@ def test_override_with_device_strings(): @ProcessorStepRegistry.register("device_aware_step") @dataclass - class DeviceAwareStep: + class DeviceAwareStep(ProcessorStep): device: str = "cpu" def __init__(self, device: str = "cpu"): @@ -1663,21 +1747,25 @@ def test_override_with_device_strings(): def load_state_dict(self, state): self.buffer = state["buffer"] - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features try: step = DeviceAwareStep(device="cpu") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Override device if torch.cuda.is_available(): - loaded = RobotProcessor.from_pretrained( - tmp_dir, overrides={"device_aware_step": {"device": "cuda:0"}} + loaded = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides={"device_aware_step": {"device": "cuda:0"}}, ) loaded_step = loaded.steps[0] @@ -1691,20 +1779,27 @@ def test_override_with_device_strings(): def test_from_pretrained_nonexistent_path(): """Test error handling when loading from non-existent sources.""" - from huggingface_hub.errors import HfHubHTTPError, HFValidationError + from huggingface_hub.errors import HfHubHTTPError - # Test with an invalid repo ID (too many slashes) - caught by HF validation - with pytest.raises(HFValidationError): - RobotProcessor.from_pretrained("/path/that/does/not/exist") + # Test with an invalid local path - should raise FileNotFoundError + with pytest.raises(FileNotFoundError): + DataProcessorPipeline.from_pretrained("/path/that/does/not/exist", config_filename="processor.json") - # Test with a non-existent but valid Hub repo format + # Test with a path that doesn't exist as a directory + with pytest.raises(FileNotFoundError): + DataProcessorPipeline.from_pretrained("user/repo/extra/path", config_filename="processor.json") + + # Test with a non-existent Hub repo with pytest.raises((FileNotFoundError, HfHubHTTPError)): - RobotProcessor.from_pretrained("nonexistent-user/nonexistent-repo") + DataProcessorPipeline.from_pretrained( + "nonexistent-user/nonexistent-repo", config_filename="processor.json" + ) # Test with a local directory that exists but has no config files with tempfile.TemporaryDirectory() as tmp_dir: - with pytest.raises(FileNotFoundError, match="No .json configuration files found"): - RobotProcessor.from_pretrained(tmp_dir) + # Since the directory exists but has no config, it will raise FileNotFoundError + with pytest.raises(FileNotFoundError): + DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor.json") def test_save_load_with_custom_converter_functions(): @@ -1733,13 +1828,15 @@ def test_save_load_with_custom_converter_functions(): } # Create processor with custom converters - pipeline = RobotProcessor([MockStep()], to_transition=custom_to_transition, to_output=custom_to_output) + pipeline = DataProcessorPipeline( + [MockStep()], to_transition=custom_to_transition, to_output=custom_to_output + ) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Load - should use default converters - loaded = RobotProcessor.from_pretrained(tmp_dir) + loaded = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="dataprocessorpipeline.json") # Verify it uses default converters by checking with standard batch format batch = { @@ -1753,35 +1850,39 @@ def test_save_load_with_custom_converter_functions(): # Should work with standard format (wouldn't work with custom converter) result = loaded(batch) - assert "observation.image" in result # Standard format preserved + # With new behavior, default to_output is _default_transition_to_batch, so result is batch dict + assert "observation.image" in result class NonCompliantStep: - """Intentionally non-compliant: missing feature_contract.""" + """Intentionally non-compliant: missing features.""" def __call__(self, transition: EnvTransition) -> EnvTransition: return transition -def test_construction_rejects_step_without_feature_contract(): - with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"): - RobotProcessor([NonCompliantStep()]) - - -class NonCallableStep: +class NonCallableStep(ProcessorStep): """Intentionally non-compliant: missing __call__.""" - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: return features def test_construction_rejects_step_without_call(): - with pytest.raises(TypeError, match=r"must define __call__"): - RobotProcessor([NonCallableStep()]) + """Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep.""" + with pytest.raises( + TypeError, match=r"Can't instantiate abstract class NonCallableStep with abstract method __call_" + ): + DataProcessorPipeline([NonCallableStep()]) + + with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"): + DataProcessorPipeline([NonCompliantStep()]) @dataclass -class FeatureContractAddStep: +class FeatureContractAddStep(ProcessorStep): """Adds a PolicyFeature""" key: str = "a" @@ -1790,39 +1891,47 @@ class FeatureContractAddStep: def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - features[self.key] = self.value + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + features[PipelineFeatureType.OBSERVATION][self.key] = self.value return features @dataclass -class FeatureContractMutateStep: +class FeatureContractMutateStep(ProcessorStep): """Mutates a PolicyFeature""" key: str = "a" - fn: Callable[[PolicyFeature | None], PolicyFeature] = lambda x: x # noqa: E731 + fn: Callable[[PolicyFeature | None], PolicyFeature] = identity_transition # noqa: E731 def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - features[self.key] = self.fn(features.get(self.key)) + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + features[PipelineFeatureType.OBSERVATION][self.key] = self.fn( + features[PipelineFeatureType.OBSERVATION].get(self.key) + ) return features @dataclass -class FeatureContractBadReturnStep: +class FeatureContractBadReturnStep(ProcessorStep): """Returns a non-dict""" def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: return ["not-a-dict"] @dataclass -class FeatureContractRemoveStep: +class FeatureContractRemoveStep(ProcessorStep): """Removes a PolicyFeature""" key: str @@ -1830,32 +1939,39 @@ class FeatureContractRemoveStep: def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - features.pop(self.key, None) + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + features[PipelineFeatureType.OBSERVATION].pop(self.key, None) return features -def test_feature_contract_orders_and_merges(policy_feature_factory): - p = RobotProcessor( +def test_features_orders_and_merges(policy_feature_factory): + p = DataProcessorPipeline( [ FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), FeatureContractMutateStep("a", lambda v: PolicyFeature(type=v.type, shape=(3,))), FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))), ] ) - out = p.feature_contract({}) - - assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,) - assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,) + out = p.transform_features({PipelineFeatureType.OBSERVATION: {}}) + assert out[PipelineFeatureType.OBSERVATION]["a"].type == FeatureType.STATE and out[ + PipelineFeatureType.OBSERVATION + ]["a"].shape == (3,) + assert out[PipelineFeatureType.OBSERVATION]["b"].type == FeatureType.ENV and out[ + PipelineFeatureType.OBSERVATION + ]["b"].shape == (2,) assert_contract_is_typed(out) -def test_feature_contract_respects_initial_without_mutation(policy_feature_factory): +def test_features_respects_initial_without_mutation(policy_feature_factory): initial = { - "seed": policy_feature_factory(FeatureType.STATE, (7,)), - "nested": policy_feature_factory(FeatureType.ENV, (0,)), + PipelineFeatureType.OBSERVATION: { + "seed": policy_feature_factory(FeatureType.STATE, (7,)), + "nested": policy_feature_factory(FeatureType.ENV, (0,)), + } } - p = RobotProcessor( + p = DataProcessorPipeline( [ FeatureContractMutateStep("seed", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 1,))), FeatureContractMutateStep( @@ -1863,57 +1979,224 @@ def test_feature_contract_respects_initial_without_mutation(policy_feature_facto ), ] ) - out = p.feature_contract(initial_features=initial) + out = p.transform_features(initial_features=initial) - assert out["seed"].shape == (8,) - assert out["nested"].shape == (5,) + assert out[PipelineFeatureType.OBSERVATION]["seed"].shape == (8,) + assert out[PipelineFeatureType.OBSERVATION]["nested"].shape == (5,) # Initial dict must be preserved - assert initial["seed"].shape == (7,) - assert initial["nested"].shape == (0,) + assert initial[PipelineFeatureType.OBSERVATION]["seed"].shape == (7,) + assert initial[PipelineFeatureType.OBSERVATION]["nested"].shape == (0,) assert_contract_is_typed(out) -def test_feature_contract_type_error_on_bad_step(): - p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()]) - with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"): - _ = p.feature_contract({}) - - -def test_feature_contract_execution_order_tracking(): - class Track: +def test_features_execution_order_tracking(): + class Track(ProcessorStep): def __init__(self, label): self.label = label def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: code = {"A": 1, "B": 2, "C": 3}[self.label] - pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=())) - features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,)) + pf = features[PipelineFeatureType.OBSERVATION].get( + "order", PolicyFeature(type=FeatureType.ENV, shape=()) + ) + features[PipelineFeatureType.OBSERVATION]["order"] = PolicyFeature( + type=pf.type, shape=pf.shape + (code,) + ) return features - out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({}) - assert out["order"].shape == (1, 2, 3) + out = DataProcessorPipeline([Track("A"), Track("B"), Track("C")]).transform_features( + initial_features={PipelineFeatureType.OBSERVATION: {}} + ) + assert out[PipelineFeatureType.OBSERVATION]["order"].shape == (1, 2, 3) -def test_feature_contract_remove_key(policy_feature_factory): - p = RobotProcessor( +def test_features_remove_key(policy_feature_factory): + p = DataProcessorPipeline( [ FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), FeatureContractRemoveStep("a"), ] ) - out = p.feature_contract({}) - assert "a" not in out + out = p.transform_features({PipelineFeatureType.OBSERVATION: {}}) + assert "a" not in out[PipelineFeatureType.OBSERVATION] -def test_feature_contract_remove_from_initial(policy_feature_factory): +def test_features_remove_from_initial(policy_feature_factory): initial = { - "keep": policy_feature_factory(FeatureType.STATE, (1,)), - "drop": policy_feature_factory(FeatureType.STATE, (1,)), + PipelineFeatureType.OBSERVATION: { + "keep": policy_feature_factory(FeatureType.STATE, (1,)), + "drop": policy_feature_factory(FeatureType.STATE, (1,)), + }, } - p = RobotProcessor([FeatureContractRemoveStep("drop")]) - out = p.feature_contract(initial_features=initial) - assert "drop" not in out and out["keep"] == initial["keep"] + p = DataProcessorPipeline([FeatureContractRemoveStep("drop")]) + out = p.transform_features(initial_features=initial) + assert ( + "drop" not in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION]["keep"] == initial[PipelineFeatureType.OBSERVATION]["keep"] + ) + + +@dataclass +class AddActionEEAndJointFeatures(ProcessorStep): + """Adds both EE and JOINT action features.""" + + def __call__(self, tr): + return tr + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # EE features + features[PipelineFeatureType.ACTION]["action.ee.x"] = float + features[PipelineFeatureType.ACTION]["action.ee.y"] = float + # JOINT features + features[PipelineFeatureType.ACTION]["action.j1.pos"] = float + features[PipelineFeatureType.ACTION]["action.j2.pos"] = float + return features + + +@dataclass +class AddObservationStateFeatures(ProcessorStep): + """Adds state features (and optionally an image spec to test precedence).""" + + add_front_image: bool = False + front_image_shape: tuple = (240, 320, 3) + + def __call__(self, tr): + return tr + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # State features (mix EE and a joint state) + features[PipelineFeatureType.OBSERVATION]["observation.state.ee.x"] = float + features[PipelineFeatureType.OBSERVATION]["observation.state.j1.pos"] = float + if self.add_front_image: + features[PipelineFeatureType.OBSERVATION]["observation.images.front"] = self.front_image_shape + return features + + +def test_aggregate_joint_action_only(): + rp = DataProcessorPipeline([AddActionEEAndJointFeatures()]) + initial = {PipelineFeatureType.OBSERVATION: {"front": (480, 640, 3)}, PipelineFeatureType.ACTION: {}} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features=initial, + use_videos=True, + patterns=["action.j1.pos", "action.j2.pos"], + ) + + # Expect only "action" with joint names + assert "action" in out and "observation.state" not in out + assert out["action"]["dtype"] == "float32" + assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"} + assert out["action"]["shape"] == (len(out["action"]["names"]),) + + +def test_aggregate_ee_action_and_observation_with_videos(): + rp = DataProcessorPipeline([AddActionEEAndJointFeatures(), AddObservationStateFeatures()]) + initial = {"front": (480, 640, 3), "side": (720, 1280, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}}, + use_videos=True, + patterns=["action.ee", "observation.state"], + ) + + # Action should pack only EE names + assert "action" in out + assert set(out["action"]["names"]) == {"ee.x", "ee.y"} + assert out["action"]["dtype"] == "float32" + + # Observation state should pack both ee.x and j1.pos as a vector + assert "observation.state" in out + assert set(out["observation.state"]["names"]) == {"ee.x", "j1.pos"} + assert out["observation.state"]["dtype"] == "float32" + + # Cameras from initial_features appear as videos + for cam in ("front", "side"): + key = f"observation.images.{cam}" + assert key in out + assert out[key]["dtype"] == "video" + assert out[key]["shape"] == initial[cam] + assert out[key]["names"] == ["height", "width", "channels"] + + +def test_aggregate_both_action_types(): + rp = DataProcessorPipeline([AddActionEEAndJointFeatures()]) + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: {}}, + use_videos=True, + patterns=["action.ee", "action.j1", "action.j2.pos"], + ) + + assert "action" in out + expected = {"ee.x", "ee.y", "j1.pos", "j2.pos"} + assert set(out["action"]["names"]) == expected + assert out["action"]["shape"] == (len(expected),) + + +def test_aggregate_images_when_use_videos_false(): + rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)]) + initial = {"back": (480, 640, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial}, + use_videos=False, # expect "image" dtype + patterns=None, + ) + + key = "observation.images.back" + key_front = "observation.images.front" + assert key not in out + assert key_front not in out + + +def test_aggregate_images_when_use_videos_true(): + rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)]) + initial = {"back": (480, 640, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}}, + use_videos=True, + patterns=None, + ) + + key = "observation.images.front" + key_back = "observation.images.back" + assert key in out + assert key_back in out + assert out[key]["dtype"] == "video" + assert out[key_back]["dtype"] == "video" + assert out[key_back]["shape"] == initial["back"] + + +def test_initial_camera_not_overridden_by_step_image(): + # Step explicitly sets a different front image shape; initial has another shape. + # aggregate_pipeline_dataset_features should keep the step's value (setdefault behavior on initial cams). + rp = DataProcessorPipeline( + [AddObservationStateFeatures(add_front_image=True, front_image_shape=(240, 320, 3))] + ) + initial = {"front": (480, 640, 3)} # should NOT override the step-provided (240, 320, 3) + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial}, + use_videos=True, + patterns=["observation.images.front"], + ) + + key = "observation.images.front" + assert key in out + assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial diff --git a/tests/processor/test_pipeline_from_pretrained_helpers.py b/tests/processor/test_pipeline_from_pretrained_helpers.py new file mode 100644 index 00000000..89d45cba --- /dev/null +++ b/tests/processor/test_pipeline_from_pretrained_helpers.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for DataProcessorPipeline.from_pretrained helper methods. + +These tests focus on the individual private methods that were extracted from +the main from_pretrained method to improve modularity and testability. +""" + +import json +import tempfile +from pathlib import Path + +import pytest + +from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError + +# Simplified Config Loading Tests + + +def test_load_config_directory(): + """Test loading config from directory.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a config file + config_file = tmp_path / "processor.json" + test_config = {"name": "TestProcessor", "steps": []} + config_file.write_text(json.dumps(test_config)) + + # Load from directory + loaded_config, base_path = DataProcessorPipeline._load_config(str(tmp_path), "processor.json", {}) + + assert loaded_config == test_config + assert base_path == tmp_path + + +def test_load_config_single_file(): + """Test loading config from a single file path.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a config file + config_file = tmp_path / "processor.json" + test_config = {"name": "TestProcessor", "steps": []} + config_file.write_text(json.dumps(test_config)) + + # Load using file path directly + loaded_config, base_path = DataProcessorPipeline._load_config( + str(config_file), "any_filename_ignored", {} + ) + + assert loaded_config == test_config + assert base_path == tmp_path + + +def test_load_config_directory_file_not_found(): + """Test directory loading when config file doesn't exist.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Directory exists but no processor.json + with pytest.raises(FileNotFoundError, match="not found in directory"): + DataProcessorPipeline._load_config(str(tmp_path), "processor.json", {}) + + +def test_load_config_directory_with_migration_detection(): + """Test that missing config triggers migration detection.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create old-style config to trigger migration + (tmp_path / "config.json").write_text(json.dumps({"type": "act"})) + + # Try to load processor.json (doesn't exist), should trigger migration + with pytest.raises(ProcessorMigrationError): + DataProcessorPipeline._load_config(str(tmp_path), "processor.json", {}) + + +def test_load_config_nonexistent_path_tries_hub(): + """Test that nonexistent paths try Hub (simplified logic).""" + # This path doesn't exist locally, should try Hub + with pytest.raises(FileNotFoundError, match="on the HuggingFace Hub"): + DataProcessorPipeline._load_config("nonexistent/path", "processor.json", {}) + + +# Config Validation Tests + + +def test_validate_loaded_config_valid_config(): + """Test validation with valid processor config.""" + valid_config = {"name": "TestProcessor", "steps": []} + + # Should not raise any exception + DataProcessorPipeline._validate_loaded_config("any-path", valid_config, "processor.json") + + +def test_validate_loaded_config_invalid_config(): + """Test validation with invalid processor config.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create non-processor config to trigger migration + (tmp_path / "config.json").write_text(json.dumps({"type": "act"})) + + invalid_config = {"type": "act", "hidden_dim": 256} + + with pytest.raises(ProcessorMigrationError): + DataProcessorPipeline._validate_loaded_config(str(tmp_path), invalid_config, "config.json") + + +def test_validate_loaded_config_invalid_config_no_migration(): + """Test validation with invalid config when no migration is detected.""" + # Non-directory path (Hub repo) - no migration detection + invalid_config = {"type": "act", "hidden_dim": 256} + + with pytest.raises(ValueError, match="not a valid processor configuration"): + DataProcessorPipeline._validate_loaded_config("user/repo", invalid_config, "config.json") + + +# Step Class Resolution Tests + + +def test_resolve_step_class_registry_name(): + """Test resolution using registry name.""" + from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry + + # Register a test step + @ProcessorStepRegistry.register("test_step") + class TestStep(ProcessorStep): + def __call__(self, transition): + return transition + + def transform_features(self, features): + return features + + try: + step_entry = {"registry_name": "test_step"} + step_class, step_key = DataProcessorPipeline._resolve_step_class(step_entry) + + assert step_class is TestStep + assert step_key == "test_step" + finally: + ProcessorStepRegistry.unregister("test_step") + + +def test_resolve_step_class_registry_name_not_found(): + """Test resolution with non-existent registry name.""" + step_entry = {"registry_name": "nonexistent_step"} + + with pytest.raises(ImportError, match="Failed to load processor step from registry"): + DataProcessorPipeline._resolve_step_class(step_entry) + + +def test_resolve_step_class_import_path(): + """Test resolution using full import path.""" + # Use a valid existing class (this should work) + step_entry = {"class": "lerobot.processor.pipeline.ProcessorStep"} + + # This should succeed - ProcessorStep can be imported, just not instantiated + step_class, step_key = DataProcessorPipeline._resolve_step_class(step_entry) + + from lerobot.processor.pipeline import ProcessorStep + + assert step_class is ProcessorStep + assert step_key == "ProcessorStep" + + +def test_resolve_step_class_invalid_import_path(): + """Test resolution with invalid import path.""" + step_entry = {"class": "nonexistent.module.ClassName"} + + with pytest.raises(ImportError, match="Failed to load processor step"): + DataProcessorPipeline._resolve_step_class(step_entry) + + +# Override Validation Tests + + +def test_validate_overrides_used_all_used(): + """Test validation when all overrides are used.""" + # Empty set means all overrides were used + remaining_overrides = set() + config = {"steps": [{"class": "SomeStep"}]} + + # Should not raise + DataProcessorPipeline._validate_overrides_used(remaining_overrides, config) + + +def test_validate_overrides_used_some_unused(): + """Test validation when some overrides are unused.""" + remaining_overrides = {"NonExistentStep", "AnotherMissingStep"} + config = { + "steps": [ + {"registry_name": "normalize_step"}, + {"class": "some.module.TransformStep"}, + ] + } + + with pytest.raises(KeyError, match="Override keys.*do not match any step"): + DataProcessorPipeline._validate_overrides_used(remaining_overrides, config) + + +def test_validate_overrides_used_helpful_error_message(): + """Test that error message includes available step keys.""" + remaining_overrides = {"WrongStep"} + config = { + "steps": [ + {"registry_name": "correct_step"}, + {"class": "module.path.CorrectClass"}, + ] + } + + with pytest.raises(KeyError) as exc_info: + DataProcessorPipeline._validate_overrides_used(remaining_overrides, config) + + error_msg = str(exc_info.value) + assert "Available step keys" in error_msg + assert "correct_step" in error_msg + assert "CorrectClass" in error_msg + + +# Integration Tests for Simplified Logic + + +def test_simplified_three_way_loading(): + """Test that the simplified 3-way loading logic works correctly.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Test 1: Directory loading + config_file = tmp_path / "processor.json" + test_config = {"name": "DirectoryTest", "steps": []} + config_file.write_text(json.dumps(test_config)) + + loaded_config, base_path = DataProcessorPipeline._load_config(str(tmp_path), "processor.json", {}) + assert loaded_config["name"] == "DirectoryTest" + assert base_path == tmp_path + + # Test 2: Single file loading + loaded_config, base_path = DataProcessorPipeline._load_config( + str(config_file), "ignored_filename", {} + ) + assert loaded_config["name"] == "DirectoryTest" + assert base_path == tmp_path diff --git a/tests/processor/test_policy_robot_bridge.py b/tests/processor/test_policy_robot_bridge.py new file mode 100644 index 00000000..f3bbd9a7 --- /dev/null +++ b/tests/processor/test_policy_robot_bridge.py @@ -0,0 +1,525 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path + +import pytest +import torch + +from lerobot.configs.types import FeatureType, PipelineFeatureType +from lerobot.processor import ( + DataProcessorPipeline, + PolicyActionToRobotActionProcessorStep, + ProcessorStepRegistry, + RobotActionToPolicyActionProcessorStep, +) +from lerobot.processor.converters import identity_transition +from tests.conftest import assert_contract_is_typed + + +def test_robot_to_policy_basic_action_conversion(): + """Test basic robot action to policy action conversion.""" + motor_names = ["joint1", "joint2", "joint3"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + robot_action = { + "joint1.pos": 1.0, + "joint2.pos": 2.0, + "joint3.pos": 3.0, + } + + policy_action = processor.action(robot_action) + + assert isinstance(policy_action, torch.Tensor) + assert policy_action.shape == (3,) + torch.testing.assert_close(policy_action, torch.tensor([1.0, 2.0, 3.0])) + + +def test_robot_to_policy_action_conversion_preserves_order(): + """Test that motor names order is preserved in conversion.""" + motor_names = ["gripper", "arm", "wrist"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + robot_action = { + "arm.pos": 10.0, + "gripper.pos": 5.0, + "wrist.pos": 15.0, + } + + policy_action = processor.action(robot_action) + + expected = torch.tensor([5.0, 10.0, 15.0]) + torch.testing.assert_close(policy_action, expected) + + +def test_robot_to_policy_action_conversion_with_floats_and_tensors(): + """Test conversion with mixed float and tensor values.""" + motor_names = ["joint1", "joint2"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + robot_action = { + "joint1.pos": torch.tensor(1.5), + "joint2.pos": 2.5, # Regular float + } + + policy_action = processor.action(robot_action) + + assert isinstance(policy_action, torch.Tensor) + torch.testing.assert_close(policy_action, torch.tensor([1.5, 2.5])) + + +def test_robot_to_policy_action_length_mismatch_error(): + """Test error when robot action length doesn't match motor names.""" + motor_names = ["joint1", "joint2", "joint3"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + # Too few actions + robot_action = {"joint1.pos": 1.0, "joint2.pos": 2.0} + + with pytest.raises(ValueError, match="Action must have 3 elements, got 2"): + processor.action(robot_action) + + robot_action = { + "joint1.pos": 1.0, + "joint2.pos": 2.0, + "joint3.pos": 3.0, + "extra.pos": 4.0, + } + + with pytest.raises(ValueError, match="Action must have 3 elements, got 4"): + processor.action(robot_action) + + +def test_robot_to_policy_missing_motor_key_error(): + """Test error when robot action is missing expected motor keys.""" + motor_names = ["joint1", "joint2"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + robot_action = { + "joint1.pos": 1.0, + "wrong_key.pos": 2.0, + } + + with pytest.raises(KeyError): + processor.action(robot_action) + + +def test_robot_to_policy_transform_features(): + """Test feature transformation for robot to policy action processor.""" + motor_names = ["joint1", "joint2", "joint3"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + features = { + PipelineFeatureType.ACTION: { + "joint1.pos": {"type": FeatureType.ACTION, "shape": (1,)}, + "joint2.pos": {"type": FeatureType.ACTION, "shape": (1,)}, + "joint3.pos": {"type": FeatureType.ACTION, "shape": (1,)}, + "other_data": {"type": FeatureType.ENV, "shape": (1,)}, + } + } + + transformed = processor.transform_features(features) + + assert "action" in transformed[PipelineFeatureType.ACTION] + action_feature = transformed[PipelineFeatureType.ACTION]["action"] + assert action_feature.type == FeatureType.ACTION + assert action_feature.shape == (3,) + + assert "joint1.pos" in transformed[PipelineFeatureType.ACTION] + assert "joint2.pos" in transformed[PipelineFeatureType.ACTION] + assert "joint3.pos" in transformed[PipelineFeatureType.ACTION] + + assert "other_data" in transformed[PipelineFeatureType.ACTION] + + +def test_robot_to_policy_get_config(): + """Test configuration serialization.""" + motor_names = ["motor1", "motor2"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + config = processor.get_config() + assert config == {"motor_names": motor_names} + + +def test_robot_to_policy_state_dict(): + """Test state dict operations.""" + processor = RobotActionToPolicyActionProcessorStep(motor_names=["joint1"]) + + state = processor.state_dict() + assert state == {} + + processor.load_state_dict({}) + + +def test_robot_to_policy_single_motor(): + """Test with single motor.""" + processor = RobotActionToPolicyActionProcessorStep(motor_names=["single_joint"]) + + robot_action = {"single_joint.pos": 42.0} + policy_action = processor.action(robot_action) + + assert policy_action.shape == (1,) + torch.testing.assert_close(policy_action, torch.tensor([42.0])) + + +def test_policy_to_robot_basic_action_conversion(): + """Test basic policy action to robot action conversion.""" + motor_names = ["joint1", "joint2", "joint3"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + policy_action = torch.tensor([1.0, 2.0, 3.0]) + robot_action = processor.action(policy_action) + + assert isinstance(robot_action, dict) + assert len(robot_action) == 3 + + expected = { + "joint1.pos": 1.0, + "joint2.pos": 2.0, + "joint3.pos": 3.0, + } + + for key, expected_value in expected.items(): + assert key in robot_action + actual_value = robot_action[key] + if isinstance(actual_value, torch.Tensor): + actual_value = actual_value.item() + assert actual_value == pytest.approx(expected_value) + + +def test_policy_to_robot_action_conversion_preserves_order(): + """Test that motor names order corresponds to tensor indices.""" + motor_names = ["gripper", "arm", "wrist"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + policy_action = torch.tensor([5.0, 10.0, 15.0]) + robot_action = processor.action(policy_action) + + assert robot_action["gripper.pos"] == pytest.approx(5.0) + assert robot_action["arm.pos"] == pytest.approx(10.0) + assert robot_action["wrist.pos"] == pytest.approx(15.0) + + +def test_policy_to_robot_action_conversion_with_numpy_input(): + """Test conversion with numpy array input.""" + import numpy as np + + motor_names = ["joint1", "joint2"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + policy_action = np.array([1.5, 2.5]) + robot_action = processor.action(policy_action) + + assert robot_action["joint1.pos"] == pytest.approx(1.5) + assert robot_action["joint2.pos"] == pytest.approx(2.5) + + +def test_policy_to_robot_action_length_mismatch_error(): + """Test error when policy action length doesn't match motor names.""" + motor_names = ["joint1", "joint2", "joint3"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + policy_action = torch.tensor([1.0, 2.0]) + + with pytest.raises(ValueError, match="Action must have 3 elements, got 2"): + processor.action(policy_action) + + policy_action = torch.tensor([1.0, 2.0, 3.0, 4.0]) + + with pytest.raises(ValueError, match="Action must have 3 elements, got 4"): + processor.action(policy_action) + + +def test_policy_to_robot_transform_features(): + """Test feature transformation for policy to robot action processor.""" + motor_names = ["joint1", "joint2"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + features = { + PipelineFeatureType.ACTION: { + "action": {"type": FeatureType.ACTION, "shape": (2,)}, + "other_data": {"type": FeatureType.ENV, "shape": (1,)}, + } + } + + transformed = processor.transform_features(features) + + assert "joint1.pos" in transformed[PipelineFeatureType.ACTION] + assert "joint2.pos" in transformed[PipelineFeatureType.ACTION] + + for motor in motor_names: + motor_feature = transformed[PipelineFeatureType.ACTION][f"{motor}.pos"] + assert motor_feature.type == FeatureType.ACTION + assert motor_feature.shape == (1,) + + assert "action" in transformed[PipelineFeatureType.ACTION] + + assert "other_data" in transformed[PipelineFeatureType.ACTION] + + +def test_policy_to_robot_get_config(): + """Test configuration serialization.""" + motor_names = ["motor1", "motor2"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + config = processor.get_config() + assert config == {"motor_names": motor_names} + + +def test_policy_to_robot_state_dict(): + """Test state dict operations.""" + processor = PolicyActionToRobotActionProcessorStep(motor_names=["joint1"]) + + state = processor.state_dict() + assert state == {} + + processor.load_state_dict({}) + + +def test_policy_to_robot_single_motor(): + """Test with single motor.""" + processor = PolicyActionToRobotActionProcessorStep(motor_names=["single_joint"]) + + policy_action = torch.tensor([42.0]) + robot_action = processor.action(policy_action) + + assert len(robot_action) == 1 + assert robot_action["single_joint.pos"] == pytest.approx(42.0) + + +def test_robot_to_policy_registry(): + """Test RobotActionToPolicyActionProcessorStep registry.""" + assert "robot_action_to_policy_action_processor" in ProcessorStepRegistry.list() + + retrieved_class = ProcessorStepRegistry.get("robot_action_to_policy_action_processor") + assert retrieved_class is RobotActionToPolicyActionProcessorStep + + instance = retrieved_class(motor_names=["test"]) + assert isinstance(instance, RobotActionToPolicyActionProcessorStep) + assert instance.motor_names == ["test"] + + +def test_policy_to_robot_registry(): + """Test PolicyActionToRobotActionProcessorStep registry.""" + assert "policy_action_to_robot_action_processor" in ProcessorStepRegistry.list() + + retrieved_class = ProcessorStepRegistry.get("policy_action_to_robot_action_processor") + assert retrieved_class is PolicyActionToRobotActionProcessorStep + + instance = retrieved_class(motor_names=["test"]) + assert isinstance(instance, PolicyActionToRobotActionProcessorStep) + assert instance.motor_names == ["test"] + + +def test_save_and_load_robot_to_policy(): + """Test saving and loading RobotActionToPolicyActionProcessorStep.""" + motor_names = ["joint1", "joint2", "joint3"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + pipeline = DataProcessorPipeline([processor], name="TestRobotToPolicy") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Check config file exists + config_path = Path(tmp_dir) / "testrobottopolicy.json" + assert config_path.exists() + + # Load pipeline + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + "testrobottopolicy.json", + to_transition=identity_transition, + to_output=identity_transition, + ) + + assert loaded_pipeline.name == "TestRobotToPolicy" + assert len(loaded_pipeline) == 1 + + # Check loaded processor + loaded_processor = loaded_pipeline.steps[0] + assert isinstance(loaded_processor, RobotActionToPolicyActionProcessorStep) + assert loaded_processor.motor_names == motor_names + + # Test functionality after loading + robot_action = {"joint1.pos": 1.0, "joint2.pos": 2.0, "joint3.pos": 3.0} + policy_action = loaded_processor.action(robot_action) + torch.testing.assert_close(policy_action, torch.tensor([1.0, 2.0, 3.0])) + + +def test_save_and_load_policy_to_robot(): + """Test saving and loading PolicyActionToRobotActionProcessorStep.""" + motor_names = ["motor_a", "motor_b"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + pipeline = DataProcessorPipeline([processor], name="TestPolicyToRobot") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Load pipeline + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + "testpolicytorobot.json", + to_transition=identity_transition, + to_output=identity_transition, + ) + + loaded_processor = loaded_pipeline.steps[0] + assert isinstance(loaded_processor, PolicyActionToRobotActionProcessorStep) + assert loaded_processor.motor_names == motor_names + + policy_action = torch.tensor([10.0, 20.0]) + robot_action = loaded_processor.action(policy_action) + assert robot_action["motor_a.pos"] == pytest.approx(10.0) + assert robot_action["motor_b.pos"] == pytest.approx(20.0) + + +# Integration and chaining tests + + +def test_round_trip_conversion(): + """Test that robot->policy->robot conversion preserves values.""" + motor_names = ["joint1", "joint2", "joint3"] + robot_to_policy = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + policy_to_robot = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + original_robot_action = { + "joint1.pos": 1.5, + "joint2.pos": -2.3, + "joint3.pos": 0.7, + } + + policy_action = robot_to_policy.action(original_robot_action) + final_robot_action = policy_to_robot.action(policy_action) + + for key in original_robot_action: + original_val = original_robot_action[key] + final_val = final_robot_action[key] + if isinstance(final_val, torch.Tensor): + final_val = final_val.item() + assert final_val == pytest.approx(original_val, abs=1e-6) + + +def test_chained_processors_in_pipeline(): + """Test both processors chained in a pipeline.""" + motor_names = ["joint1", "joint2"] + robot_to_policy = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + policy_to_robot = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + pipeline = DataProcessorPipeline( + [robot_to_policy, policy_to_robot], + to_transition=identity_transition, + to_output=identity_transition, + ) + + assert len(pipeline.steps) == 2 + assert isinstance(pipeline.steps[0], RobotActionToPolicyActionProcessorStep) + assert isinstance(pipeline.steps[1], PolicyActionToRobotActionProcessorStep) + + +def test_robot_to_policy_features_contract(policy_feature_factory): + """Test feature transformation maintains proper typing contract.""" + processor = RobotActionToPolicyActionProcessorStep(motor_names=["j1", "j2"]) + features = { + PipelineFeatureType.ACTION: { + "j1.pos": policy_feature_factory(FeatureType.ACTION, (1,)), + "j2.pos": policy_feature_factory(FeatureType.ACTION, (1,)), + "other": policy_feature_factory(FeatureType.ENV, (3,)), + } + } + + out = processor.transform_features(features.copy()) + + assert_contract_is_typed(out) + + assert "action" in out[PipelineFeatureType.ACTION] + action_feature = out[PipelineFeatureType.ACTION]["action"] + assert action_feature.type == FeatureType.ACTION + assert action_feature.shape == (2,) + + +def test_policy_to_robot_features_contract(policy_feature_factory): + """Test feature transformation maintains proper typing contract.""" + processor = PolicyActionToRobotActionProcessorStep(motor_names=["m1", "m2", "m3"]) + features = { + PipelineFeatureType.ACTION: { + "action": policy_feature_factory(FeatureType.ACTION, (3,)), + "other": policy_feature_factory(FeatureType.ENV, (1,)), + } + } + + out = processor.transform_features(features.copy()) + + assert_contract_is_typed(out) + + for motor in ["m1", "m2", "m3"]: + key = f"{motor}.pos" + assert key in out[PipelineFeatureType.ACTION] + motor_feature = out[PipelineFeatureType.ACTION][key] + assert motor_feature.type == FeatureType.ACTION + assert motor_feature.shape == (1,) + + +def test_empty_motor_names_list(): + """Test behavior with empty motor names list.""" + processor = RobotActionToPolicyActionProcessorStep(motor_names=[]) + + robot_action = {} + policy_action = processor.action(robot_action) + + assert isinstance(policy_action, torch.Tensor) + assert policy_action.shape == (0,) + + +def test_empty_motor_names_list_policy_to_robot(): + """Test PolicyActionToRobotActionProcessorStep with empty motor names.""" + processor = PolicyActionToRobotActionProcessorStep(motor_names=[]) + + policy_action = torch.tensor([]) + robot_action = processor.action(policy_action) + + assert isinstance(robot_action, dict) + assert len(robot_action) == 0 + + +def test_very_long_motor_names(): + """Test with many motor names.""" + motor_names = [f"joint_{i}" for i in range(100)] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + robot_action = {f"joint_{i}.pos": float(i) for i in range(100)} + policy_action = processor.action(robot_action) + + assert policy_action.shape == (100,) + expected = torch.tensor([float(i) for i in range(100)]) + torch.testing.assert_close(policy_action, expected) + + +def test_special_characters_in_motor_names(): + """Test with special characters in motor names.""" + motor_names = ["motor-1", "motor_2", "motor.3"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + robot_action = { + "motor-1.pos": 1.0, + "motor_2.pos": 2.0, + "motor.3.pos": 3.0, + } + + policy_action = processor.action(robot_action) + torch.testing.assert_close(policy_action, torch.tensor([1.0, 2.0, 3.0])) diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index 229d57f9..5f2b4857 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -19,33 +19,25 @@ from pathlib import Path import numpy as np import torch -from lerobot.configs.types import FeatureType -from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey +from lerobot.configs.types import FeatureType, PipelineFeatureType +from lerobot.processor import ( + DataProcessorPipeline, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TransitionKey, +) +from lerobot.processor.converters import create_transition, identity_transition +from lerobot.processor.rename_processor import rename_stats from tests.conftest import assert_contract_is_typed -def create_transition( - observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None -): - """Helper to create an EnvTransition dictionary.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info, - TransitionKey.COMPLEMENTARY_DATA: complementary_data, - } - - def test_basic_renaming(): """Test basic key renaming functionality.""" rename_map = { "old_key1": "new_key1", "old_key2": "new_key2", } - processor = RenameProcessor(rename_map=rename_map) + processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { "old_key1": torch.tensor([1.0, 2.0]), @@ -73,7 +65,7 @@ def test_basic_renaming(): def test_empty_rename_map(): """Test processor with empty rename map (should pass through unchanged).""" - processor = RenameProcessor(rename_map={}) + processor = RenameObservationsProcessorStep(rename_map={}) observation = { "key1": torch.tensor([1.0]), @@ -92,9 +84,9 @@ def test_empty_rename_map(): def test_none_observation(): """Test processor with None observation.""" - processor = RenameProcessor(rename_map={"old": "new"}) + processor = RenameObservationsProcessorStep(rename_map={"old": "new"}) - transition = create_transition() + transition = create_transition(observation={}) result = processor(transition) # Should return transition unchanged @@ -107,7 +99,7 @@ def test_overlapping_rename(): "a": "b", "b": "c", # This creates a potential conflict } - processor = RenameProcessor(rename_map=rename_map) + processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { "a": 1, @@ -132,7 +124,7 @@ def test_partial_rename(): "observation.state": "observation.proprio_state", "pixels": "observation.image", } - processor = RenameProcessor(rename_map=rename_map) + processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { "observation.state": torch.randn(10), @@ -162,15 +154,15 @@ def test_get_config(): "old1": "new1", "old2": "new2", } - processor = RenameProcessor(rename_map=rename_map) + processor = RenameObservationsProcessorStep(rename_map=rename_map) config = processor.get_config() assert config == {"rename_map": rename_map} def test_state_dict(): - """Test state dict (should be empty for RenameProcessor).""" - processor = RenameProcessor(rename_map={"old": "new"}) + """Test state dict (should be empty for RenameProcessorStep).""" + processor = RenameObservationsProcessorStep(rename_map={"old": "new"}) state = processor.state_dict() assert state == {} @@ -185,9 +177,11 @@ def test_integration_with_robot_processor(): "agent_pos": "observation.state", "pixels": "observation.image", } - rename_processor = RenameProcessor(rename_map=rename_map) + rename_processor = RenameObservationsProcessorStep(rename_map=rename_map) - pipeline = RobotProcessor([rename_processor]) + pipeline = DataProcessorPipeline( + [rename_processor], to_transition=identity_transition, to_output=identity_transition + ) observation = { "agent_pos": np.array([1.0, 2.0, 3.0]), @@ -219,30 +213,37 @@ def test_save_and_load_pretrained(): "old_state": "observation.state", "old_image": "observation.image", } - processor = RenameProcessor(rename_map=rename_map) - pipeline = RobotProcessor([processor], name="TestRenameProcessor") + processor = RenameObservationsProcessorStep(rename_map=rename_map) + pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep") with tempfile.TemporaryDirectory() as tmp_dir: # Save pipeline pipeline.save_pretrained(tmp_dir) # Check files were created - config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor" + config_path = ( + Path(tmp_dir) / "testrenameprocessorstep.json" + ) # Based on name="TestRenameProcessorStep" assert config_path.exists() - # No state files should be created for RenameProcessor + # No state files should be created for RenameProcessorStep state_files = list(Path(tmp_dir).glob("*.safetensors")) assert len(state_files) == 0 # Load pipeline - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="testrenameprocessorstep.json", + to_transition=identity_transition, + to_output=identity_transition, + ) - assert loaded_pipeline.name == "TestRenameProcessor" + assert loaded_pipeline.name == "TestRenameProcessorStep" assert len(loaded_pipeline) == 1 # Check that loaded processor works correctly loaded_processor = loaded_pipeline.steps[0] - assert isinstance(loaded_processor, RenameProcessor) + assert isinstance(loaded_processor, RenameObservationsProcessorStep) assert loaded_processor.rename_map == rename_map # Test functionality after loading @@ -259,24 +260,26 @@ def test_save_and_load_pretrained(): def test_registry_functionality(): - """Test that RenameProcessor is properly registered.""" + """Test that RenameProcessorStep is properly registered.""" # Check that it's registered - assert "rename_processor" in ProcessorStepRegistry.list() + assert "rename_observations_processor" in ProcessorStepRegistry.list() # Get from registry - retrieved_class = ProcessorStepRegistry.get("rename_processor") - assert retrieved_class is RenameProcessor + retrieved_class = ProcessorStepRegistry.get("rename_observations_processor") + assert retrieved_class is RenameObservationsProcessorStep # Create instance from registry instance = retrieved_class(rename_map={"old": "new"}) - assert isinstance(instance, RenameProcessor) + assert isinstance(instance, RenameObservationsProcessorStep) assert instance.rename_map == {"old": "new"} def test_registry_based_save_load(): """Test save/load using registry name instead of module path.""" - processor = RenameProcessor(rename_map={"key1": "renamed_key1"}) - pipeline = RobotProcessor([processor]) + processor = RenameObservationsProcessorStep(rename_map={"key1": "renamed_key1"}) + pipeline = DataProcessorPipeline( + [processor], to_transition=identity_transition, to_output=identity_transition + ) with tempfile.TemporaryDirectory() as tmp_dir: # Save and load @@ -285,24 +288,26 @@ def test_registry_based_save_load(): # Verify config uses registry name import json - with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor" + with open(Path(tmp_dir) / "dataprocessorpipeline.json") as f: # Default name is "RobotProcessor" config = json.load(f) assert "registry_name" in config["steps"][0] - assert config["steps"][0]["registry_name"] == "rename_processor" + assert config["steps"][0]["registry_name"] == "rename_observations_processor" assert "class" not in config["steps"][0] # Should use registry, not module path # Load should work - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json" + ) loaded_processor = loaded_pipeline.steps[0] - assert isinstance(loaded_processor, RenameProcessor) + assert isinstance(loaded_processor, RenameObservationsProcessorStep) assert loaded_processor.rename_map == {"key1": "renamed_key1"} def test_chained_rename_processors(): - """Test multiple RenameProcessors in a pipeline.""" + """Test multiple RenameProcessorSteps in a pipeline.""" # First processor: rename raw keys to intermediate format - processor1 = RenameProcessor( + processor1 = RenameObservationsProcessorStep( rename_map={ "pos": "agent_position", "img": "camera_image", @@ -310,14 +315,16 @@ def test_chained_rename_processors(): ) # Second processor: rename to final format - processor2 = RenameProcessor( + processor2 = RenameObservationsProcessorStep( rename_map={ "agent_position": "observation.state", "camera_image": "observation.image", } ) - pipeline = RobotProcessor([processor1, processor2]) + pipeline = DataProcessorPipeline( + [processor1, processor2], to_transition=identity_transition, to_output=identity_transition + ) observation = { "pos": np.array([1.0, 2.0]), @@ -353,7 +360,7 @@ def test_nested_observation_rename(): "observation.images.right": "observation.camera.right_view", "observation.proprio": "observation.proprioception", } - processor = RenameProcessor(rename_map=rename_map) + processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { "observation.images.left": torch.randn(3, 64, 64), @@ -383,7 +390,7 @@ def test_nested_observation_rename(): def test_value_types_preserved(): """Test that various value types are preserved during renaming.""" rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"} - processor = RenameProcessor(rename_map=rename_map) + processor = RenameObservationsProcessorStep(rename_map=rename_map) tensor_value = torch.randn(3, 3) array_value = np.random.rand(2, 2) @@ -410,58 +417,87 @@ def test_value_types_preserved(): assert processed_obs["old_list"] == [1, 2, 3] -def test_feature_contract_basic_renaming(policy_feature_factory): - processor = RenameProcessor(rename_map={"a": "x", "b": "y"}) +def test_features_basic_renaming(policy_feature_factory): + processor = RenameObservationsProcessorStep(rename_map={"a": "x", "b": "y"}) features = { - "a": policy_feature_factory(FeatureType.STATE, (2,)), - "b": policy_feature_factory(FeatureType.ACTION, (3,)), - "c": policy_feature_factory(FeatureType.ENV, (1,)), + PipelineFeatureType.OBSERVATION: { + "a": policy_feature_factory(FeatureType.VISUAL, (2,)), + "b": policy_feature_factory(FeatureType.VISUAL, (3,)), + "c": policy_feature_factory(FeatureType.VISUAL, (1,)), + }, } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) # Values preserved and typed - assert out["x"] == features["a"] - assert out["y"] == features["b"] - assert out["c"] == features["c"] + assert out[PipelineFeatureType.OBSERVATION]["x"] == features[PipelineFeatureType.OBSERVATION]["a"] + assert out[PipelineFeatureType.OBSERVATION]["y"] == features[PipelineFeatureType.OBSERVATION]["b"] + assert out[PipelineFeatureType.OBSERVATION]["c"] == features[PipelineFeatureType.OBSERVATION]["c"] assert_contract_is_typed(out) # Input not mutated - assert set(features) == {"a", "b", "c"} + assert set(features[PipelineFeatureType.OBSERVATION]) == {"a", "b", "c"} -def test_feature_contract_overlapping_keys(policy_feature_factory): +def test_features_overlapping_keys(policy_feature_factory): # Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c' - processor = RenameProcessor(rename_map={"a": "b", "b": "c"}) + processor = RenameObservationsProcessorStep(rename_map={"a": "b", "b": "c"}) features = { - "a": policy_feature_factory(FeatureType.STATE, (1,)), - "b": policy_feature_factory(FeatureType.STATE, (2,)), + PipelineFeatureType.OBSERVATION: { + "a": policy_feature_factory(FeatureType.VISUAL, (1,)), + "b": policy_feature_factory(FeatureType.VISUAL, (2,)), + }, } - out = processor.feature_contract(features) + out = processor.transform_features(features) - assert set(out) == {"b", "c"} - assert out["b"] == features["a"] # 'a' renamed to'b' - assert out["c"] == features["b"] # 'b' renamed to 'c' + assert set(out[PipelineFeatureType.OBSERVATION]) == {"b", "c"} + assert ( + out[PipelineFeatureType.OBSERVATION]["b"] == features[PipelineFeatureType.OBSERVATION]["a"] + ) # 'a' renamed to'b' + assert ( + out[PipelineFeatureType.OBSERVATION]["c"] == features[PipelineFeatureType.OBSERVATION]["b"] + ) # 'b' renamed to 'c' assert_contract_is_typed(out) -def test_feature_contract_chained_processors(policy_feature_factory): +def test_features_chained_processors(policy_feature_factory): # Chain two rename processors at the contract level - processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"}) - processor2 = RenameProcessor( + processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"}) + processor2 = RenameObservationsProcessorStep( rename_map={"agent_position": "observation.state", "camera_image": "observation.image"} ) - pipeline = RobotProcessor([processor1, processor2]) + pipeline = DataProcessorPipeline([processor1, processor2]) spec = { - "pos": policy_feature_factory(FeatureType.STATE, (7,)), - "img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "extra": policy_feature_factory(FeatureType.ENV, (1,)), + PipelineFeatureType.OBSERVATION: { + "pos": policy_feature_factory(FeatureType.VISUAL, (7,)), + "img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "extra": policy_feature_factory(FeatureType.VISUAL, (1,)), + }, } - out = pipeline.feature_contract(initial_features=spec) + out = pipeline.transform_features(initial_features=spec) - assert set(out) == {"observation.state", "observation.image", "extra"} - assert out["observation.state"] == spec["pos"] - assert out["observation.image"] == spec["img"] - assert out["extra"] == spec["extra"] + assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"} + assert ( + out[PipelineFeatureType.OBSERVATION]["observation.state"] + == spec[PipelineFeatureType.OBSERVATION]["pos"] + ) + assert ( + out[PipelineFeatureType.OBSERVATION]["observation.image"] + == spec[PipelineFeatureType.OBSERVATION]["img"] + ) + assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"] assert_contract_is_typed(out) + + +def test_rename_stats_basic(): + orig = { + "observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}, + "action": {"mean": np.array([0.0])}, + } + mapping = {"observation.state": "observation.robot_state"} + renamed = rename_stats(orig, mapping) + assert "observation.robot_state" in renamed and "observation.state" not in renamed + # Ensure deep copy: mutate original and verify renamed unaffected + orig["observation.state"]["mean"][0] = 42.0 + assert renamed["observation.robot_state"]["mean"][0] != 42.0 diff --git a/tests/processor/test_sac_processor.py b/tests/processor/test_sac_processor.py new file mode 100644 index 00000000..7cbcb188 --- /dev/null +++ b/tests/processor/test_sac_processor.py @@ -0,0 +1,414 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for SAC policy processor.""" + +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_STATE +from lerobot.policies.sac.configuration_sac import SACConfig +from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DataProcessorPipeline, + DeviceProcessorStep, + NormalizerProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +def create_default_config(): + """Create a default SAC configuration for testing.""" + config = SACConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + config.device = "cpu" + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)}, + ACTION: {"min": torch.full((5,), -1.0), "max": torch.ones(5)}, + } + + +def test_make_sac_processor_basic(): + """Test basic creation of SAC processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 4 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_sac_processor_normalization_modes(): + """Test that SAC processor correctly handles different normalization modes.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Create test data + observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization + action = torch.rand(5) * 2 - 1 # Range [-1, 1] + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is normalized and batched + # State should be mean-std normalized + # Action should be min-max normalized to [-1, 1] + assert processed[OBS_STATE].shape == (1, 10) + assert processed[TransitionKey.ACTION.value].shape == (1, 5) + + # Process action through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is unnormalized (but still batched) + assert postprocessed.shape == (1, 5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_sac_processor_cuda(): + """Test SAC processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = {OBS_STATE: torch.randn(10)} + action = torch.randn(5) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + # Process through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is back on CPU + assert postprocessed.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_sac_processor_accelerate_scenario(): + """Test SAC processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU + device = torch.device("cuda:0") + observation = {OBS_STATE: torch.randn(10).to(device)} + action = torch.randn(5).to(device) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_sac_processor_multi_gpu(): + """Test SAC processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = {OBS_STATE: torch.randn(10).to(device)} + action = torch.randn(5).to(device) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_sac_processor_without_stats(): + """Test SAC processor creation without dataset statistics.""" + config = create_default_config() + + preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + # Process should still work + observation = {OBS_STATE: torch.randn(10)} + action = torch.randn(5) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed is not None + + +def test_sac_processor_save_and_load(): + """Test saving and loading SAC processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save preprocessor + preprocessor.save_pretrained(tmpdir) + + # Load preprocessor + loaded_preprocessor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="policy_preprocessor.json" + ) + + # Test that loaded processor works + observation = {OBS_STATE: torch.randn(10)} + action = torch.randn(5) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 10) + assert processed[TransitionKey.ACTION.value].shape == (1, 5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_sac_processor_mixed_precision(): + """Test SAC processor with mixed precision.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + # Create processor + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Replace DeviceProcessorStep with one that uses float16 + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + elif isinstance(step, NormalizerProcessorStep): + # Update normalizer to use the same device as the device processor + norm_step = step # Now type checker knows this is NormalizerProcessorStep + modified_steps.append( + NormalizerProcessorStep( + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, + device=config.device, + dtype=torch.float16, # Match the float16 dtype + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Create test data + observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} + action = torch.randn(5, dtype=torch.float32) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is converted to float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 + + +def test_sac_processor_batch_data(): + """Test SAC processor with batched data.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Test with batched data + batch_size = 32 + observation = {OBS_STATE: torch.randn(batch_size, 10)} + action = torch.randn(batch_size, 5) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that batch dimension is preserved + assert processed[OBS_STATE].shape == (batch_size, 10) + assert processed[TransitionKey.ACTION.value].shape == (batch_size, 5) + + +def test_sac_processor_edge_cases(): + """Test SAC processor with edge cases.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Test with observation that has no state key but still exists + observation = {"observation.dummy": torch.randn(1)} # Some dummy observation to pass validation + action = torch.randn(5) + batch = {TransitionKey.ACTION.value: action, **observation} + processed = preprocessor(batch) + # observation.state wasn't in original, so it won't be in processed + assert OBS_STATE not in processed + assert processed[TransitionKey.ACTION.value].shape == (1, 5) + + # Test with zero action (representing "null" action) + transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=torch.zeros(5)) + batch = transition_to_batch(transition) + processed = preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 10) + # Action should be present and batched, even if it's zeros + assert processed[TransitionKey.ACTION.value].shape == (1, 5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_sac_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_sac_pre_post_processors( + config, + stats, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + norm_step = step # Now type checker knows this is NormalizerProcessorStep + modified_steps.append( + NormalizerProcessorStep( + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data + observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} # Start with float32 + action = torch.randn(5, dtype=torch.float32) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 diff --git a/tests/processor/test_smolvla_processor.py b/tests/processor/test_smolvla_processor.py new file mode 100644 index 00000000..ce162c10 --- /dev/null +++ b/tests/processor/test_smolvla_processor.py @@ -0,0 +1,459 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for SmolVLA policy processor.""" + +from unittest.mock import patch + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature +from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig +from lerobot.policies.smolvla.processor_smolvla import ( + SmolVLANewLineProcessor, + make_smolvla_pre_post_processors, +) +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + EnvTransition, + NormalizerProcessorStep, + ProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +class MockTokenizerProcessorStep(ProcessorStep): + """Mock tokenizer processor step for testing.""" + + def __init__(self, *args, **kwargs): + # Accept any arguments to mimic the real TokenizerProcessorStep interface + pass + + def __call__(self, transition: EnvTransition) -> EnvTransition: + # Pass through transition unchanged + return transition + + def transform_features(self, features): + # Pass through features unchanged + return features + + +def create_default_config(): + """Create a default SmolVLA configuration for testing.""" + config = SmolVLAConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + config.device = "cpu" + config.vlm_model_name = "HuggingFaceTB/SmolVLM-Instruct" + config.pad_language_to = "max_length" + config.tokenizer_max_length = 100 + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(8), "std": torch.ones(8)}, + OBS_IMAGE: {}, # No normalization for images + ACTION: {"min": torch.full((7,), -1.0), "max": torch.ones(7)}, + } + + +def test_make_smolvla_processor_basic(): + """Test basic creation of SmolVLA processor.""" + config = create_default_config() + stats = create_default_stats() + + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): + preprocessor, postprocessor = make_smolvla_pre_post_processors( + config, + stats, + ) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 6 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], SmolVLANewLineProcessor) + # Step 3 would be TokenizerProcessorStep but it's mocked + assert isinstance(preprocessor.steps[4], DeviceProcessorStep) + assert isinstance(preprocessor.steps[5], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_smolvla_newline_processor_single_task(): + """Test SmolVLANewLineProcessor with single task string.""" + processor = SmolVLANewLineProcessor() + + # Test with task that doesn't have newline + transition = create_transition(complementary_data={"task": "test task"}) + result = processor(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n" + + # Test with task that already has newline + transition = create_transition(complementary_data={"task": "test task\n"}) + result = processor(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n" + + +def test_smolvla_newline_processor_list_of_tasks(): + """Test SmolVLANewLineProcessor with list of task strings.""" + processor = SmolVLANewLineProcessor() + + # Test with list of tasks + tasks = ["task1", "task2\n", "task3"] + transition = create_transition(complementary_data={"task": tasks}) + result = processor(transition) + expected = ["task1\n", "task2\n", "task3\n"] + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == expected + + +def test_smolvla_newline_processor_empty_transition(): + """Test SmolVLANewLineProcessor with empty transition.""" + processor = SmolVLANewLineProcessor() + + # Test with no complementary_data + transition = create_transition() + result = processor(transition) + assert result == transition + + # Test with complementary_data but no task + transition = create_transition(complementary_data={"other": "data"}) + result = processor(transition) + assert result == transition + + # Test with None task + transition = create_transition(complementary_data={"task": None}) + result = processor(transition) + assert result == transition + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_smolvla_processor_cuda(): + """Test SmolVLA processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + # Mock the tokenizer processor to act as pass-through + class MockTokenizerProcessorStep(ProcessorStep): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, transition): + return transition + + def state_dict(self): + return {} + + def load_state_dict(self, state): + pass + + def reset(self): + pass + + def get_config(self): + return {"tokenizer_name": "HuggingFaceTB/SmolVLM-Instruct"} + + def transform_features(self, features): + return features + + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): + preprocessor, postprocessor = make_smolvla_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = { + OBS_STATE: torch.randn(8), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(7) + transition = create_transition(observation, action, complementary_data={"task": "test task"}) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_smolvla_processor_accelerate_scenario(): + """Test SmolVLA processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + # Mock the tokenizer processor to act as pass-through + class MockTokenizerProcessorStep(ProcessorStep): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, transition): + return transition + + def state_dict(self): + return {} + + def load_state_dict(self, state): + pass + + def reset(self): + pass + + def get_config(self): + return {"tokenizer_name": "HuggingFaceTB/SmolVLM-Instruct"} + + def transform_features(self, features): + return features + + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): + preprocessor, postprocessor = make_smolvla_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU and batched + device = torch.device("cuda:0") + observation = { + OBS_STATE: torch.randn(1, 8).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 7).to(device) + transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_smolvla_processor_multi_gpu(): + """Test SmolVLA processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + # Mock the tokenizer processor to act as pass-through + class MockTokenizerProcessorStep(ProcessorStep): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, transition): + return transition + + def state_dict(self): + return {} + + def load_state_dict(self, state): + pass + + def reset(self): + pass + + def get_config(self): + return {"tokenizer_name": "HuggingFaceTB/SmolVLM-Instruct"} + + def transform_features(self, features): + return features + + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): + preprocessor, postprocessor = make_smolvla_pre_post_processors( + config, + stats, + ) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = { + OBS_STATE: torch.randn(1, 8).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 7).to(device) + transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_smolvla_processor_without_stats(): + """Test SmolVLA processor creation without dataset statistics.""" + config = create_default_config() + + # Mock the tokenizer processor + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): + preprocessor, postprocessor = make_smolvla_pre_post_processors( + config, + dataset_stats=None, + ) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + +def test_smolvla_newline_processor_state_dict(): + """Test SmolVLANewLineProcessor state dict methods.""" + processor = SmolVLANewLineProcessor() + + # Test state_dict (should be empty) + state = processor.state_dict() + assert state == {} + + # Test load_state_dict (should do nothing) + processor.load_state_dict({}) + + # Test reset (should do nothing) + processor.reset() + + # Test get_config + config = processor.get_config() + assert config == {} + + +def test_smolvla_newline_processor_transform_features(): + """Test SmolVLANewLineProcessor transform_features method.""" + processor = SmolVLANewLineProcessor() + + # Test transform_features + features = { + PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + } + result = processor.transform_features(features) + assert result == features # Should return unchanged + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_smolvla_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): + preprocessor, _ = make_smolvla_pre_post_processors( + config, + stats, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration (SmolVLA has NormalizerProcessorStep at index 5) + normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(8, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(7, dtype=torch.float32) + transition = create_transition( + observation, action, complementary_data={"task": "test bfloat16 adaptation"} + ) + + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/processor/test_tdmpc_processor.py b/tests/processor/test_tdmpc_processor.py new file mode 100644 index 00000000..20979fd6 --- /dev/null +++ b/tests/processor/test_tdmpc_processor.py @@ -0,0 +1,467 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for TDMPC policy processor.""" + +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig +from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DataProcessorPipeline, + DeviceProcessorStep, + NormalizerProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +def create_default_config(): + """Create a default TDMPC configuration for testing.""" + config = TDMPCConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(12,)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + config.device = "cpu" + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(12), "std": torch.ones(12)}, + OBS_IMAGE: {}, # No normalization for images + ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)}, + } + + +def test_make_tdmpc_processor_basic(): + """Test basic creation of TDMPC processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 4 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_tdmpc_processor_normalization(): + """Test that TDMPC processor correctly normalizes and unnormalizes data.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Create test data + observation = { + OBS_STATE: torch.randn(12), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is processed and batched + assert processed[OBS_STATE].shape == (1, 12) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 6) + + # Process action through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is unnormalized (but still batched) + assert postprocessed.shape == (1, 6) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_tdmpc_processor_cuda(): + """Test TDMPC processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = { + OBS_STATE: torch.randn(12), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + # Process through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is back on CPU + assert postprocessed.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_tdmpc_processor_accelerate_scenario(): + """Test TDMPC processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU + device = torch.device("cuda:0") + observation = { + OBS_STATE: torch.randn(12).to(device), + OBS_IMAGE: torch.randn(3, 224, 224).to(device), + } + action = torch.randn(6).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_tdmpc_processor_multi_gpu(): + """Test TDMPC processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = { + OBS_STATE: torch.randn(12).to(device), + OBS_IMAGE: torch.randn(3, 224, 224).to(device), + } + action = torch.randn(6).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_tdmpc_processor_without_stats(): + """Test TDMPC processor creation without dataset statistics.""" + config = create_default_config() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + # Process should still work + observation = { + OBS_STATE: torch.randn(12), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed is not None + + +def test_tdmpc_processor_save_and_load(): + """Test saving and loading TDMPC processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save preprocessor + preprocessor.save_pretrained(tmpdir) + + # Load preprocessor + loaded_preprocessor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="policy_preprocessor.json" + ) + + # Test that loaded processor works + observation = { + OBS_STATE: torch.randn(12), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 12) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 6) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_tdmpc_processor_mixed_precision(): + """Test TDMPC processor with mixed precision.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + # Create processor + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Replace DeviceProcessorStep with one that uses float16 + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + elif isinstance(step, NormalizerProcessorStep): + # Update normalizer to use the same device as the device processor + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float16, # Match the float16 dtype + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Create test data + observation = { + OBS_STATE: torch.randn(12, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(6, dtype=torch.float32) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is converted to float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[OBS_IMAGE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 + + +def test_tdmpc_processor_batch_data(): + """Test TDMPC processor with batched data.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Test with batched data + batch_size = 64 + observation = { + OBS_STATE: torch.randn(batch_size, 12), + OBS_IMAGE: torch.randn(batch_size, 3, 224, 224), + } + action = torch.randn(batch_size, 6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that batch dimension is preserved + assert processed[OBS_STATE].shape == (batch_size, 12) + assert processed[OBS_IMAGE].shape == (batch_size, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (batch_size, 6) + + +def test_tdmpc_processor_edge_cases(): + """Test TDMPC processor with edge cases.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Test with only state observation (no image) + observation = {OBS_STATE: torch.randn(12)} + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 12) + assert OBS_IMAGE not in processed + + # Test with only image observation (no state) + observation = {OBS_IMAGE: torch.randn(3, 224, 224)} + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert OBS_STATE not in processed + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_tdmpc_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(12, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(6, dtype=torch.float32) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py new file mode 100644 index 00000000..b3b0c9bf --- /dev/null +++ b/tests/processor/test_tokenizer_processor.py @@ -0,0 +1,1029 @@ +""" +Tests for the TokenizerProcessorStep class. +""" + +import tempfile +from unittest.mock import patch + +import pytest +import torch + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.constants import OBS_LANGUAGE +from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey +from lerobot.processor.converters import create_transition, identity_transition +from tests.utils import require_package + + +class MockTokenizer: + """Mock tokenizer for testing that mimics transformers tokenizer interface.""" + + def __init__(self, vocab_size: int = 1000): + self.vocab_size = vocab_size + + def __call__( + self, + text: str | list[str], + max_length: int = 512, + truncation: bool = True, + padding: str = "max_length", + padding_side: str = "right", + return_tensors: str = "pt", + **kwargs, + ) -> dict[str, torch.Tensor]: + """Mock tokenization that returns deterministic tokens based on text.""" + if isinstance(text, str): + texts = [text] + else: + texts = text + + batch_size = len(texts) + + # Create mock input_ids and attention_mask + input_ids = torch.zeros(batch_size, max_length, dtype=torch.long) + attention_mask = torch.zeros(batch_size, max_length, dtype=torch.long) + + for i, txt in enumerate(texts): + # Simple mock: use hash of text to generate deterministic tokens + text_hash = hash(txt) % self.vocab_size + seq_len = min(len(txt.split()), max_length) + + # Fill input_ids with simple pattern based on text + for j in range(seq_len): + input_ids[i, j] = (text_hash + j) % self.vocab_size + + # Set attention mask for non-padded positions + attention_mask[i, :seq_len] = 1 + + result = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + # Return single sequence for single input to match transformers behavior + if len(texts) == 1: + result = {k: v.squeeze(0) for k, v in result.items()} + + return result + + +@pytest.fixture +def mock_tokenizer(): + """Provide a mock tokenizer for testing.""" + return MockTokenizer(vocab_size=100) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_basic_tokenization(mock_auto_tokenizer): + """Test basic string tokenization functionality.""" + # Mock AutoTokenizer.from_pretrained to return our mock tokenizer + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "pick up the red cube"}, + ) + + result = processor(transition) + + # Check that original task is preserved + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick up the red cube" + + # Check that tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + # Check token structure + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert isinstance(tokens, torch.Tensor) + assert isinstance(attention_mask, torch.Tensor) + assert tokens.shape == (10,) + assert attention_mask.shape == (10,) + + +@require_package("transformers") +def test_basic_tokenization_with_tokenizer_object(): + """Test basic string tokenization functionality using tokenizer object directly.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "pick up the red cube"}, + ) + + result = processor(transition) + + # Check that original task is preserved + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick up the red cube" + + # Check that tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + # Check token structure + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert isinstance(tokens, torch.Tensor) + assert isinstance(attention_mask, torch.Tensor) + assert tokens.shape == (10,) + assert attention_mask.shape == (10,) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_list_of_strings_tokenization(mock_auto_tokenizer): + """Test tokenization of a list of strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": ["pick up cube", "place on table"]}, + ) + + result = processor(transition) + + # Check that original task is preserved + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["pick up cube", "place on table"] + + # Check that tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert tokens.shape == (2, 8) # batch_size=2, seq_len=8 + assert attention_mask.shape == (2, 8) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_custom_keys(mock_auto_tokenizer): + """Test using custom task_key.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", task_key="instruction", max_length=5) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"instruction": "move forward"}, + ) + + result = processor(transition) + + # Check that tokens are stored in observation regardless of task_key + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + assert tokens.shape == (5,) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_none_complementary_data(mock_auto_tokenizer): + """Test handling of None complementary_data.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + transition = create_transition(observation={}, complementary_data=None) + + # create_transition converts None complementary_data to empty dict, so task key is missing + with pytest.raises(KeyError, match="task"): + processor(transition) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_missing_task_key(mock_auto_tokenizer): + """Test handling when task key is missing.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + transition = create_transition(observation={}, complementary_data={"other_field": "some value"}) + + with pytest.raises(KeyError, match="task"): + processor(transition) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_none_task_value(mock_auto_tokenizer): + """Test handling when task value is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + transition = create_transition(observation={}, complementary_data={"task": None}) + + with pytest.raises(ValueError, match="Task extracted from Complementary data is None"): + processor(transition) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_unsupported_task_type(mock_auto_tokenizer): + """Test handling of unsupported task types.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + # Test with integer task - get_task returns None, observation raises ValueError + transition = create_transition(observation={}, complementary_data={"task": 123}) + + with pytest.raises(ValueError, match="Task cannot be None"): + processor(transition) + + # Test with mixed list - get_task returns None, observation raises ValueError + transition = create_transition(observation={}, complementary_data={"task": ["text", 123, "more text"]}) + + with pytest.raises(ValueError, match="Task cannot be None"): + processor(transition) + + +@require_package("transformers") +def test_no_tokenizer_error(): + """Test that ValueError is raised when neither tokenizer nor tokenizer_name is provided.""" + with pytest.raises(ValueError, match="Either 'tokenizer' or 'tokenizer_name' must be provided"): + TokenizerProcessorStep() + + +@require_package("transformers") +def test_invalid_tokenizer_name_error(): + """Test that error is raised when invalid tokenizer_name is provided.""" + with patch("lerobot.processor.tokenizer_processor.AutoTokenizer") as mock_auto_tokenizer: + # Mock import error + mock_auto_tokenizer.from_pretrained.side_effect = Exception("Model not found") + + with pytest.raises(Exception, match="Model not found"): + TokenizerProcessorStep(tokenizer_name="invalid-tokenizer") + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_get_config_with_tokenizer_name(mock_auto_tokenizer): + """Test configuration serialization when using tokenizer_name.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep( + tokenizer_name="test-tokenizer", + max_length=256, + task_key="instruction", + padding="longest", + truncation=False, + ) + + config = processor.get_config() + + expected = { + "tokenizer_name": "test-tokenizer", + "max_length": 256, + "task_key": "instruction", + "padding_side": "right", + "padding": "longest", + "truncation": False, + } + + assert config == expected + + +@require_package("transformers") +def test_get_config_with_tokenizer_object(): + """Test configuration serialization when using tokenizer object.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + + processor = TokenizerProcessorStep( + tokenizer=mock_tokenizer, + max_length=256, + task_key="instruction", + padding="longest", + truncation=False, + ) + + config = processor.get_config() + + # tokenizer_name should not be in config when tokenizer object is used + expected = { + "max_length": 256, + "task_key": "instruction", + "padding_side": "right", + "padding": "longest", + "truncation": False, + } + + assert config == expected + assert "tokenizer_name" not in config + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_state_dict_methods(mock_auto_tokenizer): + """Test state_dict and load_state_dict methods.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + # Should return empty dict + state = processor.state_dict() + assert state == {} + + # load_state_dict should not raise error + processor.load_state_dict({}) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_reset_method(mock_auto_tokenizer): + """Test reset method.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + # Should not raise error + processor.reset() + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_integration_with_robot_processor(mock_auto_tokenizer): + """Test integration with RobotProcessor.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6) + robot_processor = DataProcessorPipeline( + [tokenizer_processor], to_transition=identity_transition, to_output=identity_transition + ) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "test task"}, + ) + + result = robot_processor(transition) + + # Check that observation exists and tokenization was applied + assert TransitionKey.OBSERVATION in result + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert tokens.shape == (6,) + assert attention_mask.shape == (6,) + + # Check that other data is preserved + assert torch.equal( + result[TransitionKey.OBSERVATION]["state"], transition[TransitionKey.OBSERVATION]["state"] + ) + assert torch.equal(result[TransitionKey.ACTION], transition[TransitionKey.ACTION]) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer): + """Test saving and loading processor with tokenizer_name.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + original_processor = TokenizerProcessorStep( + tokenizer_name="test-tokenizer", max_length=32, task_key="instruction" + ) + + robot_processor = DataProcessorPipeline( + [original_processor], to_transition=identity_transition, to_output=identity_transition + ) + + with tempfile.TemporaryDirectory() as temp_dir: + # Save processor + robot_processor.save_pretrained(temp_dir) + + # Load processor - tokenizer will be recreated from saved config + loaded_processor = DataProcessorPipeline.from_pretrained( + temp_dir, + config_filename="dataprocessorpipeline.json", + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Test that loaded processor works + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"instruction": "test instruction"}, + ) + + result = loaded_processor(transition) + assert TransitionKey.OBSERVATION in result + assert f"{OBS_LANGUAGE}.tokens" in result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.attention_mask" in result[TransitionKey.OBSERVATION] + + +@require_package("transformers") +def test_save_and_load_pretrained_with_tokenizer_object(): + """Test saving and loading processor with tokenizer object using overrides.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + + original_processor = TokenizerProcessorStep( + tokenizer=mock_tokenizer, max_length=32, task_key="instruction" + ) + + robot_processor = DataProcessorPipeline( + [original_processor], to_transition=identity_transition, to_output=identity_transition + ) + + with tempfile.TemporaryDirectory() as temp_dir: + # Save processor + robot_processor.save_pretrained(temp_dir) + + # Load processor with tokenizer override (since tokenizer object wasn't saved) + loaded_processor = DataProcessorPipeline.from_pretrained( + temp_dir, + config_filename="dataprocessorpipeline.json", + overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}}, + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Test that loaded processor works + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"instruction": "test instruction"}, + ) + + result = loaded_processor(transition) + assert TransitionKey.OBSERVATION in result + assert f"{OBS_LANGUAGE}.tokens" in result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.attention_mask" in result[TransitionKey.OBSERVATION] + + +@require_package("transformers") +def test_registry_functionality(): + """Test that the processor is properly registered.""" + from lerobot.processor import ProcessorStepRegistry + + # Check that the processor is registered + assert "tokenizer_processor" in ProcessorStepRegistry.list() + + # Check that we can retrieve it + retrieved_class = ProcessorStepRegistry.get("tokenizer_processor") + assert retrieved_class is TokenizerProcessorStep + + +@require_package("transformers") +def test_features_basic(): + """Test basic feature contract functionality.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=128) + + input_features = { + PipelineFeatureType.OBSERVATION: { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)) + }, + PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, + } + + output_features = processor.transform_features(input_features) + + # Check that original features are preserved + assert "observation.state" in output_features[PipelineFeatureType.OBSERVATION] + assert "action" in output_features[PipelineFeatureType.ACTION] + + # Check that tokenized features are added + assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION] + assert f"{OBS_LANGUAGE}.attention_mask" in output_features[PipelineFeatureType.OBSERVATION] + + # Check feature properties + tokens_feature = output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask_feature = output_features[PipelineFeatureType.OBSERVATION][ + f"{OBS_LANGUAGE}.attention_mask" + ] + + assert tokens_feature.type == FeatureType.LANGUAGE + assert tokens_feature.shape == (128,) + assert attention_mask_feature.type == FeatureType.LANGUAGE + assert attention_mask_feature.shape == (128,) + + +@require_package("transformers") +def test_features_with_custom_max_length(): + """Test feature contract with custom max_length.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=64) + + input_features = {PipelineFeatureType.OBSERVATION: {}} + output_features = processor.transform_features(input_features) + + # Check that features use correct max_length + assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION] + assert f"{OBS_LANGUAGE}.attention_mask" in output_features[PipelineFeatureType.OBSERVATION] + + tokens_feature = output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask_feature = output_features[PipelineFeatureType.OBSERVATION][ + f"{OBS_LANGUAGE}.attention_mask" + ] + + assert tokens_feature.shape == (64,) + assert attention_mask_feature.shape == (64,) + + +@require_package("transformers") +def test_features_existing_features(): + """Test feature contract when tokenized features already exist.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=256) + + input_features = { + PipelineFeatureType.OBSERVATION: { + f"{OBS_LANGUAGE}.tokens": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)), + f"{OBS_LANGUAGE}.attention_mask": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)), + } + } + + output_features = processor.transform_features(input_features) + + # Should not overwrite existing features + assert output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"].shape == ( + 100, + ) # Original shape preserved + assert output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"].shape == (100,) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_tokenization_parameters(mock_auto_tokenizer): + """Test that tokenization parameters are correctly passed to tokenizer.""" + + # Create a custom mock that tracks calls + class TrackingMockTokenizer: + def __init__(self): + self.last_call_args = None + self.last_call_kwargs = None + + def __call__(self, *args, **kwargs): + self.last_call_args = args + self.last_call_kwargs = kwargs + # Return minimal valid output + return { + "input_ids": torch.zeros(16, dtype=torch.long), + "attention_mask": torch.ones(16, dtype=torch.long), + } + + tracking_tokenizer = TrackingMockTokenizer() + mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer + + processor = TokenizerProcessorStep( + tokenizer_name="test-tokenizer", + max_length=16, + padding="longest", + truncation=False, + padding_side="left", + ) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "test task"}, + ) + + processor(transition) + + # Check that parameters were passed correctly (task is converted to list) + assert tracking_tokenizer.last_call_args == (["test task"],) + assert tracking_tokenizer.last_call_kwargs["max_length"] == 16 + assert tracking_tokenizer.last_call_kwargs["padding"] == "longest" + assert tracking_tokenizer.last_call_kwargs["padding_side"] == "left" + assert tracking_tokenizer.last_call_kwargs["truncation"] is False + assert tracking_tokenizer.last_call_kwargs["return_tensors"] == "pt" + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_preserves_other_complementary_data(mock_auto_tokenizer): + """Test that other complementary data fields are preserved.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={ + "task": "test task", + "episode_id": 123, + "timestamp": 456.789, + "other_field": {"nested": "data"}, + }, + ) + + result = processor(transition) + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check that all original fields are preserved + assert comp_data["task"] == "test task" + assert comp_data["episode_id"] == 123 + assert comp_data["timestamp"] == 456.789 + assert comp_data["other_field"] == {"nested": "data"} + + # Check that tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_deterministic_tokenization(mock_auto_tokenizer): + """Test that tokenization is deterministic for the same input.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "consistent test"}, + ) + + result1 = processor(transition) + result2 = processor(transition) + + tokens1 = result1[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask1 = result1[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + tokens2 = result2[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask2 = result2[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + # Results should be identical + assert torch.equal(tokens1, tokens2) + assert torch.equal(attention_mask1, attention_mask2) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_empty_string_task(mock_auto_tokenizer): + """Test handling of empty string task.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": ""}, + ) + + result = processor(transition) + + # Should still tokenize (mock tokenizer handles empty strings) + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + assert tokens.shape == (8,) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_very_long_task(mock_auto_tokenizer): + """Test handling of very long task strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=5, truncation=True) + + long_task = " ".join(["word"] * 100) # Very long task + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": long_task}, + ) + + result = processor(transition) + + # Should be truncated to max_length + observation = result[TransitionKey.OBSERVATION] + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert tokens.shape == (5,) + assert attention_mask.shape == (5,) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_custom_padding_side(mock_auto_tokenizer): + """Test using custom padding_side parameter.""" + + # Create a mock tokenizer that tracks padding_side calls + class PaddingSideTrackingTokenizer: + def __init__(self): + self.padding_side_calls = [] + + def __call__( + self, + text, + max_length=512, + truncation=True, + padding="max_length", + padding_side="right", + return_tensors="pt", + **kwargs, + ): + self.padding_side_calls.append(padding_side) + # Return minimal valid output + return { + "input_ids": torch.zeros(max_length, dtype=torch.long), + "attention_mask": torch.ones(max_length, dtype=torch.long), + } + + tracking_tokenizer = PaddingSideTrackingTokenizer() + mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer + + # Test left padding + processor_left = TokenizerProcessorStep( + tokenizer_name="test-tokenizer", max_length=10, padding_side="left" + ) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "test task"}, + ) + processor_left(transition) + + assert tracking_tokenizer.padding_side_calls[-1] == "left" + + # Test right padding (default) + processor_right = TokenizerProcessorStep( + tokenizer_name="test-tokenizer", max_length=10, padding_side="right" + ) + + processor_right(transition) + + assert tracking_tokenizer.padding_side_calls[-1] == "right" + + +@require_package("transformers") +def test_device_detection_cpu(): + """Test that tokenized tensors stay on CPU when other tensors are on CPU.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with CPU tensors + observation = {"observation.state": torch.randn(10)} # CPU tensor + action = torch.randn(5) # CPU tensor + transition = create_transition( + observation=observation, action=action, complementary_data={"task": "test task"} + ) + + result = processor(transition) + + # Check that tokenized tensors are on CPU + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.device.type == "cpu" + assert attention_mask.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@require_package("transformers") +def test_device_detection_cuda(): + """Test that tokenized tensors are moved to CUDA when other tensors are on CUDA.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with CUDA tensors + observation = {"observation.state": torch.randn(10).cuda()} # CUDA tensor + action = torch.randn(5).cuda() # CUDA tensor + transition = create_transition( + observation=observation, action=action, complementary_data={"task": "test task"} + ) + + result = processor(transition) + + # Check that tokenized tensors are on CUDA + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.device.type == "cuda" + assert attention_mask.device.type == "cuda" + assert tokens.device.index == 0 # Should be on same device as input + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +@require_package("transformers") +def test_device_detection_multi_gpu(): + """Test that tokenized tensors match device in multi-GPU setup.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Test with tensors on cuda:1 + device = torch.device("cuda:1") + observation = {"observation.state": torch.randn(10).to(device)} + action = torch.randn(5).to(device) + transition = create_transition( + observation=observation, action=action, complementary_data={"task": "multi gpu test"} + ) + + result = processor(transition) + + # Check that tokenized tensors are on cuda:1 + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.device == device + assert attention_mask.device == device + + +@require_package("transformers") +def test_device_detection_no_tensors(): + """Test that tokenized tensors stay on CPU when no other tensors exist.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with no tensors + transition = create_transition( + observation={"metadata": {"key": "value"}}, # No tensors + complementary_data={"task": "no tensor test"}, + ) + + result = processor(transition) + + # Check that tokenized tensors are on CPU (default) + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.device.type == "cpu" + assert attention_mask.device.type == "cpu" + + +@require_package("transformers") +def test_device_detection_mixed_devices(): + """Test device detection when tensors are on different devices (uses first found).""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + if torch.cuda.is_available(): + # Create transition with mixed devices + observation = { + "observation.cpu": torch.randn(10), # CPU + "observation.cuda": torch.randn(10).cuda(), # CUDA + } + transition = create_transition( + observation=observation, complementary_data={"task": "mixed device test"} + ) + + result = processor(transition) + + # The device detection should use the first tensor found + # (iteration order depends on dict, but result should be consistent) + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + # Both should be on the same device + assert tokens.device == attention_mask.device + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@require_package("transformers") +def test_device_detection_from_action(): + """Test that device is detected from action tensor when no observation tensors exist.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with action on CUDA but no observation tensors + observation = {"metadata": {"key": "value"}} # No tensors in observation + action = torch.randn(5).cuda() + transition = create_transition( + observation=observation, action=action, complementary_data={"task": "action device test"} + ) + + result = processor(transition) + + # Check that tokenized tensors match action's device + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.device.type == "cuda" + assert attention_mask.device.type == "cuda" + + +@require_package("transformers") +def test_device_detection_preserves_dtype(): + """Test that device detection doesn't affect dtype of tokenized tensors.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with float tensor (to test dtype isn't affected) + observation = {"observation.state": torch.randn(10, dtype=torch.float16)} + transition = create_transition(observation=observation, complementary_data={"task": "dtype test"}) + + result = processor(transition) + + # Check that tokenized tensors have correct dtypes (not affected by input dtype) + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.dtype == torch.long # Should remain long + assert attention_mask.dtype == torch.bool # Should be bool (converted in processor) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_integration_with_device_processor(mock_auto_tokenizer): + """Test that TokenizerProcessorStep works correctly with DeviceProcessorStep in pipeline.""" + from lerobot.processor import DeviceProcessorStep + + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + # Create pipeline with TokenizerProcessorStep then DeviceProcessorStep + tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6) + device_processor = DeviceProcessorStep(device="cuda:0") + robot_processor = DataProcessorPipeline( + [tokenizer_processor, device_processor], + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Start with CPU tensors + transition = create_transition( + observation={"observation.state": torch.randn(10)}, # CPU + action=torch.randn(5), # CPU + complementary_data={"task": "pipeline test"}, + ) + + result = robot_processor(transition) + + # All tensors should end up on CUDA (moved by DeviceProcessorStep) + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + + # Tokenized tensors should also be on CUDA + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + assert tokens.device.type == "cuda" + assert attention_mask.device.type == "cuda" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@require_package("transformers") +def test_simulated_accelerate_scenario(): + """Test scenario simulating Accelerate with data already on GPU.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Simulate Accelerate scenario: batch already on GPU + device = torch.device("cuda:0") + observation = { + "observation.state": torch.randn(1, 10).to(device), # Batched, on GPU + "observation.image": torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU + } + action = torch.randn(1, 5).to(device) # Batched, on GPU + + transition = create_transition( + observation=observation, + action=action, + complementary_data={"task": ["accelerate test"]}, # List for batched task + ) + + result = processor(transition) + + # Tokenized tensors should match GPU placement + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.device == device + assert attention_mask.device == device + # MockTokenizer squeezes single-item batches, so shape is (max_length,) not (1, max_length) + assert tokens.shape == (10,) # MockTokenizer behavior for single string in list + assert attention_mask.shape == (10,) diff --git a/tests/processor/test_vqbet_processor.py b/tests/processor/test_vqbet_processor.py new file mode 100644 index 00000000..98e05eae --- /dev/null +++ b/tests/processor/test_vqbet_processor.py @@ -0,0 +1,462 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for VQBeT policy processor.""" + +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DataProcessorPipeline, + DeviceProcessorStep, + NormalizerProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +def create_default_config(): + """Create a default VQBeT configuration for testing.""" + config = VQBeTConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + config.device = "cpu" + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(8), "std": torch.ones(8)}, + OBS_IMAGE: {}, # No normalization for images + ACTION: {"min": torch.full((7,), -1.0), "max": torch.ones(7)}, + } + + +def test_make_vqbet_processor_basic(): + """Test basic creation of VQBeT processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 4 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_vqbet_processor_with_images(): + """Test VQBeT processor with image and state observations.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Create test data with images and states + observation = { + OBS_STATE: torch.randn(8), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(7) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is batched + assert processed[OBS_STATE].shape == (1, 8) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 7) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_vqbet_processor_cuda(): + """Test VQBeT processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = { + OBS_STATE: torch.randn(8), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(7) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + # Process through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is back on CPU + assert postprocessed.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_vqbet_processor_accelerate_scenario(): + """Test VQBeT processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU and batched + device = torch.device("cuda:0") + observation = { + OBS_STATE: torch.randn(1, 8).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 7).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_vqbet_processor_multi_gpu(): + """Test VQBeT processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = { + OBS_STATE: torch.randn(1, 8).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 7).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_vqbet_processor_without_stats(): + """Test VQBeT processor creation without dataset statistics.""" + config = create_default_config() + + preprocessor, postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + # Process should still work + observation = { + OBS_STATE: torch.randn(8), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(7) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed is not None + + +def test_vqbet_processor_save_and_load(): + """Test saving and loading VQBeT processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save preprocessor + preprocessor.save_pretrained(tmpdir) + + # Load preprocessor + loaded_preprocessor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="policy_preprocessor.json" + ) + + # Test that loaded processor works + observation = { + OBS_STATE: torch.randn(8), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(7) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 8) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 7) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_vqbet_processor_mixed_precision(): + """Test VQBeT processor with mixed precision.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + # Create processor + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Replace DeviceProcessorStep with one that uses float16 + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + elif isinstance(step, NormalizerProcessorStep): + # Update normalizer to use the same device as the device processor + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float16, # Match the float16 dtype + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Create test data + observation = { + OBS_STATE: torch.randn(8, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(7, dtype=torch.float32) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is converted to float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[OBS_IMAGE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 + + +def test_vqbet_processor_large_batch(): + """Test VQBeT processor with large batch sizes.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Test with large batch + batch_size = 128 + observation = { + OBS_STATE: torch.randn(batch_size, 8), + OBS_IMAGE: torch.randn(batch_size, 3, 224, 224), + } + action = torch.randn(batch_size, 7) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that batch dimension is preserved + assert processed[OBS_STATE].shape == (batch_size, 8) + assert processed[OBS_IMAGE].shape == (batch_size, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (batch_size, 7) + + +def test_vqbet_processor_sequential_processing(): + """Test VQBeT processor with sequential data processing.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Process multiple samples sequentially + results = [] + for _ in range(5): + observation = { + OBS_STATE: torch.randn(8), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(7) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + results.append(processed) + + # Check that all results are consistent + for result in results: + assert result[OBS_STATE].shape == (1, 8) + assert result[OBS_IMAGE].shape == (1, 3, 224, 224) + assert result[TransitionKey.ACTION.value].shape == (1, 7) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_vqbet_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(8, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(7, dtype=torch.float32) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py new file mode 100644 index 00000000..29b7bf70 --- /dev/null +++ b/tests/utils/test_visualization_utils.py @@ -0,0 +1,209 @@ +import importlib +import sys +from types import SimpleNamespace + +import numpy as np +import pytest + +from lerobot.processor import TransitionKey + + +@pytest.fixture +def mock_rerun(monkeypatch): + """ + Provide a mock `rerun` module so tests don't depend on the real library. + Also reload the module-under-test so it binds to this mock `rr`. + """ + calls = [] + + class DummyScalar: + def __init__(self, value): + self.value = float(value) + + class DummyImage: + def __init__(self, arr): + self.arr = arr + + def dummy_log(key, obj, **kwargs): + calls.append((key, obj, kwargs)) + + dummy_rr = SimpleNamespace( + Scalar=DummyScalar, + Image=DummyImage, + log=dummy_log, + init=lambda *a, **k: None, + spawn=lambda *a, **k: None, + ) + + # Inject fake module into sys.modules + monkeypatch.setitem(sys.modules, "rerun", dummy_rr) + + # Now import and reload the module under test, to bind to our rerun mock + import lerobot.utils.visualization_utils as vu + + importlib.reload(vu) + + # Expose both the reloaded module and the call recorder + yield vu, calls + + +def _keys(calls): + """Helper to extract just the keys logged to rr.log""" + return [k for (k, _obj, _kw) in calls] + + +def _obj_for(calls, key): + """Find the first object logged under a given key.""" + for k, obj, _kw in calls: + if k == key: + return obj + raise KeyError(f"Key {key} not found in calls: {calls}") + + +def _kwargs_for(calls, key): + for k, _obj, kw in calls: + if k == key: + return kw + raise KeyError(f"Key {key} not found in calls: {calls}") + + +def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): + vu, calls = mock_rerun + + # Build EnvTransition dict + obs = { + "observation.state.temperature": np.float32(25.0), + # CHW image should be converted to HWC for rr.Image + "observation.camera": np.zeros((3, 10, 20), dtype=np.uint8), + } + act = { + "action.throttle": 0.7, + # 1D array should log individual Scalars with suffix _i + "action.vector": np.array([1.0, 2.0], dtype=np.float32), + } + transition = { + TransitionKey.OBSERVATION: obs, + TransitionKey.ACTION: act, + } + + # Extract observation and action data from transition like in the real call sites + obs_data = transition.get(TransitionKey.OBSERVATION, {}) + action_data = transition.get(TransitionKey.ACTION, {}) + vu.log_rerun_data(observation=obs_data, action=action_data) + + # We expect: + # - observation.state.temperature -> Scalar + # - observation.camera -> Image (HWC) with static=True + # - action.throttle -> Scalar + # - action.vector_0, action.vector_1 -> Scalars + expected_keys = { + "observation.state.temperature", + "observation.camera", + "action.throttle", + "action.vector_0", + "action.vector_1", + } + assert set(_keys(calls)) == expected_keys + + # Check scalar types and values + temp_obj = _obj_for(calls, "observation.state.temperature") + assert type(temp_obj).__name__ == "DummyScalar" + assert temp_obj.value == pytest.approx(25.0) + + throttle_obj = _obj_for(calls, "action.throttle") + assert type(throttle_obj).__name__ == "DummyScalar" + assert throttle_obj.value == pytest.approx(0.7) + + v0 = _obj_for(calls, "action.vector_0") + v1 = _obj_for(calls, "action.vector_1") + assert type(v0).__name__ == "DummyScalar" + assert type(v1).__name__ == "DummyScalar" + assert v0.value == pytest.approx(1.0) + assert v1.value == pytest.approx(2.0) + + # Check image handling: CHW -> HWC + img_obj = _obj_for(calls, "observation.camera") + assert type(img_obj).__name__ == "DummyImage" + assert img_obj.arr.shape == (10, 20, 3) # transposed + assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images + + +def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun): + vu, calls = mock_rerun + + # First dict without prefixes treated as observation + # Second dict without prefixes treated as action + obs_plain = { + "temp": 1.5, + # Already HWC image => should stay as-is + "img": np.zeros((5, 6, 3), dtype=np.uint8), + "none": None, # should be skipped + } + act_plain = { + "throttle": 0.3, + "vec": np.array([9, 8, 7], dtype=np.float32), + } + + # Extract observation and action data from list like the old function logic did + # First dict was treated as observation, second as action + vu.log_rerun_data(observation=obs_plain, action=act_plain) + + # Expected keys with auto-prefixes + expected = { + "observation.temp", + "observation.img", + "action.throttle", + "action.vec_0", + "action.vec_1", + "action.vec_2", + } + logged = set(_keys(calls)) + assert logged == expected + + # Scalars + t = _obj_for(calls, "observation.temp") + assert type(t).__name__ == "DummyScalar" + assert t.value == pytest.approx(1.5) + + throttle = _obj_for(calls, "action.throttle") + assert type(throttle).__name__ == "DummyScalar" + assert throttle.value == pytest.approx(0.3) + + # Image stays HWC + img = _obj_for(calls, "observation.img") + assert type(img).__name__ == "DummyImage" + assert img.arr.shape == (5, 6, 3) + assert _kwargs_for(calls, "observation.img").get("static", False) is True + + # Vectors + for i, val in enumerate([9, 8, 7]): + o = _obj_for(calls, f"action.vec_{i}") + assert type(o).__name__ == "DummyScalar" + assert o.value == pytest.approx(val) + + +def test_log_rerun_data_kwargs_only(mock_rerun): + vu, calls = mock_rerun + + vu.log_rerun_data( + observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)}, + action={"action.a": 1.0}, + ) + + keys = set(_keys(calls)) + assert "observation.temp" in keys + assert "observation.gray" in keys + assert "action.a" in keys + + temp = _obj_for(calls, "observation.temp") + assert type(temp).__name__ == "DummyScalar" + assert temp.value == pytest.approx(10.0) + + img = _obj_for(calls, "observation.gray") + assert type(img).__name__ == "DummyImage" + assert img.arr.shape == (8, 8, 1) # remains HWC + assert _kwargs_for(calls, "observation.gray").get("static", False) is True + + a = _obj_for(calls, "action.a") + assert type(a).__name__ == "DummyScalar" + assert a.value == pytest.approx(1.0)