feat(processors): use pipelines across the codebase (#1452)
* Refactor observation preprocessing to use a modular pipeline system - Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations. - Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline. - Added tests for the new processing components and ensured they match the original functionality. - Removed hardcoded logic in favor of a more flexible, composable architecture. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor observation processing and improve modularity - Updated `ObservationProcessor` to enhance the modular design for processing observations. - Cleaned up imports and improved code readability by removing unnecessary lines and comments. - Ensured backward compatibility while integrating new processing components. - Added tests to validate the functionality of the updated processing architecture. * Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability. * Refactor processing architecture to use RobotProcessor - Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity. - Introduced ProcessorStepRegistry for better management of processing steps. - Updated relevant documentation and tests to reflect the new processing structure. - Enhanced the save/load functionality to support the new processor design. - Added a model card template for RobotProcessor to facilitate sharing and documentation. * Add RobotProcessor tutorial to documentation - Introduced a new tutorial on using RobotProcessor for preprocessing robot data. - Added a section in the table of contents for easy navigation to the new tutorial. - The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Transition from tuple to dictionary format for EnvTransition - Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase. * refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling - Introduced constants for observation keys to enhance readability. - Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly. - Updated the environment state and agent position assignments to use the new constants, improving maintainability. * feat(pipeline): Add hook unregistration functionality and enhance documentation - Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks. * refactor(pipeline): Clarify hook behavior and improve documentation - Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method. * feat(pipeline): Add __repr__ method to RobotProcessor for improved readability - Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed. - Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values. - Ensured that the representation handles long lists of steps with truncation for better readability. * chore(pipeline): Move _CFG_NAME along other class member * refactor(pipeline): Utilize get_safe_torch_device for device assignment - Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling. - This change enhances code readability and maintains consistency in device management across the RobotProcessor class. * refactor(pipeline): Enhance state filename generation and profiling method - Updated state filename generation to use the registry name when available, improving clarity in saved files. - Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling. - Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results. * chore(doc): address pip install commant lerobot that not exist yet * feat(pipeline): Enhance configuration filename handling and state file naming - Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection. * refactor(pipeline): Improve state file naming conventions for clarity and uniqueness - Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files. * docs(pipeline): Add clarification for repo name sanitization process * Feat/pipeline add feature contract (#1637) * Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops * docs(pipeline): Clarify transition handling and hook behavior - Updated documentation to specify that hooks always receive transitions in EnvTransition format, ensuring consistent behavior across input formats. - Refactored the step_through method to yield only EnvTransition objects, regardless of the input format, and updated related tests to reflect this change. - Enhanced test assertions to verify the structure of results and the correctness of processing steps. * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * refactor(pipeline): Remove model card generation and streamline processor methods - Eliminated the _generate_model_card method from RobotProcessor, which was responsible for generating README.md files from a template. - Updated save_pretrained method to remove model card generation, focusing on serialization of processor definitions and parameters. - Added default implementations for get_config, state_dict, load_state_dict, reset, and feature_contract methods in various processor classes to enhance consistency and usability. * refactor(observation): Streamline observation preprocessing and remove unused processor methods - Updated the `preprocess_observation` function to enhance image handling and ensure proper tensor formatting. - Removed the `RobotProcessor` and associated transition handling from the `rollout` function, simplifying the observation processing flow. - Integrated direct calls to `preprocess_observation` for improved clarity and efficiency in the evaluation script. * refactor(pipeline): Rename parameters for clarity and enhance save/load functionality - Updated parameter names in the save_pretrained and from_pretrained methods for improved readability, changing destination_path to save_directory and source to pretrained_model_name_or_path. - Enhanced the save_pretrained method to ensure directory creation and file handling is consistent with the new parameter names. - Streamlined the loading process in from_pretrained to utilize loaded_config for better clarity and maintainability. * refactor(pipeline): minor improvements (#1684) * chore(pipeline): remove unused features + device torch + envtransition keys * refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor * refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code * test(pipeline): fix broken test after refactors * docs(pipeline): update docstrings VanillaObservationProcessor * chore(pipeline): move None check to base pipeline classes * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * refactor(normalization): Remove unused state dict transformation methods and streamline imports - Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality. * refactor(normalization): Clean up imports in normalize_processor.py * feat(batch_processor): Add feature_contract method to ToBatchProcessor - Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor. - This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs. * fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. * refactor(processors): Standardize processor naming conventions - Updated processor names across various files to use a consistent "robot_preprocessor" and "robot_postprocessor" format. - Modified the make_processor functions in factory, act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet to reflect the new naming scheme. - Enhanced the pipeline configuration to align with the updated processor names, improving clarity and maintainability. * refactor(factory): Update processor configuration and type hints - Changed return type of get_policy_class to type[PreTrainedPolicy] for improved type safety. - Enhanced make_processor function to utilize dataset_stats in processor creation for better flexibility. - Updated ProcessorConfigKwargs to include dataset_stats, allowing for more comprehensive processor configurations. - Streamlined processor initialization by removing unnecessary kwargs and ensuring clarity in processor type handling. * refactor(factory, pi0fast): Update processor function names and parameters - Renamed make_pi0_processor to make_pi0fast_processor for clarity and consistency. - Updated parameter names in the factory's make_processor function to use pretrained_model_name_or_path instead of source, enhancing readability and alignment with naming conventions. * fix(train.py) push postprocessor with preprocessor - Add preprocesser policy overrides for device and rename_map - Add rename_map to DatasetRecordConfig (record.py) * refactor(device_processor): Update device handling and improve type hints - Changed device attribute type from torch.device to str for better clarity. - Introduced a private _device attribute to store the actual torch.device instance. - Updated tests to conditionally check for CUDA availability, ensuring compatibility across different environments. - Refactored device-related assertions in tests to use a consistent approach for device type verification. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test(tokenizer_processor): Add require_package decorator for transformers - Introduced @require_package("transformers") decorator in multiple test functions to ensure the transformers package is available before running tests. - This change enhances test reliability by preventing failures due to missing dependencies. * refactor(migrate_policy_normalization): Enhance preprocessor and postprocessor structure - Introduced RenameProcessor in the preprocessor to handle renaming features. - Combined input and output features in a single NormalizerProcessor for improved efficiency. - Updated RobotProcessor initialization to clarify step naming for preprocessor and postprocessor. - Added DeviceProcessor to both preprocessor and postprocessor for better device management. * Integrate pipeline and add phone teleop (#1681) * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * fix(ci): temporary fix on dataset deps version * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * refactor(train): Update memory pinning logic for mps compatibility * feat: initial commit phone teleop * ugly delta control * use quaternion * Refactor observation preprocessing to use a modular pipeline system - Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations. - Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline. - Added tests for the new processing components and ensured they match the original functionality. - Removed hardcoded logic in favor of a more flexible, composable architecture. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor observation processing and improve modularity - Updated `ObservationProcessor` to enhance the modular design for processing observations. - Cleaned up imports and improved code readability by removing unnecessary lines and comments. - Ensured backward compatibility while integrating new processing components. - Added tests to validate the functionality of the updated processing architecture. * Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability. * Refactor processing architecture to use RobotProcessor - Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity. - Introduced ProcessorStepRegistry for better management of processing steps. - Updated relevant documentation and tests to reflect the new processing structure. - Enhanced the save/load functionality to support the new processor design. - Added a model card template for RobotProcessor to facilitate sharing and documentation. * Add RobotProcessor tutorial to documentation - Introduced a new tutorial on using RobotProcessor for preprocessing robot data. - Added a section in the table of contents for easy navigation to the new tutorial. - The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Transition from tuple to dictionary format for EnvTransition - Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase. * refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling - Introduced constants for observation keys to enhance readability. - Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly. - Updated the environment state and agent position assignments to use the new constants, improving maintainability. * feat(pipeline): Add hook unregistration functionality and enhance documentation - Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks. * refactor(pipeline): Clarify hook behavior and improve documentation - Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method. * feat(pipeline): Add __repr__ method to RobotProcessor for improved readability - Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed. - Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values. - Ensured that the representation handles long lists of steps with truncation for better readability. * chore(pipeline): Move _CFG_NAME along other class member * refactor(pipeline): Utilize get_safe_torch_device for device assignment - Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling. - This change enhances code readability and maintains consistency in device management across the RobotProcessor class. * refactor(pipeline): Enhance state filename generation and profiling method - Updated state filename generation to use the registry name when available, improving clarity in saved files. - Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling. - Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results. * chore(doc): address pip install commant lerobot that not exist yet * feat(pipeline): Enhance configuration filename handling and state file naming - Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection. * refactor(pipeline): Improve state file naming conventions for clarity and uniqueness - Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files. * docs(pipeline): Add clarification for repo name sanitization process * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * Add debug + calib * cleanup * Add pipeline * fix int * Add record example * nit * Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops * cleaned up steps and integrated pipeline with feature_contract * refactor steps and robot to pipeline * cleanup pipeline * cleanup code further * make it run * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * refactor(normalization): Remove unused state dict transformation methods and streamline imports - Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality. * refactor(normalization): Clean up imports in normalize_processor.py * feat(batch_processor): Add feature_contract method to ToBatchProcessor - Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor. - This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs. * fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * Do some todos and cleanup * change feature_contract to dataset_features * use one method for conversion pipeline output to add_frame dict and use base processors where possible * Add back in and use record_loop * update todo * rename to_dataset_frame * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix reference frame * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * update data visualization * update teleop example * fix record bugs * Add replay * Not code * feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. * Add eval script * fix `q_curr` in InverseKinematicsEEToJoints to the IK solution * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * refactor(normalization): Remove unused state dict transformation methods and streamline imports - Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality. * refactor(normalization): Clean up imports in normalize_processor.py * feat(batch_processor): Add feature_contract method to ToBatchProcessor - Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor. - This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs. * fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. * refactor(processors): Standardize processor naming conventions - Updated processor names across various files to use a consistent "robot_preprocessor" and "robot_postprocessor" format. - Modified the make_processor functions in factory, act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet to reflect the new naming scheme. - Enhanced the pipeline configuration to align with the updated processor names, improving clarity and maintainability. * refactor(factory): Update processor configuration and type hints - Changed return type of get_policy_class to type[PreTrainedPolicy] for improved type safety. - Enhanced make_processor function to utilize dataset_stats in processor creation for better flexibility. - Updated ProcessorConfigKwargs to include dataset_stats, allowing for more comprehensive processor configurations. - Streamlined processor initialization by removing unnecessary kwargs and ensuring clarity in processor type handling. * Fix eval and android gripper * add some tests * refactor(factory, pi0fast): Update processor function names and parameters - Renamed make_pi0_processor to make_pi0fast_processor for clarity and consistency. - Updated parameter names in the factory's make_processor function to use pretrained_model_name_or_path instead of source, enhancing readability and alignment with naming conventions. * fix(train.py) push postprocessor with preprocessor - Add preprocesser policy overrides for device and rename_map - Add rename_map to DatasetRecordConfig (record.py) * Cleanup pr * fix more git diff pr issues * add path as type in save_pretrained * small nit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename test file * fix: make dataset_features/feature_contract is optional * fix tests * Encorperate pr feedback * clean up record.py * add ascii art, fix normal record * remove merge issues * fix merge * remove features * Add feedback PR * fix last 4 tests * remove features check * rename to transform_features * add transform_features * fix lekiwi eval and update eval api example --------- Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> * refactor(TokenizerProcessor): improve dependency handling and observation management - Updated TokenizerProcessor to conditionally import AutoTokenizer based on the availability of the transformers library, enhancing flexibility. - Modified tokenizer attribute type to Any to accommodate scenarios where transformers may not be installed. - Improved observation handling by using a more concise approach to manage the transition dictionary, ensuring compatibility with existing data structures. - Added error handling for missing transformers library, providing clear guidance for users on installation requirements. * feat(dependencies): Add scipy as a required dependency - Included `scipy>=1.15.2` in the project dependencies to enhance functionality and support for scientific computing tasks. * feat(policies): convert save_policy_to_safetensors with pipeline * refactor(normalization): remove Normalize and Unnormalize classes - Deleted the Normalize and Unnormalize classes from the normalization module to streamline the codebase. - Updated tests to ensure compatibility with the removal of these classes, focusing on the new NormalizerProcessor and UnnormalizerProcessor implementations. - Enhanced the handling of normalization statistics and improved overall code clarity. * refactor(factory): streamline processor loading by removing unused comments - Removed commented-out code related to loading pretrained processors in the make_processor function. - This change enhances code clarity and maintains focus on the current implementation. * feat(DeviceProcessor): Enhance tensor processing with device detection and float dtype conversion - Improved the _process_tensor method to preserve GPU placement for tensors already on a GPU, facilitating multi-GPU training scenarios. - Introduced a new _detect_device method in TokenizerProcessor to ensure tokenized tensors match the device of existing tensors in transitions. - Added comprehensive unit tests to validate the functionality of device detection and float dtype conversion across various scenarios. * feat(tests): Add comprehensive tests for various policy processors - Introduced new test files for ACT, Classifier, Diffusion, PI0, SAC, SmolVLA, TDMPC, and VQBeT policy processors. - Each test file includes unit tests to validate functionality, including handling of batch sizes, device management, and data type conversions. - Enhanced test coverage to ensure robustness and reliability of processor implementations across different scenarios. * refactor(train): Remove unnecessary tensor device handling in training loop * Refactor`gym_manipulator.py` using the universal pipeline (#1650) * Migrate gym_manipulator to use the pipeline Added get_teleop_events function to capture relevant events from teleop devices unrelated to actions * Added the capability to record a dataset * Added the replay functionality with the pipeline * Refactored `actor.py` to use the pipeline * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * RL works at this commit - fixed actor.py and bugs in gym_manipulator * change folder structure to reduce the size of gym_manip * Refactored hilserl config * Remove dataset and mode from HilSerlEnvConfig to a GymManipulatorConfig to reduce verbose of configs during training * format docs * removed get_teleop_events from abc * Refactor environment configuration and processing pipeline for GymHIL support. Removed device attribute from HILSerlRobotEnvConfig, added DummyTeleopDevice for simulation, and updated processor creation to accommodate GymHIL environments. * Improved typing for HILRobotEnv config and GymManipulator config * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Migrated `gym_manipulator` to use a more modular structure similar to phone teleop * Refactor gripper handling and transition processing in HIL and robot kinematic processors - Updated gripper position handling to use a consistent key format across processors - Improved the EEReferenceAndDelta class to handle reference joint positions. - Added support for discrete gripper actions in the GripperVelocityToJoint processor. - Refactored the gym manipulator to improve modularity and clarity in processing steps. * Added delta_action_processor mapping wrapper * Added missing file delta_action_processor and improved imports in `gym_manipulator` * nit * Added missing file joint_observation_processor * Enhance processing architecture with new teleoperation processors - Introduced `AddTeleopActionAsComplimentaryData` and `AddTeleopEventsAsInfo` for integrating teleoperator actions and events into transitions. - Added `Torch2NumpyActionProcessor` and `Numpy2TorchActionProcessor` for seamless conversion between PyTorch tensors and NumPy arrays. - Updated `__init__.py` to include new processors in module exports, improving modularity and clarity in the processing pipeline. - GymHIL is now fully supported with HIL using the pipeline * Refactor configuration structure for gym_hil integration - Renamed sections for better readability, such as changing "Gym Wrappers Configuration" to "Processor Configuration." - Enhanced documentation with clear examples for dataset collection and policy evaluation configurations. * Enhance reset configuration and teleoperation event handling - Added `terminate_on_success` parameter to `ResetConfig` and `InterventionActionProcessor` for controlling episode termination behavior upon success detection. - Updated documentation to clarify the impact of `terminate_on_success` on data collection for reward classifier training. - Refactored teleoperation event handling to use `TeleopEvents` constants for improved readability and maintainability across various modules. * fix(keyboard teleop), delta action keys * Added transform features and feature contract * Added transform features for image crop * Enum for TeleopEvents * Update tranform_features delta action proc --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Remove HILEnvConfig references * chore(processor): Add default names for preprocessor and postprocessor in constants - Introduced `PREPROCESSOR_DEFAULT_NAME` and `POSTPROCESSOR_DEFAULT_NAME` constants for consistent naming across various processor implementations. - Updated processor creation in multiple policy files to utilize these constants, enhancing code readability and maintainability. - Modified the training script to load and save the preprocessor and postprocessor using the new constants. * feat(processor): multiple improvements to the pipeline porting (#1749) * [Port codebase pipeline] General fixes for RL and scripts (#1748) * Refactor dataset configuration in documentation and codebase - Updated dataset configuration keys from `dataset_root` to `root` and `num_episodes` to `num_episodes_to_record` for consistency. - Adjusted replay episode handling by renaming `episode` to `replay_episode`. - Enhanced documentation - added specific processor to transform from policy actions to delta actions * Added Robot action to tensor processor Added new processor script for dealing with gym specific action processing * removed RobotAction2Tensor processor; imrpoved choosing observations in actor * nit in delta action * added missing reset functions to kinematics * Adapt teleoperate and replay to pipeline similar to record * refactor(processors): move to inheritance (#1750) * fix(teleoperator): improvements phone implementation (#1752) * fix(teleoperator): protect shared state in phone implementation * refactor(teleop): separate classes in phone * fix: solve breaking changes (#1753) * refactor(policies): multiple improvements (#1754) * refactor(processor): simpler logic in device processor (#1755) * refactor(processor): euclidean distance in delta action processor (#1757) * refactor(processor): improvements to joint observations processor migration (#1758) * refactor(processor): improvements to tokenizer migration (#1759) * refactor(processor): improvements to tokenizer migration * fix(tests): tokenizer tests regression from #1750 * fix(processors): fix float comparison and config in hil processors (#1760) * chore(teleop): remove unnecessary callbacks in KeyboardEndEffectorTeleop (#1761) * refactor(processor): improvements normalize pipeline migration (#1756) * refactor(processor): several improvements normalize processor step * refactor(processor): more improvements normalize processor * refactor(processor): more changes to normalizer * refactor(processor): take a different approach to DRY * refactor(processor): final design * chore(record): revert comment and continue deleted (#1764) * refactor(examples): pipeline phone examples (#1769) * refactor(examples): phone teleop + teleop script * refactor(examples): phone replay + replay * chore(examples): rename phone example files & folders * feat(processor): fix improvements to the pipeline porting (#1796) * refactor(processor): enhance tensor device handling in normalization process (#1795) * refactor(tests): remove unsupported device detection test for complementary data (#1797) * chore(tests): update ToBatchProcessor test (#1798) * refactor(tests): remove in-place mutation tests for actions and complementary data in batch processor * test(tests): add tests for action and task processing in batch processor * add names for android and ios phone (#1799) * use _tensor_stats in normalize processor (#1800) * fix(normalize_processor): correct device reference for tensor epsilon handling (#1801) * add point 5 add missing feature contracts (#1806) * Fix PR comments 1452 (#1807) * use key to determine image * Address rest of PR comments * use PolicyFeatures in transform_features --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * refactor(constants, processor): standardize action and observation keys across multiple files (#1808) - Added new constants for truncated and done states in constants.py. - Updated references to action and observation keys in pipeline_features.py, converters.py, hil_processor.py, tokenizer_processor.py, and robot_kinematic_processor.py to use the new constants for improved readability and maintainability. * refactor(processor): improve processor pipeline typing with generic type (#1810) * refactor(processor): introduce generic type for to_output - Always return `TOutput` - Remove `_prepare_transition`, so `__call__` now always returns `TOutput` - Update tests accordingly - This refactor paves the way for adding settings for `to_transition` and `to_output` in `make_processor` and the post-processor * refactor(processor): consolidate ProcessorKwargs usage across policies - Removed the ProcessorTypes module and integrated ProcessorKwargs directly into the processor pipeline. - Updated multiple policy files to utilize the new ProcessorKwargs structure for preprocessor and postprocessor arguments. - Simplified the handling of processor kwargs by initializing them to empty dictionaries when not provided. * refactor(converters): implement unified tensor conversion function (#1830) - Introduced `to_tensor` function using `singledispatch` to handle various input types, including scalars, arrays, and dictionaries, converting them to PyTorch tensors. - Replaced previous tensor conversion logic in `gym_action_processor`, `normalize_processor`, and `test_converters` with the new `to_tensor` function for improved readability and maintainability. - Updated tests to cover new functionality and ensure correct tensor conversion behavior. * Revert "refactor(converters): implement unified tensor conversion function (#…" (#1840) This reverts commit a837685bf870919fc07ada287a71711cebabb1ea. * refactor(converters): implement unified tensor conversion function (#1841) - Introduced `to_tensor` function using `singledispatch` to handle various input types, including scalars, arrays, and dictionaries, converting them to PyTorch tensors. - Replaced previous tensor conversion logic in `gym_action_processor`, `normalize_processor`, and `test_converters` with the new `to_tensor` function for improved readability and maintainability. - Updated tests to cover new functionality and ensure correct tensor conversion behavior. Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> * refactor(converters): gather converters and refactor the logic (#1833) * refactor(converters): move batch transition functions to converters module - Moved `_default_batch_to_transition` and `_default_transition_to_batch` functions from `pipeline.py` to `converters.py` for better organization and separation of concerns. - Updated references in `RobotProcessor` to use the new location of these functions. - Added tests to ensure correct functionality of the transition functions, including handling of index and task_index fields. - Removed redundant tests from `pipeline.py` to streamline the test suite. * refactor(processor): reorganize EnvTransition and TransitionKey definitions - Moved `EnvTransition` and `TransitionKey` classes from `pipeline.py` to a new `core.py` module for better structure and maintainability. - Updated import statements across relevant modules to reflect the new location of these definitions, ensuring consistent access throughout the codebase. * refactor(converters): rename and update dataset frame conversion functions - Replaced `to_dataset_frame` with `transition_to_dataset_frame` for clarity and consistency in naming. - Updated references in `record.py`, `pipeline.py`, and tests to use the new function name. - Introduced `merge_transitions` to streamline the merging of transitions, enhancing readability and maintainability. - Adjusted related tests to ensure correct functionality with the new naming conventions. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(processor): solve conflict artefacts * refactor(converters): remove unused identity function and update type hints for merge_transitions * refactor(processor): remove unused identity import and clean up gym_manipulator.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co> * refactor(processors): add transform_features method to various processors (#1843) * refactor(processors): update transition handling in RewardClassifierProcessor and InverseKinematicsEEToJoints (#1844) * refactor(processors): unify import statements by consolidating pipeline imports into the main processor module (#1845) * refactor(processors): add extended api for specialized pipelines (#1848) * refactor(processors): enhance transform_features method across multiple processors (#1849) * refactor(processors): enhance transform_features method across multiple processors - Updated the transform_features method in various processors to utilize a copy of the features dictionary, ensuring immutability of the original features. - Added handling for new feature keys and removed obsolete ones in the MapTensorToDeltaActionDict, JointVelocityProcessor, and others. - Improved readability and maintainability by following consistent patterns in feature transformation. * refactor(processors): standardize action and observation keys in delta_action_processor and joint_observations_processor - Updated action and observation keys to use constants for improved readability and maintainability. - Refactored the transform_features method in multiple processors to ensure consistent handling of feature keys. - Enhanced error handling by raising exceptions for missing required components in action and observation processing. - Removed obsolete code and improved overall structure for better clarity. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(processors): remove unused import in joint_observations_processor * refactor(processors): simplify transform_features method in delta_action_processor * refactor(processors): streamline transform_features method in ImageCropResizeProcessor * refactor(processors): improve error handling and streamline transform_features method in phone_processor - Raised a ValueError for missing position and rotation in action to enhance error handling. * refactor(processors): enhance error handling in JointVelocityProcessor - Added a ValueError raise for missing current joint positions in the observation method to improve error handling and ensure the integrity of the transform_features method. * refactor(processors): simplify transform_features method in robot kinematic processors * refactor(processors): standardize action keys in phone_processor * fix(processor): RKP feature obs -> act --------- Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co> * chore(processor): rename RobotProcessor -> DataProcessorPipeline (#1850) * chore(processor): rename specialized processor -> XYZProcessorStep (#1852) * chore(processor): rename converters function names (#1853) * chore(processor): rename to_transition_teleop_action -> action_to_transition * chore(processor): rename to_transition_robot_observation -> observation_to_transition * chore(processor): rename to_output_robot_action -> transition_to_robot_action * chore(processor): add Step suffix to all processors (#1854) * refactor(processor): rename MapDeltaActionToRobotAction and MapTensorToDeltaActionDict for consistency * refactor(processor): rename DeviceProcessor to DeviceProcessorStep for consistency across modules * refactor(processor): rename Torch2NumpyActionProcessor to Torch2NumpyActionProcessorStep for consistency * refactor(processor): rename Numpy2TorchActionProcessor to Numpy2TorchActionProcessorStep for consistency * refactor(processor): rename AddTeleopActionAsComplimentaryData to AddTeleopActionAsComplimentaryDataStep for consistency * refactor(processor): rename ImageCropResizeProcessor and AddTeleopEventsAsInfo for consistency * refactor(processor): rename TimeLimitProcessor to TimeLimitProcessorStep for consistency * refactor(processor): rename GripperPenaltyProcessor to GripperPenaltyProcessorStep for consistency * refactor(processor): rename InterventionActionProcessor to InterventionActionProcessorStep for consistency * refactor(processor): rename RewardClassifierProcessor to RewardClassifierProcessorStep for consistency * refactor(processor): rename JointVelocityProcessor to JointVelocityProcessorStep for consistency * refactor(processor): rename MotorCurrentProcessor to MotorCurrentProcessorStep for consistency * refactor(processor): rename NormalizerProcessor and UnnormalizerProcessor to NormalizerProcessorStep and UnnormalizerProcessorStep for consistency * refactor(processor): rename VanillaObservationProcessor to VanillaObservationProcessorStep for consistency * refactor(processor): rename RenameProcessor to RenameProcessorStep for consistency * refactor(processor): rename TokenizerProcessor to TokenizerProcessorStep for consistency * refactor(processor): rename ToBatchProcessor to AddBatchDimensionProcessorStep for consistency * refactor(processor): update config file name in test for RenameProcessorStep consistency * refactor(processor): rename internal tokenizer variable for clarity (#1855) - Changed the internal tokenizer variable name from `_tokenizer` to `input_tokenizer` for improved readability and consistency. - Updated references throughout the class to reflect the new variable name. * chore(processor): rename merge_features -> combine_feature_dicts (#1856) * refactor(processor): rename internal device variable for clarity (#1857) - Changed the internal device variable from `_device` to `tensor_device` for improved readability and consistency. - Updated references throughout the class to reflect the new variable name. * chore(processor): rename teleop_phone variable names (#1858) * chore(processor): add type alias RobotProcessorPipeline and PolicyProcessorPipeline (#1859) * feat(processor): introduce PolicyProcessorPipeline and RobotProcessorPipeline as type aliases for DataProcessorPipeline - Added PolicyProcessorPipeline and RobotProcessorPipeline type aliases to enhance clarity and maintainability in the processor module. - Updated the __all__ list to include the new pipelines for better module export consistency. * refactor(processor): replace DataProcessorPipeline with PolicyProcessorPipeline across multiple modules - Updated all instances of DataProcessorPipeline to PolicyProcessorPipeline in various processor files for consistency and clarity. - Adjusted function signatures to reflect the new pipeline type, enhancing maintainability and readability. * refactor(processor): update hotswap_stats function to use PolicyProcessorPipeline - Changed the parameter name from robot_processor to policy_processor for clarity. - Ensured consistency with recent updates to the processor module by reflecting the new pipeline type in the function signature. * refactor(processor): replace DataProcessorPipeline with PolicyProcessorPipeline in migrate_policy_normalization.py - Updated the preprocessor and postprocessor to use PolicyProcessorPipeline for consistency with recent changes in the processor module. - Enhanced clarity and maintainability by aligning with the new pipeline structure. * refactor(processor): update hotswap_stats to use PolicyProcessorPipeline - Changed the parameter type in hotswap_stats from DataProcessorPipeline to PolicyProcessorPipeline for consistency with recent updates. - Enhanced clarity by updating the function documentation to reflect the new pipeline type. * refactor(processor): replace DataProcessorPipeline with RobotProcessorPipeline across multiple files - Updated instances of DataProcessorPipeline to RobotProcessorPipeline in evaluate.py, record.py, replay.py, teleoperate.py, and other relevant files for consistency and clarity. - Adjusted function signatures and variable types to reflect the new pipeline structure, enhancing maintainability and readability. * refactor(processor): enforce config_filename requirement for HF Hub loading (#1860) - Updated the DataProcessorPipeline to require a specific config_filename when loading from Hugging Face Hub, enhancing clarity and preventing errors. - Simplified local path checks and improved error handling for invalid paths. - Adjusted tests to reflect the new requirement and ensure proper error handling for various loading scenarios. * feat(record): add transition features to dataset and handle scalar vs array formatting in converters (#1861) - Introduced new transition features (`next.reward`, `next.done`, `next.truncated`) in the dataset during recording. - Updated the `transition_to_dataset_frame` function to handle scalar values correctly, ensuring compatibility with expected array formats for reward, done, and truncated features. * refactor(pipeline): enforce ProcessorStep inheritance for pipeline steps (#1862) - Updated the DataProcessorPipeline to require that all steps inherit from ProcessorStep, enhancing type safety and clarity. - Adjusted tests to utilize a MockTokenizerProcessorStep that adheres to the ProcessorStep interface, ensuring consistent behavior across tests. - Refactored various mock step classes in tests to inherit from ProcessorStep for improved consistency and maintainability. * refactor(dependencies): remove scipy dependency and introduce custom rotation utilities (#1863) - Removed the scipy dependency from the project to streamline requirements. - Added a new `rotation.py` module containing a custom `Rotation` class that replicates essential functionalities of `scipy.spatial.transform.Rotation`, allowing for rotation vector, matrix, and quaternion conversions without external dependencies. - Updated the `robot_kinematic_processor.py` to utilize the new custom rotation utilities. * feat(teleoperation): introduce HasTeleopEvents protocol and enhance teleop event handling (#1866) - Added the HasTeleopEvents protocol to define a standard for teleoperators that provide control events. - Implemented a runtime check to ensure teleoperators implement the get_teleop_events() method. - Updated AddTeleopEventsAsInfoStep to utilize the new protocol, enhancing compatibility with custom teleoperators. - Improved documentation for clarity on teleoperation event extraction and compatibility with built-in teleoperators. * fix(deps): use in-house rotation utils over scipy throughout the codebase * refactor(constants): rename preprocessor and postprocessor constants for clarity (#1868) - Updated constant names from PREPROCESSOR_DEFAULT_NAME and POSTPROCESSOR_DEFAULT_NAME to POLICY_PREPROCESSOR_DEFAULT_NAME and POLICY_POSTPROCESSOR_DEFAULT_NAME for better context. - Adjusted references across multiple files to use the new constant names, ensuring consistency in the codebase. * refactor(tests): update processor test assertions to reflect new preprocessor and postprocessor names (#1869) - Changed assertions in multiple processor test files to verify the updated names from "robot_preprocessor" and "robot_postprocessor" to "policy_preprocessor" and "policy_postprocessor" for consistency with recent refactoring. * refactor(utils): simplify log_rerun_data function (#1864) * refactor(logging): enhance log_rerun_data to handle observation and action separately - Updated the `log_rerun_data` function to accept and log observation and action data more clearly, improving readability and maintainability. - Refactored the `record_loop` and `teleop_loop` functions to extract and pass observation and action data to `log_rerun_data`, ensuring consistent logging format. * refactor(tests): update test_log_rerun_data to align with log_rerun_data changes - Modified test cases in `test_visualization_utils.py` to extract and pass observation and action data separately to `log_rerun_data`, improving clarity and consistency with recent function updates. - Ensured that the tests reflect the new structure of `log_rerun_data` for better maintainability. * refactor(processors): simplify calls to log_rerun + replace lambda functions with identity_transition --------- Co-authored-by: Steven Palma <steven.palma@huggingface.co> * fix(processor): recover type inference for use of processors (#1873) * refactor(processors): Improve Normalization Processor Performance and Device/Dtype Adaptability (#1880) * refactor(processors): reorder processor steps for consistency across implementations - Updated the order of processor steps in multiple files to ensure consistency, placing AddBatchDimensionProcessorStep and DeviceProcessorStep before NormalizerProcessorStep. - Adjusted related test assertions to reflect the new order of steps in the preprocessor, enhancing clarity and maintainability. * refactor(normalization): remove dtype specification in tensor conversion for adaptation logic - Updated tensor conversion in the _NormalizationMixin class to remove explicit dtype specification, allowing for automatic adaptation of tensor types. - Adjusted related tests to ensure proper functionality with the new tensor conversion logic, verifying that normalizers adapt correctly to input types. * chore(docs): update doctrines pipeline files (#1872) * docs(processor): update docstrings batch_processor * docs(processor): update docstrings device_processor * docs(processor): update docstrings tokenizer_processor * update docstrings processor_act * update docstrings for pipeline_features * update docstrings for utils * update docstring for processor_diffusion * update docstrings factory * add docstrings to pi0 processor * add docstring to pi0fast processor * add docstring classifier processor * add docstring to sac processor * add docstring smolvla processor * add docstring to tdmpc processor * add docstring to vqbet processor * add docstrings to converters * add docstrings for delta_action_processor * add docstring to gym action processor * update hil processor * add docstring to joint obs processor * add docstring to migrate_normalize_processor * update docstrings normalize processor * update docstring normalize processor * update docstrings observation processor * update docstrings rename_processor * add docstrings robot_kinematic_processor * cleanup rl comments * add docstring to train.py * add docstring to teleoperate.py * add docstrings to phone_processor.py * add docstrings to teleop_phone.py * add docstrings to control_utils.py * add docstrings to visualization_utils.py --------- Co-authored-by: Pepijn <pepijn@huggingface.co> * refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions (#1900) * refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions - Updated the `rollout` and `eval_policy` functions to accept preprocessor and postprocessor parameters, enhancing the flexibility of the evaluation pipeline. - Adjusted the implementation to apply preprocessing and postprocessing steps during policy evaluation, improving the overall data handling and processing flow. * refactor(eval): remove redundant observation device conversion in rollout function - Eliminated unnecessary device conversion for the observation dictionary within the `rollout` function, streamlining the code and enhancing readability. - This change simplifies the observation handling process, aligning with the preference for clearer solutions. * debug * refactor(utils): enhance task handling in add_envs_task function - Improved the `add_envs_task` function to validate the output of `task_description` and `task` calls, ensuring they return lists of strings. - Removed the use of `else` statement for environments without language instructions, simplifying the logic and enhancing readability. - Streamlined the observation dictionary handling by ensuring consistent data types for task attributes. * refactor(converters): rename _from_tensor to from_tensor_to_numpy for clarity (#1902) - Updated the function name from _from_tensor to from_tensor_to_numpy to better reflect its purpose of converting PyTorch tensors to numpy arrays or scalars. - Adjusted all references to the renamed function throughout the codebase to maintain consistency. - Enhanced the _NormalizationMixin class to reconstruct the stats dictionary from tensor stats using the new function, ensuring compatibility after loading state dicts. - Added tests to verify the correct reconstruction of stats and functionality of methods dependent on self.stats after loading. * refactor(pipeline): feature contract now categorizes between OBS or Action (#1867) * refactor(processor): signature of transform_features * refactor(processor): remove prefixes + processor respect new transform_features signature + update test accordingly * refactor(processor): rename now is only for visual * refactor(processor): update normalize processor * refactor(processor): update vanilla processor features * refactor(processor): feature contract now uses its own enum * chore(processor): rename renameprocessor * chore(processor): minor changes * refactor(processor): add create & change aggregate * refactor(processor): update aggregate * refactor(processor): simplify to functions, fix features contracts and rename function * test(processor): remove to converter tests as now they are very simple * chore(docs): recover docs joint observations processor * fix(processor): update RKP * fix(tests): recv diff test_pipeline * chore(tests): add docs to test * chore(processor): leave obs language constant untouched * fix(processor): correct new shape of feature in crop image processor * refactor(eval): specify type parameters for preprocessor and postprocessor in eval_policy function (#1904) * chore(processor): remove action prefixes (#1905) * test(processor): all processors use now the same create_transition (#1906) * test(processor): all processors use now the same create_transition * test(processor): use identity instead of lambda for transition in pipelines * fix(processor): specialized processors respect contract by raising if none (#1909) * fix(processor): specialized processor now raise * test(processor): fix tests for now raise specialized processors * test(processor): use identity in newly introduced pipeline * refactor(processor): clarify action types, distinguish PolicyAction, RobotAction, and EnvAction (#1908) * refactor(processor): split action from policy, robots and environment - Updated function names to robot_action_to_transition and robot_transition_to_action across multiple files to better reflect their purpose in processing robot actions. - Adjusted references in the RobotProcessorPipeline and related components to ensure compatibility with the new naming convention. - Enhanced type annotations for action parameters to improve code readability and maintainability. * refactor(converters): rename robot_transition_to_action to transition_to_robot_action - Updated function names across multiple files to improve clarity and consistency in processing robot actions. - Adjusted references in RobotProcessorPipeline and related components to align with the new naming convention. - Simplified action handling in the AddBatchDimensionProcessorStep by removing unnecessary checks for action presence. * refactor(converters): update references to transition_to_robot_action - Renamed all instances of robot_transition_to_action to transition_to_robot_action across multiple files for consistency and clarity in the processing of robot actions. - Adjusted the RobotProcessorPipeline configurations to reflect the new naming convention, enhancing code readability. * refactor(processor): update Torch2NumpyActionProcessorStep to extend ActionProcessorStep - Changed the base class of Torch2NumpyActionProcessorStep from PolicyActionProcessorStep to ActionProcessorStep, aligning it with the current architecture of action processing. - This modification enhances the clarity of the class's role in the processing pipeline. * fix(processor): main action processor can take also EnvAction --------- Co-authored-by: Steven Palma <steven.palma@huggingface.co> * refactor(processor): phone processor is now an RobotActionProcessorStep * fix(processor): use subprocessors in AddBatchDimensionProcessorStep only if we have the ingredients * fix(robots): remove action prefix hard-coded in teleop keyboard and gamepad * feat(processor): enhance type safety with generic DataProcessorPipeline for policy and robot pipelines (#1915) * refactor(processor): enhance type annotations for processors in record, replay, teleoperate, and control utils - Updated type annotations for preprocessor and postprocessor parameters in record_loop and predict_action functions to specify the expected dictionary types. - Adjusted robot_action_processor type in ReplayConfig and TeleoperateConfig to improve clarity and maintainability. - Ensured consistency in type definitions across multiple files, enhancing overall code readability. * refactor(processor): enhance type annotations for RobotProcessorPipeline in various files - Updated type annotations for RobotProcessorPipeline instances in evaluate.py, record.py, replay.py, teleoperate.py, and other related files to specify input and output types more clearly. - Introduced new type conversions for PolicyAction and EnvTransition to improve type safety and maintainability across the processing pipelines. - Ensured consistency in type definitions, enhancing overall code readability and reducing potential runtime errors. * refactor(processor): update transition handling in processors to use transition_to_batch - Replaced direct transition handling with transition_to_batch in various processor tests and implementations to ensure consistent batching of input data. - Updated assertions in tests to reflect changes in data structure, enhancing clarity and maintainability. - Improved overall code readability by standardizing the way transitions are processed across different processor types. * refactor(tests): standardize transition key usage in processor tests - Updated assertions in processor test files to utilize the TransitionKey for action references, enhancing consistency across tests. - Replaced direct string references with TransitionKey constants for improved readability and maintainability. - Ensured that all relevant tests reflect these changes, contributing to a more uniform approach in handling transitions. * refactor(processor): unify action imports and enhance type clarity across multiple files - Updated imports in various files to include RobotAction and PolicyAction directly from the processor module, improving clarity and consistency. - Removed redundant imports from core, streamlining the codebase and enhancing maintainability. - Adjusted type annotations and references in the RobotProcessorPipeline and related components to align with the new import structure, ensuring better type safety and readability. * refactor(processor): migrate policy normalization to use factory functions - Updated the migration script to utilize `make_pre_post_processors` and `make_policy_config` from `lerobot.policies.factory`, enhancing consistency with the current codebase. - Improved normalization statistics extraction and processor pipeline creation, ensuring compatibility with the new `PolicyProcessorPipeline` architecture. - Cleaned up configuration handling by removing unnecessary fields and adding normalization mapping directly to the config. - Enhanced type safety and readability by refining feature type and normalization mode handling. * debug(scripts): simplify record with processors (#1918) Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> * refactor(processor): update migration script for policy normalization and hub integration - Modified the migration script to include a branch argument for pushing to the hub, enhancing flexibility in version control. - Improved error handling by ensuring the policy type is extracted from the configuration, promoting robustness. - Streamlined the process of saving and pushing model components to the hub, allowing for a single commit with optional PR creation. - Updated the commit message and description for better clarity on the migration changes and benefits, ensuring users are informed of the new architecture and usage. * fixes for processors used in phone teleop * fixes for rotation matrix * add empty obs and act in create_initial_features * use observation instead of obs * docs(processor): update docstrings pipeline (#1920) * chore(docs): Processor doc (#1685) * chore(docs): initialize doc * Added script for the second part of the processor doc * precommit style nit * improved part 2 of processor guide * Add comprehensive documentation for processors in robotics - Introduced a detailed guide on processors, covering their role in transforming raw robot data into model-ready inputs and vice versa. - Explained core concepts such as EnvTransition, ProcessorStep, and RobotProcessor, along with their functionalities. - Included examples of common processor steps like normalization, device management, batch processing, and text tokenization. - Provided insights on building complete pipelines, integrating processors into training loops, and saving/loading configurations. - Emphasized best practices and advanced features for effective usage of processors in robotics applications. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(docs): Enhance introduction to processors with additional converter functions - Updated the introduction to processors documentation to include default batch-to-transition and transition-to-batch converters. - Added detailed descriptions and examples for new specialized converter functions: `to_transition_teleop_action`, `to_transition_robot_observation`, `to_output_robot_action`, and `to_dataset_frame`. - Improved clarity on how these converters facilitate integration with existing robotics applications. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Improved doc implement_your_own_pipeline - Use normalization processor as default example - Add section on transform features - Add section on overrides. * Add phone docs and use pipeline for robots/teleop docs * Fix typo in documentation for adapters in robots/teleop section * Enhance documentation for processors with detailed explanations and examples - Updated the introduction to processors, clarifying the role of `EnvTransition` and `ProcessorStep`. - Introduced `DataProcessorPipeline` as a generic orchestrator for chaining processor steps. - Added comprehensive descriptions of new converter functions and their applications. - Improved clarity on type safety and the differences between `RobotProcessorPipeline` and `PolicyProcessorPipeline`. - Included examples for various processing scenarios, emphasizing best practices for data handling in robotics. * Enhance documentation for processor migration and debugging - Added detailed sections on the migration of models to the new `PolicyProcessorPipeline` system, including breaking changes and migration scripts. - Introduced a comprehensive guide for debugging processor pipelines, covering common issues, step-by-step inspection, and runtime monitoring techniques. - Updated examples to reflect new usage patterns and best practices for processor implementation and error handling. - Clarified the role of various processor steps and their configurations in the context of robotics applications. --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Pepijn <pepijn@huggingface.co> * docs: Add new section for debugging processor pipelines - Introduced a new documentation entry for debugging processor pipelines, enhancing the existing guide on processors. - This addition aims to provide users with insights and best practices for troubleshooting and optimizing their processor workflows. * fix(processor): phone examples (#1921) * fix(processor): phone examples * chore(processor): simplify gripper in phone example kinematic chain --------- Co-authored-by: Steven Palma <steven.palma@huggingface.co> * refactor(processors): several additions (#1926) * chore(processor): remove merge_transitions functions (#1925) * refactor(processors): move processors out of configs (#1927) * chore(processor): streamline combine_features_dict (#1928) * chore(policies): use new constants (#1929) * fix(deps): right version transformers (#1930) * fix(tests): add none + disable async tests for now (#1931) * refactor(processor): transform_features loop + EAFP (#1932) * fix(processors): make sure nested dict are also shallow copied (#1939) * refactor(processor): replace ModelHubMixin with HubMixin and enhance save_pretrained method (#1937) - Updated DataProcessorPipeline to use HubMixin instead of ModelHubMixin for improved functionality. - Refactored save_pretrained method to handle saving * refactor(docs): streamline monitoring hooks and enhance performance reporting - Removed the log_shapes and measure_performance hooks, simplifying the monitoring process to focus on NaN checks. - Updated performance reporting to include maximum processing times alongside average times for better insights. - Clarified documentation regarding the processing pipeline and feature transformations. * fix teleop, record and eval (#1940) * fix cmd record, eval * chore(processor): update input output of main 3 processors for better semantics (#1942) * chore(processor): update input output of main 3 processors for better semantics * refactor(processor): replace Any with RobotObservation for improved type safety in processors * fix(processors): no PolicyObservation * chore(processor): update with RobotObservation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * test(processor): fix batch expectation * feat(example): Add SO100 EE pipeline control (teleop+record) (#1943) * feat(examples): add ee so100 processors teleop & record * refactor(processor): improve FK processor for better use compatability * docs(processor): enhance tutorial on implementing custom processors - Updated the tutorial to use `NormalizerProcessorStep` as the primary example, clarifying its role in normalizing observations and actions. - Improved explanations of the need for custom processors, emphasizing data compatibility and processing requirements. - Added code snippets demonstrating the normalization process and the configuration of processor pipelines. - Enhanced the introduction to processors, detailing their function as translators between raw robot data and model inputs. - Included examples of real-world processor configurations for both training and inference scenarios. * docs(debug): enhance debugging guide for processor pipelines - Streamlined the introduction to clarify the challenges of debugging complex processor pipelines. - Expanded the section on hooks, detailing their purpose and implementation for runtime monitoring. - Introduced step-by-step debugging techniques, emphasizing the use of the `step_through()` method for inspecting intermediate states. - Added examples of feature validation to ensure data structure contracts are met. - Consolidated best practices for debugging, highlighting the synergy between hooks, step-through debugging, and feature validation. * chore(processors): tokenizers raises and remove tensor conversion (#1949) * chore(processor): remove unused transition_features dict * feat(ee): add so100_to_so100_EE replay and evaluate examples * chore(examples): homogenize style across example files (#1955) * chore(examples): homogenize style across example files * chore(examples): homogenize style across example files eval + replay * chore(examples): homogenize headers * test(async): fix feature manipulation (#1957) * test(async): fix feature manipulation * chore(processor): remove unused functions * fix(processor): Preserve stats overrides in normalizer load_state_dict and fix training resumption (#1958) * feat(processor): enhance normalization handling and state management - Added support for additional normalization modes including IDENTITY. - Introduced a new function `clean_state_dict` to remove specific substrings from state dict keys. - Implemented preservation of explicitly provided normalization statistics during state loading. - Updated training script to conditionally provide dataset statistics based on resume state. - Expanded tests to verify the correct behavior of stats override preservation and loading. * fix(train): remove redundant comment regarding state loading - Removed a comment that noted the preprocessor and postprocessor state is already loaded when resuming training, as it was deemed unnecessary for clarity. * test(processor): update tests to handle missing or invalid task keys - Modified tests to assert that the processor raises appropriate exceptions when the task key is missing or has an invalid value in the complementary data. - Ensured that the tests cover cases for None, integer, and mixed list task values, improving robustness against invalid inputs. * fix(processor): enforce signatures * chore(processor): update comments in record.py * test(processor): fix isinstance and cuda test * modify phone docs * fix(processor): reorder output steps to ensure correct processing sequence (#1961) - Moved DeviceProcessorStep to the end of the output steps in multiple processor files to maintain the intended processing order. - Updated corresponding tests to reflect the change in step order. * fix(processors): assumptions for robot_action_processor & teleop_action_processor (#1964) * fix(processors): new assumptions pipeline * fix(processors): ee jj phone teleop replay record working * chore(processors): update comments and default vars * chore(processor): remove unnecessary copy * chore(processor): added todo assumption gripper * fix(processors): eval using detected device * finish phone docs * fix correct image link * feat(processor): implement migration detection and error handling for processor configurations (#1968) * feat(processor): implement migration detection and error handling for processor configurations - Added ProcessorMigrationError to handle migration requirements for old model formats. - Enhanced DataProcessorPipeline.from_pretrained to include robust migration detection logic. - Implemented methods for resolving configuration sources, validating loaded configs, and checking for valid processor configurations. - Introduced comprehensive tests for migration detection and configuration validation to ensure correct behavior. * refactor(processor): simplify loading logic and enhance migration detection - Refactored DataProcessorPipeline to implement a simplified three-way loading strategy for configuration files. - Introduced explicit config_filename parameter to avoid ambiguity during loading. - Updated ProcessorMigrationError to provide clearer error messages for migration requirements. - Enhanced tests to cover new loading logic and ensure proper migration detection. - Removed deprecated methods related to config source resolution. * fix(processor) RL (#1953) * fix(gym_manipulator) general fixes to make it compitable * fix for dataset v3.0 * fix for gym_manipulator * add map policy action to robot action wrappers in a seperate scripts * added unittest for policy to robot bridge * fixes for gripper penalty * fix style * fix gamepad controller * fixes for sim teleop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify numpy2torch to a regular processor as a quick fix * missing imports?! * - Removed the use of `AddRobotObservationAsComplimentaryData` from `gym_manipulator` and thus the codebase - Added get_raw_joint_positions functions to RobotEnv - Pass raw_joint_positions as input to the action_pipeline in `gym_manipulator` - Add `InverseKinematicsRLStep` to be tailored towards the need of RL which requires the use of the IK solution as the main reference point of the control loop - Added the option `use_ik_solution` in `EEReferenceDelta` step to rely on the ik solution rather than the joint values * -Updated links to all the config files to place them in the new repo with configs compatible with the pipeline --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> * fix(tests): update test cases for loading pipelines with specific config filenames - Modified test cases to include explicit configuration filenames when loading pipelines in `test_policy_robot_bridge.py`. - Ensured that the tests reflect the correct loading behavior for both robot-to-policy and policy-to-robot transitions. * fix(examples): train mps processor (#1970) * fix(examples): train mps processor * fix(processor): add MPS compatibility for float64 tensors - Implemented a workaround to convert float64 tensors to float32 when using the MPS device, as MPS does not support float64. - Added unit tests to verify the automatic conversion of float64 tensors to float32 and ensure compatibility with various tensor types on the MPS device. --------- Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> --------- Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Steven Palma <steven.palma@huggingface.co> Co-authored-by: Pepijn <pepijn@huggingface.co>
This commit is contained in:
412
tests/processor/test_act_processor.py
Normal file
412
tests/processor/test_act_processor.py
Normal file
@@ -0,0 +1,412 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for ACT policy processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.processor_act import make_act_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, transition_to_batch
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default ACT configuration for testing."""
|
||||
config = ACTConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(7,)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
ACTION: {"mean": torch.zeros(4), "std": torch.ones(4)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_act_processor_basic():
|
||||
"""Test basic creation of ACT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
|
||||
|
||||
|
||||
def test_act_processor_normalization():
|
||||
"""Test that ACT processor correctly normalizes and unnormalizes data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is normalized and batched
|
||||
assert processed[OBS_STATE].shape == (1, 7)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1, 4)
|
||||
|
||||
# Process action through postprocessor
|
||||
postprocessed = postprocessor(processed[TransitionKey.ACTION.value])
|
||||
|
||||
# Check that action is unnormalized
|
||||
assert postprocessed.shape == (1, 4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_cuda():
|
||||
"""Test ACT processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[OBS_STATE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION.value].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
postprocessed = postprocessor(processed[TransitionKey.ACTION.value])
|
||||
|
||||
# Check that action is back on CPU
|
||||
assert postprocessed.device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_accelerate_scenario():
|
||||
"""Test ACT processor in simulated Accelerate scenario (data already on GPU)."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on same GPU (not moved unnecessarily)
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_act_processor_multi_gpu():
|
||||
"""Test ACT processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate data on different GPU (like in multi-GPU training)
|
||||
device = torch.device("cuda:1")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)}
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on cuda:1 (not moved to cuda:0)
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
def test_act_processor_without_stats():
|
||||
"""Test ACT processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
dataset_stats=None,
|
||||
)
|
||||
|
||||
# Should still create processors, but normalization won't have stats
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work (but won't normalize without stats)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_act_processor_save_and_load():
|
||||
"""Test saving and loading ACT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, config_filename="policy_preprocessor.json"
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = loaded_preprocessor(batch)
|
||||
assert processed[OBS_STATE].shape == (1, 7)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1, 4)
|
||||
|
||||
|
||||
def test_act_processor_device_placement_preservation():
|
||||
"""Test that ACT processor preserves device placement correctly."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
# Test with CPU config
|
||||
config.device = "cpu"
|
||||
preprocessor, _ = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Process CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
assert processed[OBS_STATE].device.type == "cpu"
|
||||
assert processed[TransitionKey.ACTION.value].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_mixed_precision():
|
||||
"""Test ACT processor with mixed precision (float16)."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Modify the device processor to use float16
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Replace DeviceProcessorStep with one that uses float16
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Update normalizer to use the same device as the device processor
|
||||
norm_step = step # Now type checker knows this is NormalizerProcessorStep
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=norm_step.features,
|
||||
norm_map=norm_step.norm_map,
|
||||
stats=norm_step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float16, # Match the float16 dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
|
||||
action = torch.randn(4, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[OBS_STATE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.float16
|
||||
|
||||
|
||||
def test_act_processor_batch_consistency():
|
||||
"""Test that ACT processor handles different batch sizes correctly."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Test single sample (unbatched)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
assert processed["observation.state"].shape[0] == 1 # Batched
|
||||
|
||||
# Test already batched data
|
||||
observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8
|
||||
action_batched = torch.randn(8, 4)
|
||||
transition_batched = create_transition(observation_batched, action_batched)
|
||||
batch_batched = transition_to_batch(transition_batched)
|
||||
|
||||
processed_batched = preprocessor(batch_batched)
|
||||
assert processed_batched[OBS_STATE].shape[0] == 8
|
||||
assert processed_batched[TransitionKey.ACTION.value].shape[0] == 8
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, _ = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
norm_step = step # Now type checker knows this is NormalizerProcessorStep
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=norm_step.features,
|
||||
norm_map=norm_step.norm_map,
|
||||
stats=norm_step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration
|
||||
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32
|
||||
action = torch.randn(4, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[OBS_STATE].dtype == torch.bfloat16
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
@@ -1,11 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.processor.pipeline import (
|
||||
RobotProcessor,
|
||||
TransitionKey,
|
||||
_default_batch_to_transition,
|
||||
_default_transition_to_batch,
|
||||
)
|
||||
from lerobot.processor import DataProcessorPipeline, TransitionKey
|
||||
from lerobot.processor.converters import batch_to_transition, transition_to_batch
|
||||
|
||||
|
||||
def _dummy_batch():
|
||||
@@ -24,7 +20,7 @@ def _dummy_batch():
|
||||
|
||||
def test_observation_grouping_roundtrip():
|
||||
"""Test that observation.* keys are properly grouped and ungrouped."""
|
||||
proc = RobotProcessor([])
|
||||
proc = DataProcessorPipeline([])
|
||||
batch_in = _dummy_batch()
|
||||
batch_out = proc(batch_in)
|
||||
|
||||
@@ -48,19 +44,19 @@ def test_observation_grouping_roundtrip():
|
||||
|
||||
|
||||
def test_batch_to_transition_observation_grouping():
|
||||
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
|
||||
"""Test that batch_to_transition correctly groups observation.* keys."""
|
||||
batch = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
"action": "action_data",
|
||||
"action": torch.tensor([0.1, 0.2, 0.3, 0.4]),
|
||||
"next.reward": 1.5,
|
||||
"next.done": True,
|
||||
"next.truncated": False,
|
||||
"info": {"episode": 42},
|
||||
}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation is a dict with all observation.* keys
|
||||
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
|
||||
@@ -78,7 +74,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4]))
|
||||
assert transition[TransitionKey.REWARD] == 1.5
|
||||
assert transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
@@ -87,7 +83,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
|
||||
|
||||
def test_transition_to_batch_observation_flattening():
|
||||
"""Test that _default_transition_to_batch correctly flattens observation dict."""
|
||||
"""Test that transition_to_batch correctly flattens observation dict."""
|
||||
observation_dict = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
@@ -104,7 +100,7 @@ def test_transition_to_batch_observation_flattening():
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
batch = _default_transition_to_batch(transition)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Check that observation.* keys are flattened back to batch
|
||||
assert "observation.image.top" in batch
|
||||
@@ -127,28 +123,28 @@ def test_transition_to_batch_observation_flattening():
|
||||
def test_no_observation_keys():
|
||||
"""Test behavior when there are no observation.* keys."""
|
||||
batch = {
|
||||
"action": "action_data",
|
||||
"action": torch.tensor([1.0, 2.0]),
|
||||
"next.reward": 2.0,
|
||||
"next.done": False,
|
||||
"next.truncated": True,
|
||||
"info": {"test": "no_obs"},
|
||||
}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Observation should be None when no observation.* keys
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([1.0, 2.0]))
|
||||
assert transition[TransitionKey.REWARD] == 2.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert transition[TransitionKey.TRUNCATED]
|
||||
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
|
||||
|
||||
# Round trip should work
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
assert reconstructed_batch["action"] == "action_data"
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([1.0, 2.0]))
|
||||
assert reconstructed_batch["next.reward"] == 2.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert reconstructed_batch["next.truncated"]
|
||||
@@ -157,13 +153,13 @@ def test_no_observation_keys():
|
||||
|
||||
def test_minimal_batch():
|
||||
"""Test with minimal batch containing only observation.* and action."""
|
||||
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
|
||||
batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert transition[TransitionKey.ACTION] == "minimal_action"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5]))
|
||||
|
||||
# Check defaults
|
||||
assert transition[TransitionKey.REWARD] == 0.0
|
||||
@@ -173,9 +169,9 @@ def test_minimal_batch():
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||
assert reconstructed_batch["action"] == "minimal_action"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5]))
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert not reconstructed_batch["next.truncated"]
|
||||
@@ -186,7 +182,7 @@ def test_empty_batch():
|
||||
"""Test behavior with empty batch."""
|
||||
batch = {}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# All fields should have defaults
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
@@ -198,7 +194,7 @@ def test_empty_batch():
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["action"] is None
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
@@ -219,8 +215,8 @@ def test_complex_nested_observation():
|
||||
"info": {"episode_length": 200, "success": True},
|
||||
}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
transition = batch_to_transition(batch)
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
|
||||
# Check that all observation keys are preserved
|
||||
original_obs_keys = {k for k in batch if k.startswith("observation.")}
|
||||
@@ -254,7 +250,7 @@ def test_custom_converter():
|
||||
|
||||
def to_tr(batch):
|
||||
# Custom converter that modifies the reward
|
||||
tr = _default_batch_to_transition(batch)
|
||||
tr = batch_to_transition(batch)
|
||||
# Double the reward
|
||||
reward = tr.get(TransitionKey.REWARD, 0.0)
|
||||
new_tr = tr.copy()
|
||||
@@ -262,10 +258,10 @@ def test_custom_converter():
|
||||
return new_tr
|
||||
|
||||
def to_batch(tr):
|
||||
batch = _default_transition_to_batch(tr)
|
||||
batch = transition_to_batch(tr)
|
||||
return batch
|
||||
|
||||
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
|
||||
processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch)
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 4),
|
||||
|
||||
1184
tests/processor/test_batch_processor.py
Normal file
1184
tests/processor/test_batch_processor.py
Normal file
File diff suppressed because it is too large
Load Diff
362
tests/processor/test_classifier_processor.py
Normal file
362
tests/processor/test_classifier_processor.py
Normal file
@@ -0,0 +1,362 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for Reward Classifier processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||
from lerobot.processor import (
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
IdentityProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, transition_to_batch
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default Reward Classifier configuration for testing."""
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), # Classifier output
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.IDENTITY, # No normalization for classifier output
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)},
|
||||
OBS_IMAGE: {}, # No normalization for images
|
||||
"reward": {}, # No normalization for classifier output
|
||||
}
|
||||
|
||||
|
||||
def test_make_classifier_processor_basic():
|
||||
"""Test basic creation of Classifier processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "classifier_preprocessor"
|
||||
assert postprocessor.name == "classifier_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 3
|
||||
assert isinstance(preprocessor.steps[0], NormalizerProcessorStep) # For input features
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) # For output features
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], IdentityProcessorStep)
|
||||
|
||||
|
||||
def test_classifier_processor_normalization():
|
||||
"""Test that Classifier processor correctly normalizes data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1) # Dummy action/reward
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is processed
|
||||
assert processed[OBS_STATE].shape == (10,)
|
||||
assert processed[OBS_IMAGE].shape == (3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1,)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_classifier_processor_cuda():
|
||||
"""Test Classifier processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[OBS_STATE].device.type == "cuda"
|
||||
assert processed[OBS_IMAGE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION.value].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
postprocessed = postprocessor(processed[TransitionKey.ACTION.value])
|
||||
|
||||
# Check that output is back on CPU
|
||||
assert postprocessed.device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_classifier_processor_accelerate_scenario():
|
||||
"""Test Classifier processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10).to(device),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_classifier_processor_multi_gpu():
|
||||
"""Test Classifier processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10).to(device),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
def test_classifier_processor_without_stats():
|
||||
"""Test Classifier processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_classifier_processor_save_and_load():
|
||||
"""Test saving and loading Classifier processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, config_filename="classifier_preprocessor.json"
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = loaded_preprocessor(batch)
|
||||
assert processed[OBS_STATE].shape == (10,)
|
||||
assert processed[OBS_IMAGE].shape == (3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1,)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_classifier_processor_mixed_precision():
|
||||
"""Test Classifier processor with mixed precision."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessorStep with one that uses float16
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(1, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[OBS_STATE].dtype == torch.float16
|
||||
assert processed[OBS_IMAGE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.float16
|
||||
|
||||
|
||||
def test_classifier_processor_batch_data():
|
||||
"""Test Classifier processor with batched data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 16
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(batch_size, 10),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 1)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed[OBS_STATE].shape == (batch_size, 10)
|
||||
assert processed[OBS_IMAGE].shape == (batch_size, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (batch_size, 1)
|
||||
|
||||
|
||||
def test_classifier_processor_postprocessor_identity():
|
||||
"""Test that Classifier postprocessor uses IdentityProcessor correctly."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create test data for postprocessor
|
||||
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions
|
||||
transition = create_transition(action=reward)
|
||||
|
||||
_ = transition_to_batch(transition)
|
||||
|
||||
# Process through postprocessor
|
||||
processed = postprocessor(reward)
|
||||
|
||||
# IdentityProcessor should leave values unchanged (except device)
|
||||
assert torch.allclose(processed.cpu(), reward.cpu())
|
||||
assert processed.device.type == "cpu"
|
||||
292
tests/processor/test_converters.py
Normal file
292
tests/processor/test_converters.py
Normal file
@@ -0,0 +1,292 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.processor.converters import (
|
||||
batch_to_transition,
|
||||
create_transition,
|
||||
to_tensor,
|
||||
transition_to_batch,
|
||||
)
|
||||
|
||||
|
||||
# Tests for the unified to_tensor function
|
||||
def test_to_tensor_numpy_arrays():
|
||||
"""Test to_tensor with various numpy arrays."""
|
||||
# Regular numpy array
|
||||
arr = np.array([1.0, 2.0, 3.0])
|
||||
result = to_tensor(arr)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
# Different numpy dtypes should convert to float32 by default
|
||||
int_arr = np.array([1, 2, 3], dtype=np.int64)
|
||||
result = to_tensor(int_arr)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
# uint8 arrays (previously "preserved") should now convert
|
||||
uint8_arr = np.array([100, 150, 200], dtype=np.uint8)
|
||||
result = to_tensor(uint8_arr)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([100.0, 150.0, 200.0]))
|
||||
|
||||
|
||||
def test_to_tensor_numpy_scalars():
|
||||
"""Test to_tensor with numpy scalars (0-dimensional arrays)."""
|
||||
# numpy float32 scalar
|
||||
scalar = np.float32(3.14)
|
||||
result = to_tensor(scalar)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.ndim == 0 # Should be 0-dimensional tensor
|
||||
assert result.dtype == torch.float32
|
||||
assert result.item() == pytest.approx(3.14)
|
||||
|
||||
# numpy int32 scalar
|
||||
int_scalar = np.int32(42)
|
||||
result = to_tensor(int_scalar)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.ndim == 0
|
||||
assert result.dtype == torch.float32
|
||||
assert result.item() == pytest.approx(42.0)
|
||||
|
||||
|
||||
def test_to_tensor_python_scalars():
|
||||
"""Test to_tensor with Python scalars."""
|
||||
# Python int
|
||||
result = to_tensor(42)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert result.item() == pytest.approx(42.0)
|
||||
|
||||
# Python float
|
||||
result = to_tensor(3.14)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert result.item() == pytest.approx(3.14)
|
||||
|
||||
|
||||
def test_to_tensor_sequences():
|
||||
"""Test to_tensor with lists and tuples."""
|
||||
# List
|
||||
result = to_tensor([1, 2, 3])
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
# Tuple
|
||||
result = to_tensor((4.5, 5.5, 6.5))
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([4.5, 5.5, 6.5]))
|
||||
|
||||
|
||||
def test_to_tensor_existing_tensors():
|
||||
"""Test to_tensor with existing PyTorch tensors."""
|
||||
# Tensor with same dtype should pass through with potential device change
|
||||
tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
|
||||
result = to_tensor(tensor)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, tensor)
|
||||
|
||||
# Tensor with different dtype should convert
|
||||
int_tensor = torch.tensor([1, 2, 3], dtype=torch.int64)
|
||||
result = to_tensor(int_tensor)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
|
||||
def test_to_tensor_dictionaries():
|
||||
"""Test to_tensor with nested dictionaries."""
|
||||
# Simple dictionary
|
||||
data = {"mean": [0.1, 0.2], "std": np.array([1.0, 2.0]), "count": 42}
|
||||
result = to_tensor(data)
|
||||
assert isinstance(result, dict)
|
||||
assert isinstance(result["mean"], torch.Tensor)
|
||||
assert isinstance(result["std"], torch.Tensor)
|
||||
assert isinstance(result["count"], torch.Tensor)
|
||||
assert torch.allclose(result["mean"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(result["std"], torch.tensor([1.0, 2.0]))
|
||||
assert result["count"].item() == pytest.approx(42.0)
|
||||
|
||||
# Nested dictionary
|
||||
nested = {
|
||||
"action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
|
||||
"observation": {"mean": np.array([0.5, 0.6]), "count": 10},
|
||||
}
|
||||
result = to_tensor(nested)
|
||||
assert isinstance(result, dict)
|
||||
assert isinstance(result["action"], dict)
|
||||
assert isinstance(result["observation"], dict)
|
||||
assert isinstance(result["action"]["mean"], torch.Tensor)
|
||||
assert isinstance(result["observation"]["mean"], torch.Tensor)
|
||||
assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2]))
|
||||
assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6]))
|
||||
|
||||
|
||||
def test_to_tensor_none_filtering():
|
||||
"""Test that None values are filtered out from dictionaries."""
|
||||
data = {"valid": [1, 2, 3], "none_value": None, "nested": {"valid": [4, 5], "also_none": None}}
|
||||
result = to_tensor(data)
|
||||
assert "none_value" not in result
|
||||
assert "also_none" not in result["nested"]
|
||||
assert "valid" in result
|
||||
assert "valid" in result["nested"]
|
||||
assert torch.allclose(result["valid"], torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
|
||||
def test_to_tensor_dtype_parameter():
|
||||
"""Test to_tensor with different dtype parameters."""
|
||||
arr = np.array([1, 2, 3])
|
||||
|
||||
# Default dtype (float32)
|
||||
result = to_tensor(arr)
|
||||
assert result.dtype == torch.float32
|
||||
|
||||
# Explicit float32
|
||||
result = to_tensor(arr, dtype=torch.float32)
|
||||
assert result.dtype == torch.float32
|
||||
|
||||
# Float64
|
||||
result = to_tensor(arr, dtype=torch.float64)
|
||||
assert result.dtype == torch.float64
|
||||
|
||||
# Preserve original dtype
|
||||
float64_arr = np.array([1.0, 2.0, 3.0], dtype=np.float64)
|
||||
result = to_tensor(float64_arr, dtype=None)
|
||||
assert result.dtype == torch.float64
|
||||
|
||||
|
||||
def test_to_tensor_device_parameter():
|
||||
"""Test to_tensor with device parameter."""
|
||||
arr = np.array([1.0, 2.0, 3.0])
|
||||
|
||||
# CPU device (default)
|
||||
result = to_tensor(arr, device="cpu")
|
||||
assert result.device.type == "cpu"
|
||||
|
||||
# CUDA device (if available)
|
||||
if torch.cuda.is_available():
|
||||
result = to_tensor(arr, device="cuda")
|
||||
assert result.device.type == "cuda"
|
||||
|
||||
|
||||
def test_to_tensor_empty_dict():
|
||||
"""Test to_tensor with empty dictionary."""
|
||||
result = to_tensor({})
|
||||
assert isinstance(result, dict)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_to_tensor_unsupported_type():
|
||||
"""Test to_tensor with unsupported types raises TypeError."""
|
||||
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
|
||||
to_tensor("unsupported_string")
|
||||
|
||||
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
|
||||
to_tensor(object())
|
||||
|
||||
|
||||
def test_batch_to_transition_with_index_fields():
|
||||
"""Test that batch_to_transition handles index and task_index fields correctly."""
|
||||
|
||||
# Create batch with index and task_index fields
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
"next.reward": 1.5,
|
||||
"next.done": False,
|
||||
"task": ["pick_cube"],
|
||||
"index": torch.tensor([42], dtype=torch.int64),
|
||||
"task_index": torch.tensor([3], dtype=torch.int64),
|
||||
}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check basic transition structure
|
||||
assert TransitionKey.OBSERVATION in transition
|
||||
assert TransitionKey.ACTION in transition
|
||||
assert TransitionKey.COMPLEMENTARY_DATA in transition
|
||||
|
||||
# Check that index and task_index are in complementary_data
|
||||
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert "index" in comp_data
|
||||
assert "task_index" in comp_data
|
||||
assert "task" in comp_data
|
||||
|
||||
# Verify values
|
||||
assert torch.equal(comp_data["index"], batch["index"])
|
||||
assert torch.equal(comp_data["task_index"], batch["task_index"])
|
||||
assert comp_data["task"] == batch["task"]
|
||||
|
||||
|
||||
def testtransition_to_batch_with_index_fields():
|
||||
"""Test that transition_to_batch handles index and task_index fields correctly."""
|
||||
|
||||
# Create transition with index and task_index in complementary_data
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
reward=1.5,
|
||||
done=False,
|
||||
complementary_data={
|
||||
"task": ["navigate"],
|
||||
"index": torch.tensor([100], dtype=torch.int64),
|
||||
"task_index": torch.tensor([5], dtype=torch.int64),
|
||||
},
|
||||
)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Check that index and task_index are in the batch
|
||||
assert "index" in batch
|
||||
assert "task_index" in batch
|
||||
assert "task" in batch
|
||||
|
||||
# Verify values
|
||||
assert torch.equal(batch["index"], transition[TransitionKey.COMPLEMENTARY_DATA]["index"])
|
||||
assert torch.equal(batch["task_index"], transition[TransitionKey.COMPLEMENTARY_DATA]["task_index"])
|
||||
assert batch["task"] == transition[TransitionKey.COMPLEMENTARY_DATA]["task"]
|
||||
|
||||
|
||||
def test_batch_to_transition_without_index_fields():
|
||||
"""Test that conversion works without index and task_index fields."""
|
||||
|
||||
# Batch without index/task_index
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 7),
|
||||
"action": torch.randn(1, 4),
|
||||
"task": ["pick_cube"],
|
||||
}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
# Should have task but not index/task_index
|
||||
assert "task" in comp_data
|
||||
assert "index" not in comp_data
|
||||
assert "task_index" not in comp_data
|
||||
|
||||
|
||||
def test_transition_to_batch_without_index_fields():
|
||||
"""Test that conversion works without index and task_index fields."""
|
||||
|
||||
# Transition without index/task_index
|
||||
transition = create_transition(
|
||||
observation={"observation.state": torch.randn(1, 7)},
|
||||
action=torch.randn(1, 4),
|
||||
complementary_data={"task": ["navigate"]},
|
||||
)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Should have task but not index/task_index
|
||||
assert "task" in batch
|
||||
assert "index" not in batch
|
||||
assert "task_index" not in batch
|
||||
1161
tests/processor/test_device_processor.py
Normal file
1161
tests/processor/test_device_processor.py
Normal file
File diff suppressed because it is too large
Load Diff
398
tests/processor/test_diffusion_processor.py
Normal file
398
tests/processor/test_diffusion_processor.py
Normal file
@@ -0,0 +1,398 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for Diffusion policy processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, transition_to_batch
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default Diffusion configuration for testing."""
|
||||
config = DiffusionConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(7,)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
OBS_IMAGE: {}, # No normalization for images
|
||||
ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_diffusion_processor_basic():
|
||||
"""Test basic creation of Diffusion processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
|
||||
|
||||
|
||||
def test_diffusion_processor_with_images():
|
||||
"""Test Diffusion processor with image observations."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create test data with images
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(7),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is batched
|
||||
assert processed[OBS_STATE].shape == (1, 7)
|
||||
assert processed[OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1, 6)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_diffusion_processor_cuda():
|
||||
"""Test Diffusion processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(7),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[OBS_STATE].device.type == "cuda"
|
||||
assert processed[OBS_IMAGE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION.value].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
postprocessed = postprocessor(processed[TransitionKey.ACTION.value])
|
||||
|
||||
# Check that action is back on CPU
|
||||
assert postprocessed.device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_diffusion_processor_accelerate_scenario():
|
||||
"""Test Diffusion processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 7).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_diffusion_processor_multi_gpu():
|
||||
"""Test Diffusion processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 7).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
def test_diffusion_processor_without_stats():
|
||||
"""Test Diffusion processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
dataset_stats=None,
|
||||
)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(7),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_diffusion_processor_save_and_load():
|
||||
"""Test saving and loading Diffusion processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, config_filename="policy_preprocessor.json"
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(7),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = loaded_preprocessor(batch)
|
||||
assert processed[OBS_STATE].shape == (1, 7)
|
||||
assert processed[OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1, 6)
|
||||
|
||||
|
||||
def test_diffusion_processor_identity_normalization():
|
||||
"""Test that images with IDENTITY normalization are not normalized."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create test data
|
||||
image_value = torch.rand(3, 224, 224) * 255 # Large values
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(7),
|
||||
OBS_IMAGE: image_value.clone(),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Image should not be normalized (IDENTITY mode)
|
||||
# Just batched
|
||||
assert torch.allclose(processed[OBS_IMAGE][0], image_value, rtol=1e-5)
|
||||
|
||||
|
||||
def test_diffusion_processor_batch_consistency():
|
||||
"""Test Diffusion processor with different batch sizes."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Test with different batch sizes
|
||||
for batch_size in [1, 8, 32]:
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(batch_size, 7) if batch_size > 1 else torch.randn(7),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224) if batch_size > 1 else torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 6) if batch_size > 1 else torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check correct batch size
|
||||
expected_batch = batch_size if batch_size > 1 else 1
|
||||
assert processed[OBS_STATE].shape[0] == expected_batch
|
||||
assert processed[OBS_IMAGE].shape[0] == expected_batch
|
||||
assert processed[TransitionKey.ACTION.value].shape[0] == expected_batch
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_diffusion_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, _ = make_diffusion_pre_post_processors(config, stats)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
norm_step = step # Now type checker knows this is NormalizerProcessorStep
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=norm_step.features,
|
||||
norm_map=norm_step.norm_map,
|
||||
stats=norm_step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration
|
||||
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data with both state and visual observations
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(7, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[OBS_STATE].dtype == torch.bfloat16
|
||||
assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
# Check state stats (has normalization)
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
|
||||
341
tests/processor/test_migration_detection.py
Normal file
341
tests/processor/test_migration_detection.py
Normal file
@@ -0,0 +1,341 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Tests for processor migration detection functionality.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError
|
||||
|
||||
|
||||
def test_is_processor_config_valid_configs():
|
||||
"""Test processor config detection with valid configurations."""
|
||||
valid_configs = [
|
||||
{"steps": []}, # Empty steps
|
||||
{"steps": [{"class": "MyClass"}]}, # Class-based step
|
||||
{"steps": [{"registry_name": "my_step"}]}, # Registry-based step
|
||||
{"steps": [{"class": "A"}, {"registry_name": "B"}]}, # Mixed
|
||||
{"name": "Test", "steps": [{"class": "MyClass"}]}, # With name
|
||||
]
|
||||
|
||||
for i, config in enumerate(valid_configs):
|
||||
assert DataProcessorPipeline._is_processor_config(config), (
|
||||
f"Valid config {i} should be detected as processor config: {config}"
|
||||
)
|
||||
|
||||
|
||||
def test_is_processor_config_invalid_configs():
|
||||
"""Test processor config detection with invalid configurations."""
|
||||
invalid_configs = [
|
||||
{}, # No steps field
|
||||
{"steps": "not a list"}, # Steps is not a list
|
||||
{"steps": [{}]}, # Step without class or registry_name
|
||||
{"steps": ["not a dict"]}, # Step is not a dict
|
||||
{"steps": [{"other_field": "value"}]}, # Step with wrong fields
|
||||
{"other_field": "value"}, # Completely different structure
|
||||
]
|
||||
|
||||
for i, config in enumerate(invalid_configs):
|
||||
assert not DataProcessorPipeline._is_processor_config(config), (
|
||||
f"Invalid config {i} should not be detected as processor config: {config}"
|
||||
)
|
||||
|
||||
|
||||
def test_should_suggest_migration_with_processor_config():
|
||||
"""Test that migration is NOT suggested when processor config exists."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create a valid processor config
|
||||
processor_config = {
|
||||
"name": "TestProcessor",
|
||||
"steps": [
|
||||
{
|
||||
"class": "lerobot.processor.normalize.NormalizeStep",
|
||||
"config": {"mean": 0.0, "std": 1.0},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
with open(tmp_path / "processor.json", "w") as f:
|
||||
json.dump(processor_config, f)
|
||||
|
||||
# Should NOT suggest migration (processor config exists)
|
||||
result = DataProcessorPipeline._should_suggest_migration(tmp_path)
|
||||
assert not result
|
||||
|
||||
|
||||
def test_should_suggest_migration_with_empty_processor_config():
|
||||
"""Test that migration is NOT suggested when empty processor config exists."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create an empty processor config
|
||||
empty_processor_config = {
|
||||
"name": "EmptyProcessor",
|
||||
"steps": [], # Empty steps is valid
|
||||
}
|
||||
|
||||
with open(tmp_path / "empty_processor.json", "w") as f:
|
||||
json.dump(empty_processor_config, f)
|
||||
|
||||
# Should NOT suggest migration (processor config exists, even if empty)
|
||||
result = DataProcessorPipeline._should_suggest_migration(tmp_path)
|
||||
assert not result
|
||||
|
||||
|
||||
def test_should_suggest_migration_with_model_config_only():
|
||||
"""Test that migration IS suggested when only model config exists."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create a model config (like old LeRobot format)
|
||||
model_config = {
|
||||
"type": "act",
|
||||
"input_features": {"observation.state": {"shape": [7]}},
|
||||
"output_features": {"action": {"shape": [7]}},
|
||||
"hidden_dim": 256,
|
||||
"n_obs_steps": 1,
|
||||
"n_action_steps": 1,
|
||||
}
|
||||
|
||||
with open(tmp_path / "config.json", "w") as f:
|
||||
json.dump(model_config, f)
|
||||
|
||||
# SHOULD suggest migration (model config exists but no processor)
|
||||
result = DataProcessorPipeline._should_suggest_migration(tmp_path)
|
||||
assert result
|
||||
|
||||
|
||||
def test_should_suggest_migration_no_json_files():
|
||||
"""Test that migration is NOT suggested when no JSON files exist."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create some non-JSON files
|
||||
with open(tmp_path / "model.safetensors", "w") as f:
|
||||
f.write("fake model data")
|
||||
|
||||
with open(tmp_path / "README.md", "w") as f:
|
||||
f.write("# Model README")
|
||||
|
||||
# Should NOT suggest migration (no JSON files)
|
||||
result = DataProcessorPipeline._should_suggest_migration(tmp_path)
|
||||
assert not result
|
||||
|
||||
|
||||
def test_should_suggest_migration_random_json_files():
|
||||
"""Test that migration IS suggested when JSON files exist but none are processor configs."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create some random JSON file (not a processor config)
|
||||
random_config = {"some_field": "some_value", "another_field": 123}
|
||||
|
||||
with open(tmp_path / "random.json", "w") as f:
|
||||
json.dump(random_config, f)
|
||||
|
||||
# SHOULD suggest migration (JSON files exist but none are processor configs)
|
||||
result = DataProcessorPipeline._should_suggest_migration(tmp_path)
|
||||
assert result
|
||||
|
||||
|
||||
def test_should_suggest_migration_mixed_configs():
|
||||
"""Test that migration is NOT suggested when processor config exists alongside other configs."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create both a processor config and a model config
|
||||
processor_config = {"name": "TestProcessor", "steps": [{"registry_name": "normalize_step"}]}
|
||||
|
||||
model_config = {"type": "diffusion", "hidden_dim": 512}
|
||||
|
||||
with open(tmp_path / "processor.json", "w") as f:
|
||||
json.dump(processor_config, f)
|
||||
|
||||
with open(tmp_path / "config.json", "w") as f:
|
||||
json.dump(model_config, f)
|
||||
|
||||
# Should NOT suggest migration (processor config exists)
|
||||
result = DataProcessorPipeline._should_suggest_migration(tmp_path)
|
||||
assert not result
|
||||
|
||||
|
||||
def test_should_suggest_migration_invalid_json():
|
||||
"""Test that invalid JSON is handled gracefully."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create an invalid JSON file
|
||||
with open(tmp_path / "invalid.json", "w") as f:
|
||||
f.write("{ invalid json")
|
||||
|
||||
# Create a valid non-processor config
|
||||
model_config = {"type": "act"}
|
||||
with open(tmp_path / "model.json", "w") as f:
|
||||
json.dump(model_config, f)
|
||||
|
||||
# SHOULD suggest migration (invalid JSON is ignored, but we have non-processor JSON)
|
||||
result = DataProcessorPipeline._should_suggest_migration(tmp_path)
|
||||
assert result
|
||||
|
||||
|
||||
def test_from_pretrained_multiple_json_files_migration_error():
|
||||
"""Test that multiple JSON files trigger ProcessorMigrationError."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create multiple non-processor configs
|
||||
model_config = {"type": "act", "hidden_dim": 128}
|
||||
train_config = {"batch_size": 32, "lr": 0.001}
|
||||
|
||||
with open(tmp_path / "config.json", "w") as f:
|
||||
json.dump(model_config, f)
|
||||
|
||||
with open(tmp_path / "train_config.json", "w") as f:
|
||||
json.dump(train_config, f)
|
||||
|
||||
# Should raise ProcessorMigrationError
|
||||
with pytest.raises(ProcessorMigrationError) as exc_info:
|
||||
DataProcessorPipeline.from_pretrained(tmp_path, config_filename="config.json")
|
||||
|
||||
# Check the error details
|
||||
error = exc_info.value
|
||||
assert str(tmp_path) in str(error.model_path)
|
||||
assert "migrate_policy_normalization.py" in error.migration_command
|
||||
assert "not a valid processor configuration" in error.original_error
|
||||
|
||||
|
||||
def test_from_pretrained_no_processor_config_migration_error():
|
||||
"""Test that missing processor config triggers ProcessorMigrationError."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create a model config but no processor
|
||||
model_config = {"type": "diffusion", "hidden_dim": 256}
|
||||
|
||||
with open(tmp_path / "config.json", "w") as f:
|
||||
json.dump(model_config, f)
|
||||
|
||||
# Should raise ProcessorMigrationError
|
||||
with pytest.raises(ProcessorMigrationError) as exc_info:
|
||||
DataProcessorPipeline.from_pretrained(tmp_path, config_filename="config.json")
|
||||
|
||||
# Check the error details
|
||||
error = exc_info.value
|
||||
assert str(tmp_path) in str(error.model_path)
|
||||
assert "migrate_policy_normalization.py" in error.migration_command
|
||||
assert "not a valid processor configuration" in error.original_error
|
||||
|
||||
|
||||
def test_from_pretrained_valid_processor_no_migration_error():
|
||||
"""Test that valid processor config does NOT trigger migration error."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create a valid processor config
|
||||
processor_config = {
|
||||
"name": "TestProcessor",
|
||||
"steps": [], # Empty is valid
|
||||
}
|
||||
|
||||
with open(tmp_path / "processor.json", "w") as f:
|
||||
json.dump(processor_config, f)
|
||||
|
||||
# Should succeed and create pipeline
|
||||
pipeline = DataProcessorPipeline.from_pretrained(tmp_path, config_filename="processor.json")
|
||||
assert pipeline is not None
|
||||
assert pipeline.name == "TestProcessor"
|
||||
assert len(pipeline) == 0
|
||||
|
||||
|
||||
def test_from_pretrained_no_json_files_no_migration_error():
|
||||
"""Test that directories with no JSON files don't trigger migration errors."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create some non-JSON files
|
||||
with open(tmp_path / "model.safetensors", "w") as f:
|
||||
f.write("fake model data")
|
||||
|
||||
# Should raise FileNotFoundError (config file not found)
|
||||
with pytest.raises(FileNotFoundError, match="not found in directory"):
|
||||
DataProcessorPipeline.from_pretrained(tmp_path, config_filename="processor.json")
|
||||
|
||||
|
||||
def test_processor_migration_error_creation():
|
||||
"""Test that ProcessorMigrationError is created correctly."""
|
||||
model_path = "/path/to/model"
|
||||
migration_command = "python migrate.py --path /path/to/model"
|
||||
original_error = "Config not found"
|
||||
|
||||
error = ProcessorMigrationError(model_path, migration_command, original_error)
|
||||
|
||||
assert error.model_path == model_path
|
||||
assert error.migration_command == migration_command
|
||||
assert error.original_error == original_error
|
||||
assert model_path in str(error)
|
||||
assert migration_command in str(error)
|
||||
assert original_error in str(error)
|
||||
|
||||
|
||||
def test_processor_migration_error_attributes():
|
||||
"""Test that ProcessorMigrationError has correct attributes."""
|
||||
model_path = Path("/test/path")
|
||||
migration_command = "python test.py"
|
||||
original_error = "Test error"
|
||||
|
||||
error = ProcessorMigrationError(model_path, migration_command, original_error)
|
||||
|
||||
# Test that attributes are accessible
|
||||
assert hasattr(error, "model_path")
|
||||
assert hasattr(error, "migration_command")
|
||||
assert hasattr(error, "original_error")
|
||||
|
||||
# Test that it's still an Exception
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
|
||||
def test_migration_suggestion_raises_error():
|
||||
"""Test that migration suggestion always raises ProcessorMigrationError."""
|
||||
with pytest.raises(ProcessorMigrationError) as exc_info:
|
||||
DataProcessorPipeline._suggest_processor_migration("/test/path", "Test error")
|
||||
|
||||
error = exc_info.value
|
||||
assert "/test/path" in str(error.model_path)
|
||||
assert "Test error" in error.original_error
|
||||
assert "migrate_policy_normalization.py" in error.migration_command
|
||||
|
||||
|
||||
def test_migration_error_always_raised_for_invalid_configs():
|
||||
"""Test that ProcessorMigrationError is always raised for invalid configs."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create a model config
|
||||
model_config = {"type": "test", "param": "value"}
|
||||
with open(tmp_path / "config.json", "w") as f:
|
||||
json.dump(model_config, f)
|
||||
|
||||
# Should always raise ProcessorMigrationError
|
||||
with pytest.raises(ProcessorMigrationError):
|
||||
DataProcessorPipeline.from_pretrained(tmp_path, config_filename="config.json")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,31 +18,16 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor import VanillaObservationProcessor
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from lerobot.processor import TransitionKey, VanillaObservationProcessorStep
|
||||
from lerobot.processor.converters import create_transition
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
|
||||
|
||||
def test_process_single_image():
|
||||
"""Test processing a single image."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Create a mock image (H, W, C) format, uint8
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
@@ -68,7 +53,7 @@ def test_process_single_image():
|
||||
|
||||
def test_process_image_dict():
|
||||
"""Test processing multiple images in a dictionary."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Create mock images
|
||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
@@ -91,7 +76,7 @@ def test_process_image_dict():
|
||||
|
||||
def test_process_batched_image():
|
||||
"""Test processing already batched images."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Create a batched image (B, H, W, C)
|
||||
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
|
||||
@@ -108,7 +93,7 @@ def test_process_batched_image():
|
||||
|
||||
def test_invalid_image_format():
|
||||
"""Test error handling for invalid image formats."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Test wrong channel order (channels first)
|
||||
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
|
||||
@@ -121,7 +106,7 @@ def test_invalid_image_format():
|
||||
|
||||
def test_invalid_image_dtype():
|
||||
"""Test error handling for invalid image dtype."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Test wrong dtype
|
||||
image = np.random.rand(64, 64, 3).astype(np.float32)
|
||||
@@ -134,7 +119,7 @@ def test_invalid_image_dtype():
|
||||
|
||||
def test_no_pixels_in_observation():
|
||||
"""Test processor when no pixels are in observation."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = create_transition(observation=observation)
|
||||
@@ -149,9 +134,9 @@ def test_no_pixels_in_observation():
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
transition = create_transition()
|
||||
transition = create_transition(observation={})
|
||||
result = processor(transition)
|
||||
|
||||
assert result == transition
|
||||
@@ -159,7 +144,7 @@ def test_none_observation():
|
||||
|
||||
def test_serialization_methods():
|
||||
"""Test serialization methods."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
@@ -178,7 +163,7 @@ def test_serialization_methods():
|
||||
|
||||
def test_process_environment_state():
|
||||
"""Test processing environment_state."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
observation = {"environment_state": env_state}
|
||||
@@ -199,7 +184,7 @@ def test_process_environment_state():
|
||||
|
||||
def test_process_agent_pos():
|
||||
"""Test processing agent_pos."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
@@ -220,7 +205,7 @@ def test_process_agent_pos():
|
||||
|
||||
def test_process_batched_states():
|
||||
"""Test processing already batched states."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
||||
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
|
||||
@@ -238,7 +223,7 @@ def test_process_batched_states():
|
||||
|
||||
def test_process_both_states():
|
||||
"""Test processing both environment_state and agent_pos."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
env_state = np.array([1.0, 2.0], dtype=np.float32)
|
||||
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
|
||||
@@ -263,7 +248,7 @@ def test_process_both_states():
|
||||
|
||||
def test_no_states_in_observation():
|
||||
"""Test processor when no states are in observation."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = create_transition(observation=observation)
|
||||
@@ -277,7 +262,7 @@ def test_no_states_in_observation():
|
||||
|
||||
def test_complete_observation_processing():
|
||||
"""Test processing a complete observation with both images and states."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Create mock data
|
||||
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
@@ -314,7 +299,7 @@ def test_complete_observation_processing():
|
||||
|
||||
def test_image_only_processing():
|
||||
"""Test processing observation with only images."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
observation = {"pixels": image}
|
||||
@@ -329,7 +314,7 @@ def test_image_only_processing():
|
||||
|
||||
def test_state_only_processing():
|
||||
"""Test processing observation with only states."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
@@ -344,7 +329,7 @@ def test_state_only_processing():
|
||||
|
||||
def test_empty_observation():
|
||||
"""Test processing empty observation."""
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
observation = {}
|
||||
transition = create_transition(observation=observation)
|
||||
@@ -360,7 +345,7 @@ def test_equivalent_to_original_function():
|
||||
# Import the original function for comparison
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Create test data similar to what the original function expects
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
@@ -387,7 +372,7 @@ def test_equivalent_with_image_dict():
|
||||
"""Test equivalence with dictionary of images."""
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
|
||||
processor = VanillaObservationProcessor()
|
||||
processor = VanillaObservationProcessorStep()
|
||||
|
||||
# Create test data with multiple cameras
|
||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
@@ -410,77 +395,133 @@ def test_equivalent_with_image_dict():
|
||||
torch.testing.assert_close(original_result[key], processor_result[key])
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
def test_image_processor_features_pixels_to_image(policy_feature_factory):
|
||||
processor = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
},
|
||||
}
|
||||
out = processor.feature_contract(features.copy())
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"]
|
||||
assert "pixels" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert (
|
||||
OBS_IMAGE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["pixels"]
|
||||
)
|
||||
assert "pixels" not in out[PipelineFeatureType.OBSERVATION]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
def test_image_processor_features_observation_pixels_to_image(policy_feature_factory):
|
||||
processor = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
},
|
||||
}
|
||||
out = processor.feature_contract(features.copy())
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"]
|
||||
assert "observation.pixels" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert (
|
||||
OBS_IMAGE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.pixels"]
|
||||
)
|
||||
assert "observation.pixels" not in out[PipelineFeatureType.OBSERVATION]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
def test_image_processor_features_multi_camera_and_prefixed(policy_feature_factory):
|
||||
processor = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
|
||||
},
|
||||
}
|
||||
out = processor.feature_contract(features.copy())
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"]
|
||||
assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"]
|
||||
assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"]
|
||||
assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert (
|
||||
f"{OBS_IMAGES}.front" in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"]
|
||||
== features[PipelineFeatureType.OBSERVATION]["pixels.front"]
|
||||
)
|
||||
assert (
|
||||
f"{OBS_IMAGES}.wrist" in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.wrist"]
|
||||
== features[PipelineFeatureType.OBSERVATION]["pixels.wrist"]
|
||||
)
|
||||
assert (
|
||||
f"{OBS_IMAGES}.rear" in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.rear"]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.pixels.rear"]
|
||||
)
|
||||
assert (
|
||||
"pixels.front" not in out[PipelineFeatureType.OBSERVATION]
|
||||
and "pixels.wrist" not in out[PipelineFeatureType.OBSERVATION]
|
||||
and "observation.pixels.rear" not in out[PipelineFeatureType.OBSERVATION]
|
||||
)
|
||||
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
def test_state_processor_features_environment_and_agent_pos(policy_feature_factory):
|
||||
processor = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
||||
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
||||
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
},
|
||||
}
|
||||
out = processor.feature_contract(features.copy())
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"]
|
||||
assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"]
|
||||
assert "environment_state" not in out and "agent_pos" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert (
|
||||
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["environment_state"]
|
||||
)
|
||||
assert (
|
||||
OBS_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["agent_pos"]
|
||||
)
|
||||
assert (
|
||||
"environment_state" not in out[PipelineFeatureType.OBSERVATION]
|
||||
and "agent_pos" not in out[PipelineFeatureType.OBSERVATION]
|
||||
)
|
||||
assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
|
||||
proc = VanillaObservationProcessor()
|
||||
def test_state_processor_features_prefixed_inputs(policy_feature_factory):
|
||||
proc = VanillaObservationProcessorStep()
|
||||
features = {
|
||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||
},
|
||||
}
|
||||
out = proc.feature_contract(features.copy())
|
||||
out = proc.transform_features(features.copy())
|
||||
|
||||
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"]
|
||||
assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"]
|
||||
assert "environment_state" not in out and "agent_pos" not in out
|
||||
assert (
|
||||
OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.environment_state"]
|
||||
)
|
||||
assert (
|
||||
OBS_STATE in out[PipelineFeatureType.OBSERVATION]
|
||||
and out[PipelineFeatureType.OBSERVATION][OBS_STATE]
|
||||
== features[PipelineFeatureType.OBSERVATION]["observation.agent_pos"]
|
||||
)
|
||||
assert (
|
||||
"environment_state" not in out[PipelineFeatureType.OBSERVATION]
|
||||
and "agent_pos" not in out[PipelineFeatureType.OBSERVATION]
|
||||
)
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
424
tests/processor/test_pi0_processor.py
Normal file
424
tests/processor/test_pi0_processor.py
Normal file
@@ -0,0 +1,424 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for PI0 policy processor."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
ProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, transition_to_batch
|
||||
|
||||
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
"""Mock tokenizer processor step for testing."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Accept any arguments to mimic the real TokenizerProcessorStep interface
|
||||
pass
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Pass through transition unchanged
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
# Pass through features unchanged
|
||||
return features
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default PI0 configuration for testing."""
|
||||
config = PI0Config()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
config.device = "cpu"
|
||||
config.tokenizer_max_length = 128
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)},
|
||||
OBS_IMAGE: {}, # No normalization for images
|
||||
ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_pi0_processor_basic():
|
||||
"""Test basic creation of PI0 processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 6
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], Pi0NewLineProcessor)
|
||||
# Step 3 would be TokenizerProcessorStep but it's mocked
|
||||
assert isinstance(preprocessor.steps[4], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[5], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
|
||||
|
||||
|
||||
def test_pi0_newline_processor_single_task():
|
||||
"""Test Pi0NewLineProcessor with single task string."""
|
||||
processor = Pi0NewLineProcessor()
|
||||
|
||||
# Test with task that doesn't have newline
|
||||
transition = create_transition(complementary_data={"task": "test task"})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
|
||||
|
||||
# Test with task that already has newline
|
||||
transition = create_transition(complementary_data={"task": "test task\n"})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
|
||||
|
||||
|
||||
def test_pi0_newline_processor_list_of_tasks():
|
||||
"""Test Pi0NewLineProcessor with list of task strings."""
|
||||
processor = Pi0NewLineProcessor()
|
||||
|
||||
# Test with list of tasks
|
||||
tasks = ["task1", "task2\n", "task3"]
|
||||
transition = create_transition(complementary_data={"task": tasks})
|
||||
result = processor(transition)
|
||||
expected = ["task1\n", "task2\n", "task3\n"]
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == expected
|
||||
|
||||
|
||||
def test_pi0_newline_processor_empty_transition():
|
||||
"""Test Pi0NewLineProcessor with empty transition."""
|
||||
processor = Pi0NewLineProcessor()
|
||||
|
||||
# Test with no complementary_data
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
# Test with complementary_data but no task
|
||||
transition = create_transition(complementary_data={"other": "data"})
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
# Test with None task
|
||||
transition = create_transition(complementary_data={"task": None})
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_pi0_processor_cuda():
|
||||
"""Test PI0 processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action, complementary_data={"task": "test task"})
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[OBS_STATE].device.type == "cuda"
|
||||
assert processed[OBS_IMAGE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION.value].device.type == "cuda"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_pi0_processor_accelerate_scenario():
|
||||
"""Test PI0 processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU and batched
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 10).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_pi0_processor_multi_gpu():
|
||||
"""Test PI0 processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 10).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
def test_pi0_processor_without_stats():
|
||||
"""Test PI0 processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
# Mock the tokenizer processor
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
dataset_stats=None,
|
||||
)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
|
||||
def test_pi0_newline_processor_state_dict():
|
||||
"""Test Pi0NewLineProcessor state dict methods."""
|
||||
processor = Pi0NewLineProcessor()
|
||||
|
||||
# Test state_dict (should be empty)
|
||||
state = processor.state_dict()
|
||||
assert state == {}
|
||||
|
||||
# Test load_state_dict (should do nothing)
|
||||
processor.load_state_dict({})
|
||||
|
||||
# Test reset (should do nothing)
|
||||
processor.reset()
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
assert config == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_pi0_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
config.device = "cuda"
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
|
||||
preprocessor, _ = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
norm_step = step # Now type checker knows this is NormalizerProcessorStep
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=norm_step.features,
|
||||
norm_map=norm_step.norm_map,
|
||||
stats=norm_step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration (PI0 has NormalizerProcessorStep at index 5)
|
||||
normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data with both state and visual observations
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(10, dtype=torch.float32), # PI0 expects size 10
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6
|
||||
transition = create_transition(
|
||||
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[OBS_STATE].dtype == torch.bfloat16
|
||||
assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
# Check state stats (has normalization)
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
|
||||
File diff suppressed because it is too large
Load Diff
259
tests/processor/test_pipeline_from_pretrained_helpers.py
Normal file
259
tests/processor/test_pipeline_from_pretrained_helpers.py
Normal file
@@ -0,0 +1,259 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Tests for DataProcessorPipeline.from_pretrained helper methods.
|
||||
|
||||
These tests focus on the individual private methods that were extracted from
|
||||
the main from_pretrained method to improve modularity and testability.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError
|
||||
|
||||
# Simplified Config Loading Tests
|
||||
|
||||
|
||||
def test_load_config_directory():
|
||||
"""Test loading config from directory."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create a config file
|
||||
config_file = tmp_path / "processor.json"
|
||||
test_config = {"name": "TestProcessor", "steps": []}
|
||||
config_file.write_text(json.dumps(test_config))
|
||||
|
||||
# Load from directory
|
||||
loaded_config, base_path = DataProcessorPipeline._load_config(str(tmp_path), "processor.json", {})
|
||||
|
||||
assert loaded_config == test_config
|
||||
assert base_path == tmp_path
|
||||
|
||||
|
||||
def test_load_config_single_file():
|
||||
"""Test loading config from a single file path."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create a config file
|
||||
config_file = tmp_path / "processor.json"
|
||||
test_config = {"name": "TestProcessor", "steps": []}
|
||||
config_file.write_text(json.dumps(test_config))
|
||||
|
||||
# Load using file path directly
|
||||
loaded_config, base_path = DataProcessorPipeline._load_config(
|
||||
str(config_file), "any_filename_ignored", {}
|
||||
)
|
||||
|
||||
assert loaded_config == test_config
|
||||
assert base_path == tmp_path
|
||||
|
||||
|
||||
def test_load_config_directory_file_not_found():
|
||||
"""Test directory loading when config file doesn't exist."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Directory exists but no processor.json
|
||||
with pytest.raises(FileNotFoundError, match="not found in directory"):
|
||||
DataProcessorPipeline._load_config(str(tmp_path), "processor.json", {})
|
||||
|
||||
|
||||
def test_load_config_directory_with_migration_detection():
|
||||
"""Test that missing config triggers migration detection."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create old-style config to trigger migration
|
||||
(tmp_path / "config.json").write_text(json.dumps({"type": "act"}))
|
||||
|
||||
# Try to load processor.json (doesn't exist), should trigger migration
|
||||
with pytest.raises(ProcessorMigrationError):
|
||||
DataProcessorPipeline._load_config(str(tmp_path), "processor.json", {})
|
||||
|
||||
|
||||
def test_load_config_nonexistent_path_tries_hub():
|
||||
"""Test that nonexistent paths try Hub (simplified logic)."""
|
||||
# This path doesn't exist locally, should try Hub
|
||||
with pytest.raises(FileNotFoundError, match="on the HuggingFace Hub"):
|
||||
DataProcessorPipeline._load_config("nonexistent/path", "processor.json", {})
|
||||
|
||||
|
||||
# Config Validation Tests
|
||||
|
||||
|
||||
def test_validate_loaded_config_valid_config():
|
||||
"""Test validation with valid processor config."""
|
||||
valid_config = {"name": "TestProcessor", "steps": []}
|
||||
|
||||
# Should not raise any exception
|
||||
DataProcessorPipeline._validate_loaded_config("any-path", valid_config, "processor.json")
|
||||
|
||||
|
||||
def test_validate_loaded_config_invalid_config():
|
||||
"""Test validation with invalid processor config."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Create non-processor config to trigger migration
|
||||
(tmp_path / "config.json").write_text(json.dumps({"type": "act"}))
|
||||
|
||||
invalid_config = {"type": "act", "hidden_dim": 256}
|
||||
|
||||
with pytest.raises(ProcessorMigrationError):
|
||||
DataProcessorPipeline._validate_loaded_config(str(tmp_path), invalid_config, "config.json")
|
||||
|
||||
|
||||
def test_validate_loaded_config_invalid_config_no_migration():
|
||||
"""Test validation with invalid config when no migration is detected."""
|
||||
# Non-directory path (Hub repo) - no migration detection
|
||||
invalid_config = {"type": "act", "hidden_dim": 256}
|
||||
|
||||
with pytest.raises(ValueError, match="not a valid processor configuration"):
|
||||
DataProcessorPipeline._validate_loaded_config("user/repo", invalid_config, "config.json")
|
||||
|
||||
|
||||
# Step Class Resolution Tests
|
||||
|
||||
|
||||
def test_resolve_step_class_registry_name():
|
||||
"""Test resolution using registry name."""
|
||||
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
# Register a test step
|
||||
@ProcessorStepRegistry.register("test_step")
|
||||
class TestStep(ProcessorStep):
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
try:
|
||||
step_entry = {"registry_name": "test_step"}
|
||||
step_class, step_key = DataProcessorPipeline._resolve_step_class(step_entry)
|
||||
|
||||
assert step_class is TestStep
|
||||
assert step_key == "test_step"
|
||||
finally:
|
||||
ProcessorStepRegistry.unregister("test_step")
|
||||
|
||||
|
||||
def test_resolve_step_class_registry_name_not_found():
|
||||
"""Test resolution with non-existent registry name."""
|
||||
step_entry = {"registry_name": "nonexistent_step"}
|
||||
|
||||
with pytest.raises(ImportError, match="Failed to load processor step from registry"):
|
||||
DataProcessorPipeline._resolve_step_class(step_entry)
|
||||
|
||||
|
||||
def test_resolve_step_class_import_path():
|
||||
"""Test resolution using full import path."""
|
||||
# Use a valid existing class (this should work)
|
||||
step_entry = {"class": "lerobot.processor.pipeline.ProcessorStep"}
|
||||
|
||||
# This should succeed - ProcessorStep can be imported, just not instantiated
|
||||
step_class, step_key = DataProcessorPipeline._resolve_step_class(step_entry)
|
||||
|
||||
from lerobot.processor.pipeline import ProcessorStep
|
||||
|
||||
assert step_class is ProcessorStep
|
||||
assert step_key == "ProcessorStep"
|
||||
|
||||
|
||||
def test_resolve_step_class_invalid_import_path():
|
||||
"""Test resolution with invalid import path."""
|
||||
step_entry = {"class": "nonexistent.module.ClassName"}
|
||||
|
||||
with pytest.raises(ImportError, match="Failed to load processor step"):
|
||||
DataProcessorPipeline._resolve_step_class(step_entry)
|
||||
|
||||
|
||||
# Override Validation Tests
|
||||
|
||||
|
||||
def test_validate_overrides_used_all_used():
|
||||
"""Test validation when all overrides are used."""
|
||||
# Empty set means all overrides were used
|
||||
remaining_overrides = set()
|
||||
config = {"steps": [{"class": "SomeStep"}]}
|
||||
|
||||
# Should not raise
|
||||
DataProcessorPipeline._validate_overrides_used(remaining_overrides, config)
|
||||
|
||||
|
||||
def test_validate_overrides_used_some_unused():
|
||||
"""Test validation when some overrides are unused."""
|
||||
remaining_overrides = {"NonExistentStep", "AnotherMissingStep"}
|
||||
config = {
|
||||
"steps": [
|
||||
{"registry_name": "normalize_step"},
|
||||
{"class": "some.module.TransformStep"},
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(KeyError, match="Override keys.*do not match any step"):
|
||||
DataProcessorPipeline._validate_overrides_used(remaining_overrides, config)
|
||||
|
||||
|
||||
def test_validate_overrides_used_helpful_error_message():
|
||||
"""Test that error message includes available step keys."""
|
||||
remaining_overrides = {"WrongStep"}
|
||||
config = {
|
||||
"steps": [
|
||||
{"registry_name": "correct_step"},
|
||||
{"class": "module.path.CorrectClass"},
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
DataProcessorPipeline._validate_overrides_used(remaining_overrides, config)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Available step keys" in error_msg
|
||||
assert "correct_step" in error_msg
|
||||
assert "CorrectClass" in error_msg
|
||||
|
||||
|
||||
# Integration Tests for Simplified Logic
|
||||
|
||||
|
||||
def test_simplified_three_way_loading():
|
||||
"""Test that the simplified 3-way loading logic works correctly."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir)
|
||||
|
||||
# Test 1: Directory loading
|
||||
config_file = tmp_path / "processor.json"
|
||||
test_config = {"name": "DirectoryTest", "steps": []}
|
||||
config_file.write_text(json.dumps(test_config))
|
||||
|
||||
loaded_config, base_path = DataProcessorPipeline._load_config(str(tmp_path), "processor.json", {})
|
||||
assert loaded_config["name"] == "DirectoryTest"
|
||||
assert base_path == tmp_path
|
||||
|
||||
# Test 2: Single file loading
|
||||
loaded_config, base_path = DataProcessorPipeline._load_config(
|
||||
str(config_file), "ignored_filename", {}
|
||||
)
|
||||
assert loaded_config["name"] == "DirectoryTest"
|
||||
assert base_path == tmp_path
|
||||
525
tests/processor/test_policy_robot_bridge.py
Normal file
525
tests/processor/test_policy_robot_bridge.py
Normal file
@@ -0,0 +1,525 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType
|
||||
from lerobot.processor import (
|
||||
DataProcessorPipeline,
|
||||
PolicyActionToRobotActionProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RobotActionToPolicyActionProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import identity_transition
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
def test_robot_to_policy_basic_action_conversion():
|
||||
"""Test basic robot action to policy action conversion."""
|
||||
motor_names = ["joint1", "joint2", "joint3"]
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
robot_action = {
|
||||
"joint1.pos": 1.0,
|
||||
"joint2.pos": 2.0,
|
||||
"joint3.pos": 3.0,
|
||||
}
|
||||
|
||||
policy_action = processor.action(robot_action)
|
||||
|
||||
assert isinstance(policy_action, torch.Tensor)
|
||||
assert policy_action.shape == (3,)
|
||||
torch.testing.assert_close(policy_action, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
|
||||
def test_robot_to_policy_action_conversion_preserves_order():
|
||||
"""Test that motor names order is preserved in conversion."""
|
||||
motor_names = ["gripper", "arm", "wrist"]
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
robot_action = {
|
||||
"arm.pos": 10.0,
|
||||
"gripper.pos": 5.0,
|
||||
"wrist.pos": 15.0,
|
||||
}
|
||||
|
||||
policy_action = processor.action(robot_action)
|
||||
|
||||
expected = torch.tensor([5.0, 10.0, 15.0])
|
||||
torch.testing.assert_close(policy_action, expected)
|
||||
|
||||
|
||||
def test_robot_to_policy_action_conversion_with_floats_and_tensors():
|
||||
"""Test conversion with mixed float and tensor values."""
|
||||
motor_names = ["joint1", "joint2"]
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
robot_action = {
|
||||
"joint1.pos": torch.tensor(1.5),
|
||||
"joint2.pos": 2.5, # Regular float
|
||||
}
|
||||
|
||||
policy_action = processor.action(robot_action)
|
||||
|
||||
assert isinstance(policy_action, torch.Tensor)
|
||||
torch.testing.assert_close(policy_action, torch.tensor([1.5, 2.5]))
|
||||
|
||||
|
||||
def test_robot_to_policy_action_length_mismatch_error():
|
||||
"""Test error when robot action length doesn't match motor names."""
|
||||
motor_names = ["joint1", "joint2", "joint3"]
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
# Too few actions
|
||||
robot_action = {"joint1.pos": 1.0, "joint2.pos": 2.0}
|
||||
|
||||
with pytest.raises(ValueError, match="Action must have 3 elements, got 2"):
|
||||
processor.action(robot_action)
|
||||
|
||||
robot_action = {
|
||||
"joint1.pos": 1.0,
|
||||
"joint2.pos": 2.0,
|
||||
"joint3.pos": 3.0,
|
||||
"extra.pos": 4.0,
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Action must have 3 elements, got 4"):
|
||||
processor.action(robot_action)
|
||||
|
||||
|
||||
def test_robot_to_policy_missing_motor_key_error():
|
||||
"""Test error when robot action is missing expected motor keys."""
|
||||
motor_names = ["joint1", "joint2"]
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
robot_action = {
|
||||
"joint1.pos": 1.0,
|
||||
"wrong_key.pos": 2.0,
|
||||
}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
processor.action(robot_action)
|
||||
|
||||
|
||||
def test_robot_to_policy_transform_features():
|
||||
"""Test feature transformation for robot to policy action processor."""
|
||||
motor_names = ["joint1", "joint2", "joint3"]
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
features = {
|
||||
PipelineFeatureType.ACTION: {
|
||||
"joint1.pos": {"type": FeatureType.ACTION, "shape": (1,)},
|
||||
"joint2.pos": {"type": FeatureType.ACTION, "shape": (1,)},
|
||||
"joint3.pos": {"type": FeatureType.ACTION, "shape": (1,)},
|
||||
"other_data": {"type": FeatureType.ENV, "shape": (1,)},
|
||||
}
|
||||
}
|
||||
|
||||
transformed = processor.transform_features(features)
|
||||
|
||||
assert "action" in transformed[PipelineFeatureType.ACTION]
|
||||
action_feature = transformed[PipelineFeatureType.ACTION]["action"]
|
||||
assert action_feature.type == FeatureType.ACTION
|
||||
assert action_feature.shape == (3,)
|
||||
|
||||
assert "joint1.pos" in transformed[PipelineFeatureType.ACTION]
|
||||
assert "joint2.pos" in transformed[PipelineFeatureType.ACTION]
|
||||
assert "joint3.pos" in transformed[PipelineFeatureType.ACTION]
|
||||
|
||||
assert "other_data" in transformed[PipelineFeatureType.ACTION]
|
||||
|
||||
|
||||
def test_robot_to_policy_get_config():
|
||||
"""Test configuration serialization."""
|
||||
motor_names = ["motor1", "motor2"]
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
config = processor.get_config()
|
||||
assert config == {"motor_names": motor_names}
|
||||
|
||||
|
||||
def test_robot_to_policy_state_dict():
|
||||
"""Test state dict operations."""
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=["joint1"])
|
||||
|
||||
state = processor.state_dict()
|
||||
assert state == {}
|
||||
|
||||
processor.load_state_dict({})
|
||||
|
||||
|
||||
def test_robot_to_policy_single_motor():
|
||||
"""Test with single motor."""
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=["single_joint"])
|
||||
|
||||
robot_action = {"single_joint.pos": 42.0}
|
||||
policy_action = processor.action(robot_action)
|
||||
|
||||
assert policy_action.shape == (1,)
|
||||
torch.testing.assert_close(policy_action, torch.tensor([42.0]))
|
||||
|
||||
|
||||
def test_policy_to_robot_basic_action_conversion():
|
||||
"""Test basic policy action to robot action conversion."""
|
||||
motor_names = ["joint1", "joint2", "joint3"]
|
||||
processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
policy_action = torch.tensor([1.0, 2.0, 3.0])
|
||||
robot_action = processor.action(policy_action)
|
||||
|
||||
assert isinstance(robot_action, dict)
|
||||
assert len(robot_action) == 3
|
||||
|
||||
expected = {
|
||||
"joint1.pos": 1.0,
|
||||
"joint2.pos": 2.0,
|
||||
"joint3.pos": 3.0,
|
||||
}
|
||||
|
||||
for key, expected_value in expected.items():
|
||||
assert key in robot_action
|
||||
actual_value = robot_action[key]
|
||||
if isinstance(actual_value, torch.Tensor):
|
||||
actual_value = actual_value.item()
|
||||
assert actual_value == pytest.approx(expected_value)
|
||||
|
||||
|
||||
def test_policy_to_robot_action_conversion_preserves_order():
|
||||
"""Test that motor names order corresponds to tensor indices."""
|
||||
motor_names = ["gripper", "arm", "wrist"]
|
||||
processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
policy_action = torch.tensor([5.0, 10.0, 15.0])
|
||||
robot_action = processor.action(policy_action)
|
||||
|
||||
assert robot_action["gripper.pos"] == pytest.approx(5.0)
|
||||
assert robot_action["arm.pos"] == pytest.approx(10.0)
|
||||
assert robot_action["wrist.pos"] == pytest.approx(15.0)
|
||||
|
||||
|
||||
def test_policy_to_robot_action_conversion_with_numpy_input():
|
||||
"""Test conversion with numpy array input."""
|
||||
import numpy as np
|
||||
|
||||
motor_names = ["joint1", "joint2"]
|
||||
processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
policy_action = np.array([1.5, 2.5])
|
||||
robot_action = processor.action(policy_action)
|
||||
|
||||
assert robot_action["joint1.pos"] == pytest.approx(1.5)
|
||||
assert robot_action["joint2.pos"] == pytest.approx(2.5)
|
||||
|
||||
|
||||
def test_policy_to_robot_action_length_mismatch_error():
|
||||
"""Test error when policy action length doesn't match motor names."""
|
||||
motor_names = ["joint1", "joint2", "joint3"]
|
||||
processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
policy_action = torch.tensor([1.0, 2.0])
|
||||
|
||||
with pytest.raises(ValueError, match="Action must have 3 elements, got 2"):
|
||||
processor.action(policy_action)
|
||||
|
||||
policy_action = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
|
||||
with pytest.raises(ValueError, match="Action must have 3 elements, got 4"):
|
||||
processor.action(policy_action)
|
||||
|
||||
|
||||
def test_policy_to_robot_transform_features():
|
||||
"""Test feature transformation for policy to robot action processor."""
|
||||
motor_names = ["joint1", "joint2"]
|
||||
processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
features = {
|
||||
PipelineFeatureType.ACTION: {
|
||||
"action": {"type": FeatureType.ACTION, "shape": (2,)},
|
||||
"other_data": {"type": FeatureType.ENV, "shape": (1,)},
|
||||
}
|
||||
}
|
||||
|
||||
transformed = processor.transform_features(features)
|
||||
|
||||
assert "joint1.pos" in transformed[PipelineFeatureType.ACTION]
|
||||
assert "joint2.pos" in transformed[PipelineFeatureType.ACTION]
|
||||
|
||||
for motor in motor_names:
|
||||
motor_feature = transformed[PipelineFeatureType.ACTION][f"{motor}.pos"]
|
||||
assert motor_feature.type == FeatureType.ACTION
|
||||
assert motor_feature.shape == (1,)
|
||||
|
||||
assert "action" in transformed[PipelineFeatureType.ACTION]
|
||||
|
||||
assert "other_data" in transformed[PipelineFeatureType.ACTION]
|
||||
|
||||
|
||||
def test_policy_to_robot_get_config():
|
||||
"""Test configuration serialization."""
|
||||
motor_names = ["motor1", "motor2"]
|
||||
processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
config = processor.get_config()
|
||||
assert config == {"motor_names": motor_names}
|
||||
|
||||
|
||||
def test_policy_to_robot_state_dict():
|
||||
"""Test state dict operations."""
|
||||
processor = PolicyActionToRobotActionProcessorStep(motor_names=["joint1"])
|
||||
|
||||
state = processor.state_dict()
|
||||
assert state == {}
|
||||
|
||||
processor.load_state_dict({})
|
||||
|
||||
|
||||
def test_policy_to_robot_single_motor():
|
||||
"""Test with single motor."""
|
||||
processor = PolicyActionToRobotActionProcessorStep(motor_names=["single_joint"])
|
||||
|
||||
policy_action = torch.tensor([42.0])
|
||||
robot_action = processor.action(policy_action)
|
||||
|
||||
assert len(robot_action) == 1
|
||||
assert robot_action["single_joint.pos"] == pytest.approx(42.0)
|
||||
|
||||
|
||||
def test_robot_to_policy_registry():
|
||||
"""Test RobotActionToPolicyActionProcessorStep registry."""
|
||||
assert "robot_action_to_policy_action_processor" in ProcessorStepRegistry.list()
|
||||
|
||||
retrieved_class = ProcessorStepRegistry.get("robot_action_to_policy_action_processor")
|
||||
assert retrieved_class is RobotActionToPolicyActionProcessorStep
|
||||
|
||||
instance = retrieved_class(motor_names=["test"])
|
||||
assert isinstance(instance, RobotActionToPolicyActionProcessorStep)
|
||||
assert instance.motor_names == ["test"]
|
||||
|
||||
|
||||
def test_policy_to_robot_registry():
|
||||
"""Test PolicyActionToRobotActionProcessorStep registry."""
|
||||
assert "policy_action_to_robot_action_processor" in ProcessorStepRegistry.list()
|
||||
|
||||
retrieved_class = ProcessorStepRegistry.get("policy_action_to_robot_action_processor")
|
||||
assert retrieved_class is PolicyActionToRobotActionProcessorStep
|
||||
|
||||
instance = retrieved_class(motor_names=["test"])
|
||||
assert isinstance(instance, PolicyActionToRobotActionProcessorStep)
|
||||
assert instance.motor_names == ["test"]
|
||||
|
||||
|
||||
def test_save_and_load_robot_to_policy():
|
||||
"""Test saving and loading RobotActionToPolicyActionProcessorStep."""
|
||||
motor_names = ["joint1", "joint2", "joint3"]
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names)
|
||||
pipeline = DataProcessorPipeline([processor], name="TestRobotToPolicy")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save pipeline
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Check config file exists
|
||||
config_path = Path(tmp_dir) / "testrobottopolicy.json"
|
||||
assert config_path.exists()
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = DataProcessorPipeline.from_pretrained(
|
||||
tmp_dir,
|
||||
"testrobottopolicy.json",
|
||||
to_transition=identity_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
|
||||
assert loaded_pipeline.name == "TestRobotToPolicy"
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
# Check loaded processor
|
||||
loaded_processor = loaded_pipeline.steps[0]
|
||||
assert isinstance(loaded_processor, RobotActionToPolicyActionProcessorStep)
|
||||
assert loaded_processor.motor_names == motor_names
|
||||
|
||||
# Test functionality after loading
|
||||
robot_action = {"joint1.pos": 1.0, "joint2.pos": 2.0, "joint3.pos": 3.0}
|
||||
policy_action = loaded_processor.action(robot_action)
|
||||
torch.testing.assert_close(policy_action, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
|
||||
def test_save_and_load_policy_to_robot():
|
||||
"""Test saving and loading PolicyActionToRobotActionProcessorStep."""
|
||||
motor_names = ["motor_a", "motor_b"]
|
||||
processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names)
|
||||
pipeline = DataProcessorPipeline([processor], name="TestPolicyToRobot")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save pipeline
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = DataProcessorPipeline.from_pretrained(
|
||||
tmp_dir,
|
||||
"testpolicytorobot.json",
|
||||
to_transition=identity_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
|
||||
loaded_processor = loaded_pipeline.steps[0]
|
||||
assert isinstance(loaded_processor, PolicyActionToRobotActionProcessorStep)
|
||||
assert loaded_processor.motor_names == motor_names
|
||||
|
||||
policy_action = torch.tensor([10.0, 20.0])
|
||||
robot_action = loaded_processor.action(policy_action)
|
||||
assert robot_action["motor_a.pos"] == pytest.approx(10.0)
|
||||
assert robot_action["motor_b.pos"] == pytest.approx(20.0)
|
||||
|
||||
|
||||
# Integration and chaining tests
|
||||
|
||||
|
||||
def test_round_trip_conversion():
|
||||
"""Test that robot->policy->robot conversion preserves values."""
|
||||
motor_names = ["joint1", "joint2", "joint3"]
|
||||
robot_to_policy = RobotActionToPolicyActionProcessorStep(motor_names=motor_names)
|
||||
policy_to_robot = PolicyActionToRobotActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
original_robot_action = {
|
||||
"joint1.pos": 1.5,
|
||||
"joint2.pos": -2.3,
|
||||
"joint3.pos": 0.7,
|
||||
}
|
||||
|
||||
policy_action = robot_to_policy.action(original_robot_action)
|
||||
final_robot_action = policy_to_robot.action(policy_action)
|
||||
|
||||
for key in original_robot_action:
|
||||
original_val = original_robot_action[key]
|
||||
final_val = final_robot_action[key]
|
||||
if isinstance(final_val, torch.Tensor):
|
||||
final_val = final_val.item()
|
||||
assert final_val == pytest.approx(original_val, abs=1e-6)
|
||||
|
||||
|
||||
def test_chained_processors_in_pipeline():
|
||||
"""Test both processors chained in a pipeline."""
|
||||
motor_names = ["joint1", "joint2"]
|
||||
robot_to_policy = RobotActionToPolicyActionProcessorStep(motor_names=motor_names)
|
||||
policy_to_robot = PolicyActionToRobotActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
pipeline = DataProcessorPipeline(
|
||||
[robot_to_policy, policy_to_robot],
|
||||
to_transition=identity_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
|
||||
assert len(pipeline.steps) == 2
|
||||
assert isinstance(pipeline.steps[0], RobotActionToPolicyActionProcessorStep)
|
||||
assert isinstance(pipeline.steps[1], PolicyActionToRobotActionProcessorStep)
|
||||
|
||||
|
||||
def test_robot_to_policy_features_contract(policy_feature_factory):
|
||||
"""Test feature transformation maintains proper typing contract."""
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=["j1", "j2"])
|
||||
features = {
|
||||
PipelineFeatureType.ACTION: {
|
||||
"j1.pos": policy_feature_factory(FeatureType.ACTION, (1,)),
|
||||
"j2.pos": policy_feature_factory(FeatureType.ACTION, (1,)),
|
||||
"other": policy_feature_factory(FeatureType.ENV, (3,)),
|
||||
}
|
||||
}
|
||||
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
assert "action" in out[PipelineFeatureType.ACTION]
|
||||
action_feature = out[PipelineFeatureType.ACTION]["action"]
|
||||
assert action_feature.type == FeatureType.ACTION
|
||||
assert action_feature.shape == (2,)
|
||||
|
||||
|
||||
def test_policy_to_robot_features_contract(policy_feature_factory):
|
||||
"""Test feature transformation maintains proper typing contract."""
|
||||
processor = PolicyActionToRobotActionProcessorStep(motor_names=["m1", "m2", "m3"])
|
||||
features = {
|
||||
PipelineFeatureType.ACTION: {
|
||||
"action": policy_feature_factory(FeatureType.ACTION, (3,)),
|
||||
"other": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
}
|
||||
}
|
||||
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
for motor in ["m1", "m2", "m3"]:
|
||||
key = f"{motor}.pos"
|
||||
assert key in out[PipelineFeatureType.ACTION]
|
||||
motor_feature = out[PipelineFeatureType.ACTION][key]
|
||||
assert motor_feature.type == FeatureType.ACTION
|
||||
assert motor_feature.shape == (1,)
|
||||
|
||||
|
||||
def test_empty_motor_names_list():
|
||||
"""Test behavior with empty motor names list."""
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=[])
|
||||
|
||||
robot_action = {}
|
||||
policy_action = processor.action(robot_action)
|
||||
|
||||
assert isinstance(policy_action, torch.Tensor)
|
||||
assert policy_action.shape == (0,)
|
||||
|
||||
|
||||
def test_empty_motor_names_list_policy_to_robot():
|
||||
"""Test PolicyActionToRobotActionProcessorStep with empty motor names."""
|
||||
processor = PolicyActionToRobotActionProcessorStep(motor_names=[])
|
||||
|
||||
policy_action = torch.tensor([])
|
||||
robot_action = processor.action(policy_action)
|
||||
|
||||
assert isinstance(robot_action, dict)
|
||||
assert len(robot_action) == 0
|
||||
|
||||
|
||||
def test_very_long_motor_names():
|
||||
"""Test with many motor names."""
|
||||
motor_names = [f"joint_{i}" for i in range(100)]
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
robot_action = {f"joint_{i}.pos": float(i) for i in range(100)}
|
||||
policy_action = processor.action(robot_action)
|
||||
|
||||
assert policy_action.shape == (100,)
|
||||
expected = torch.tensor([float(i) for i in range(100)])
|
||||
torch.testing.assert_close(policy_action, expected)
|
||||
|
||||
|
||||
def test_special_characters_in_motor_names():
|
||||
"""Test with special characters in motor names."""
|
||||
motor_names = ["motor-1", "motor_2", "motor.3"]
|
||||
processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names)
|
||||
|
||||
robot_action = {
|
||||
"motor-1.pos": 1.0,
|
||||
"motor_2.pos": 2.0,
|
||||
"motor.3.pos": 3.0,
|
||||
}
|
||||
|
||||
policy_action = processor.action(robot_action)
|
||||
torch.testing.assert_close(policy_action, torch.tensor([1.0, 2.0, 3.0]))
|
||||
@@ -19,33 +19,25 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType
|
||||
from lerobot.processor import (
|
||||
DataProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
|
||||
|
||||
def test_basic_renaming():
|
||||
"""Test basic key renaming functionality."""
|
||||
rename_map = {
|
||||
"old_key1": "new_key1",
|
||||
"old_key2": "new_key2",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"old_key1": torch.tensor([1.0, 2.0]),
|
||||
@@ -73,7 +65,7 @@ def test_basic_renaming():
|
||||
|
||||
def test_empty_rename_map():
|
||||
"""Test processor with empty rename map (should pass through unchanged)."""
|
||||
processor = RenameProcessor(rename_map={})
|
||||
processor = RenameObservationsProcessorStep(rename_map={})
|
||||
|
||||
observation = {
|
||||
"key1": torch.tensor([1.0]),
|
||||
@@ -92,9 +84,9 @@ def test_empty_rename_map():
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = RenameProcessor(rename_map={"old": "new"})
|
||||
processor = RenameObservationsProcessorStep(rename_map={"old": "new"})
|
||||
|
||||
transition = create_transition()
|
||||
transition = create_transition(observation={})
|
||||
result = processor(transition)
|
||||
|
||||
# Should return transition unchanged
|
||||
@@ -107,7 +99,7 @@ def test_overlapping_rename():
|
||||
"a": "b",
|
||||
"b": "c", # This creates a potential conflict
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"a": 1,
|
||||
@@ -132,7 +124,7 @@ def test_partial_rename():
|
||||
"observation.state": "observation.proprio_state",
|
||||
"pixels": "observation.image",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.randn(10),
|
||||
@@ -162,15 +154,15 @@ def test_get_config():
|
||||
"old1": "new1",
|
||||
"old2": "new2",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
config = processor.get_config()
|
||||
assert config == {"rename_map": rename_map}
|
||||
|
||||
|
||||
def test_state_dict():
|
||||
"""Test state dict (should be empty for RenameProcessor)."""
|
||||
processor = RenameProcessor(rename_map={"old": "new"})
|
||||
"""Test state dict (should be empty for RenameProcessorStep)."""
|
||||
processor = RenameObservationsProcessorStep(rename_map={"old": "new"})
|
||||
|
||||
state = processor.state_dict()
|
||||
assert state == {}
|
||||
@@ -185,9 +177,11 @@ def test_integration_with_robot_processor():
|
||||
"agent_pos": "observation.state",
|
||||
"pixels": "observation.image",
|
||||
}
|
||||
rename_processor = RenameProcessor(rename_map=rename_map)
|
||||
rename_processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
pipeline = RobotProcessor([rename_processor])
|
||||
pipeline = DataProcessorPipeline(
|
||||
[rename_processor], to_transition=identity_transition, to_output=identity_transition
|
||||
)
|
||||
|
||||
observation = {
|
||||
"agent_pos": np.array([1.0, 2.0, 3.0]),
|
||||
@@ -219,30 +213,37 @@ def test_save_and_load_pretrained():
|
||||
"old_state": "observation.state",
|
||||
"old_image": "observation.image",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
pipeline = RobotProcessor([processor], name="TestRenameProcessor")
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save pipeline
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Check files were created
|
||||
config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor"
|
||||
config_path = (
|
||||
Path(tmp_dir) / "testrenameprocessorstep.json"
|
||||
) # Based on name="TestRenameProcessorStep"
|
||||
assert config_path.exists()
|
||||
|
||||
# No state files should be created for RenameProcessor
|
||||
# No state files should be created for RenameProcessorStep
|
||||
state_files = list(Path(tmp_dir).glob("*.safetensors"))
|
||||
assert len(state_files) == 0
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
loaded_pipeline = DataProcessorPipeline.from_pretrained(
|
||||
tmp_dir,
|
||||
config_filename="testrenameprocessorstep.json",
|
||||
to_transition=identity_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
|
||||
assert loaded_pipeline.name == "TestRenameProcessor"
|
||||
assert loaded_pipeline.name == "TestRenameProcessorStep"
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
# Check that loaded processor works correctly
|
||||
loaded_processor = loaded_pipeline.steps[0]
|
||||
assert isinstance(loaded_processor, RenameProcessor)
|
||||
assert isinstance(loaded_processor, RenameObservationsProcessorStep)
|
||||
assert loaded_processor.rename_map == rename_map
|
||||
|
||||
# Test functionality after loading
|
||||
@@ -259,24 +260,26 @@ def test_save_and_load_pretrained():
|
||||
|
||||
|
||||
def test_registry_functionality():
|
||||
"""Test that RenameProcessor is properly registered."""
|
||||
"""Test that RenameProcessorStep is properly registered."""
|
||||
# Check that it's registered
|
||||
assert "rename_processor" in ProcessorStepRegistry.list()
|
||||
assert "rename_observations_processor" in ProcessorStepRegistry.list()
|
||||
|
||||
# Get from registry
|
||||
retrieved_class = ProcessorStepRegistry.get("rename_processor")
|
||||
assert retrieved_class is RenameProcessor
|
||||
retrieved_class = ProcessorStepRegistry.get("rename_observations_processor")
|
||||
assert retrieved_class is RenameObservationsProcessorStep
|
||||
|
||||
# Create instance from registry
|
||||
instance = retrieved_class(rename_map={"old": "new"})
|
||||
assert isinstance(instance, RenameProcessor)
|
||||
assert isinstance(instance, RenameObservationsProcessorStep)
|
||||
assert instance.rename_map == {"old": "new"}
|
||||
|
||||
|
||||
def test_registry_based_save_load():
|
||||
"""Test save/load using registry name instead of module path."""
|
||||
processor = RenameProcessor(rename_map={"key1": "renamed_key1"})
|
||||
pipeline = RobotProcessor([processor])
|
||||
processor = RenameObservationsProcessorStep(rename_map={"key1": "renamed_key1"})
|
||||
pipeline = DataProcessorPipeline(
|
||||
[processor], to_transition=identity_transition, to_output=identity_transition
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save and load
|
||||
@@ -285,24 +288,26 @@ def test_registry_based_save_load():
|
||||
# Verify config uses registry name
|
||||
import json
|
||||
|
||||
with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor"
|
||||
with open(Path(tmp_dir) / "dataprocessorpipeline.json") as f: # Default name is "RobotProcessor"
|
||||
config = json.load(f)
|
||||
|
||||
assert "registry_name" in config["steps"][0]
|
||||
assert config["steps"][0]["registry_name"] == "rename_processor"
|
||||
assert config["steps"][0]["registry_name"] == "rename_observations_processor"
|
||||
assert "class" not in config["steps"][0] # Should use registry, not module path
|
||||
|
||||
# Load should work
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
loaded_pipeline = DataProcessorPipeline.from_pretrained(
|
||||
tmp_dir, config_filename="dataprocessorpipeline.json"
|
||||
)
|
||||
loaded_processor = loaded_pipeline.steps[0]
|
||||
assert isinstance(loaded_processor, RenameProcessor)
|
||||
assert isinstance(loaded_processor, RenameObservationsProcessorStep)
|
||||
assert loaded_processor.rename_map == {"key1": "renamed_key1"}
|
||||
|
||||
|
||||
def test_chained_rename_processors():
|
||||
"""Test multiple RenameProcessors in a pipeline."""
|
||||
"""Test multiple RenameProcessorSteps in a pipeline."""
|
||||
# First processor: rename raw keys to intermediate format
|
||||
processor1 = RenameProcessor(
|
||||
processor1 = RenameObservationsProcessorStep(
|
||||
rename_map={
|
||||
"pos": "agent_position",
|
||||
"img": "camera_image",
|
||||
@@ -310,14 +315,16 @@ def test_chained_rename_processors():
|
||||
)
|
||||
|
||||
# Second processor: rename to final format
|
||||
processor2 = RenameProcessor(
|
||||
processor2 = RenameObservationsProcessorStep(
|
||||
rename_map={
|
||||
"agent_position": "observation.state",
|
||||
"camera_image": "observation.image",
|
||||
}
|
||||
)
|
||||
|
||||
pipeline = RobotProcessor([processor1, processor2])
|
||||
pipeline = DataProcessorPipeline(
|
||||
[processor1, processor2], to_transition=identity_transition, to_output=identity_transition
|
||||
)
|
||||
|
||||
observation = {
|
||||
"pos": np.array([1.0, 2.0]),
|
||||
@@ -353,7 +360,7 @@ def test_nested_observation_rename():
|
||||
"observation.images.right": "observation.camera.right_view",
|
||||
"observation.proprio": "observation.proprioception",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.images.left": torch.randn(3, 64, 64),
|
||||
@@ -383,7 +390,7 @@ def test_nested_observation_rename():
|
||||
def test_value_types_preserved():
|
||||
"""Test that various value types are preserved during renaming."""
|
||||
rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
tensor_value = torch.randn(3, 3)
|
||||
array_value = np.random.rand(2, 2)
|
||||
@@ -410,58 +417,87 @@ def test_value_types_preserved():
|
||||
assert processed_obs["old_list"] == [1, 2, 3]
|
||||
|
||||
|
||||
def test_feature_contract_basic_renaming(policy_feature_factory):
|
||||
processor = RenameProcessor(rename_map={"a": "x", "b": "y"})
|
||||
def test_features_basic_renaming(policy_feature_factory):
|
||||
processor = RenameObservationsProcessorStep(rename_map={"a": "x", "b": "y"})
|
||||
features = {
|
||||
"a": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"b": policy_feature_factory(FeatureType.ACTION, (3,)),
|
||||
"c": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"a": policy_feature_factory(FeatureType.VISUAL, (2,)),
|
||||
"b": policy_feature_factory(FeatureType.VISUAL, (3,)),
|
||||
"c": policy_feature_factory(FeatureType.VISUAL, (1,)),
|
||||
},
|
||||
}
|
||||
|
||||
out = processor.feature_contract(features.copy())
|
||||
out = processor.transform_features(features.copy())
|
||||
|
||||
# Values preserved and typed
|
||||
assert out["x"] == features["a"]
|
||||
assert out["y"] == features["b"]
|
||||
assert out["c"] == features["c"]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["x"] == features[PipelineFeatureType.OBSERVATION]["a"]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["y"] == features[PipelineFeatureType.OBSERVATION]["b"]
|
||||
assert out[PipelineFeatureType.OBSERVATION]["c"] == features[PipelineFeatureType.OBSERVATION]["c"]
|
||||
|
||||
assert_contract_is_typed(out)
|
||||
# Input not mutated
|
||||
assert set(features) == {"a", "b", "c"}
|
||||
assert set(features[PipelineFeatureType.OBSERVATION]) == {"a", "b", "c"}
|
||||
|
||||
|
||||
def test_feature_contract_overlapping_keys(policy_feature_factory):
|
||||
def test_features_overlapping_keys(policy_feature_factory):
|
||||
# Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c'
|
||||
processor = RenameProcessor(rename_map={"a": "b", "b": "c"})
|
||||
processor = RenameObservationsProcessorStep(rename_map={"a": "b", "b": "c"})
|
||||
features = {
|
||||
"a": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||
"b": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"a": policy_feature_factory(FeatureType.VISUAL, (1,)),
|
||||
"b": policy_feature_factory(FeatureType.VISUAL, (2,)),
|
||||
},
|
||||
}
|
||||
out = processor.feature_contract(features)
|
||||
out = processor.transform_features(features)
|
||||
|
||||
assert set(out) == {"b", "c"}
|
||||
assert out["b"] == features["a"] # 'a' renamed to'b'
|
||||
assert out["c"] == features["b"] # 'b' renamed to 'c'
|
||||
assert set(out[PipelineFeatureType.OBSERVATION]) == {"b", "c"}
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["b"] == features[PipelineFeatureType.OBSERVATION]["a"]
|
||||
) # 'a' renamed to'b'
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["c"] == features[PipelineFeatureType.OBSERVATION]["b"]
|
||||
) # 'b' renamed to 'c'
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_feature_contract_chained_processors(policy_feature_factory):
|
||||
def test_features_chained_processors(policy_feature_factory):
|
||||
# Chain two rename processors at the contract level
|
||||
processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"})
|
||||
processor2 = RenameProcessor(
|
||||
processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"})
|
||||
processor2 = RenameObservationsProcessorStep(
|
||||
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
|
||||
)
|
||||
pipeline = RobotProcessor([processor1, processor2])
|
||||
pipeline = DataProcessorPipeline([processor1, processor2])
|
||||
|
||||
spec = {
|
||||
"pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"extra": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"pos": policy_feature_factory(FeatureType.VISUAL, (7,)),
|
||||
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"extra": policy_feature_factory(FeatureType.VISUAL, (1,)),
|
||||
},
|
||||
}
|
||||
out = pipeline.feature_contract(initial_features=spec)
|
||||
out = pipeline.transform_features(initial_features=spec)
|
||||
|
||||
assert set(out) == {"observation.state", "observation.image", "extra"}
|
||||
assert out["observation.state"] == spec["pos"]
|
||||
assert out["observation.image"] == spec["img"]
|
||||
assert out["extra"] == spec["extra"]
|
||||
assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"}
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["observation.state"]
|
||||
== spec[PipelineFeatureType.OBSERVATION]["pos"]
|
||||
)
|
||||
assert (
|
||||
out[PipelineFeatureType.OBSERVATION]["observation.image"]
|
||||
== spec[PipelineFeatureType.OBSERVATION]["img"]
|
||||
)
|
||||
assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_rename_stats_basic():
|
||||
orig = {
|
||||
"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])},
|
||||
"action": {"mean": np.array([0.0])},
|
||||
}
|
||||
mapping = {"observation.state": "observation.robot_state"}
|
||||
renamed = rename_stats(orig, mapping)
|
||||
assert "observation.robot_state" in renamed and "observation.state" not in renamed
|
||||
# Ensure deep copy: mutate original and verify renamed unaffected
|
||||
orig["observation.state"]["mean"][0] = 42.0
|
||||
assert renamed["observation.robot_state"]["mean"][0] != 42.0
|
||||
|
||||
414
tests/processor/test_sac_processor.py
Normal file
414
tests/processor/test_sac_processor.py
Normal file
@@ -0,0 +1,414 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for SAC policy processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, transition_to_batch
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default SAC configuration for testing."""
|
||||
config = SACConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)},
|
||||
ACTION: {"min": torch.full((5,), -1.0), "max": torch.ones(5)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_sac_processor_basic():
|
||||
"""Test basic creation of SAC processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
|
||||
|
||||
|
||||
def test_sac_processor_normalization_modes():
|
||||
"""Test that SAC processor correctly handles different normalization modes."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization
|
||||
action = torch.rand(5) * 2 - 1 # Range [-1, 1]
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is normalized and batched
|
||||
# State should be mean-std normalized
|
||||
# Action should be min-max normalized to [-1, 1]
|
||||
assert processed[OBS_STATE].shape == (1, 10)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1, 5)
|
||||
|
||||
# Process action through postprocessor
|
||||
postprocessed = postprocessor(processed[TransitionKey.ACTION.value])
|
||||
|
||||
# Check that action is unnormalized (but still batched)
|
||||
assert postprocessed.shape == (1, 5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_cuda():
|
||||
"""Test SAC processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[OBS_STATE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION.value].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
postprocessed = postprocessor(processed[TransitionKey.ACTION.value])
|
||||
|
||||
# Check that action is back on CPU
|
||||
assert postprocessed.device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_accelerate_scenario():
|
||||
"""Test SAC processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {OBS_STATE: torch.randn(10).to(device)}
|
||||
action = torch.randn(5).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_sac_processor_multi_gpu():
|
||||
"""Test SAC processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {OBS_STATE: torch.randn(10).to(device)}
|
||||
action = torch.randn(5).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
def test_sac_processor_without_stats():
|
||||
"""Test SAC processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_sac_processor_save_and_load():
|
||||
"""Test saving and loading SAC processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, config_filename="policy_preprocessor.json"
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = loaded_preprocessor(batch)
|
||||
assert processed[OBS_STATE].shape == (1, 10)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1, 5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_mixed_precision():
|
||||
"""Test SAC processor with mixed precision."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Replace DeviceProcessorStep with one that uses float16
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Update normalizer to use the same device as the device processor
|
||||
norm_step = step # Now type checker knows this is NormalizerProcessorStep
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=norm_step.features,
|
||||
norm_map=norm_step.norm_map,
|
||||
stats=norm_step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float16, # Match the float16 dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)}
|
||||
action = torch.randn(5, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[OBS_STATE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.float16
|
||||
|
||||
|
||||
def test_sac_processor_batch_data():
|
||||
"""Test SAC processor with batched data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 32
|
||||
observation = {OBS_STATE: torch.randn(batch_size, 10)}
|
||||
action = torch.randn(batch_size, 5)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed[OBS_STATE].shape == (batch_size, 10)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (batch_size, 5)
|
||||
|
||||
|
||||
def test_sac_processor_edge_cases():
|
||||
"""Test SAC processor with edge cases."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Test with observation that has no state key but still exists
|
||||
observation = {"observation.dummy": torch.randn(1)} # Some dummy observation to pass validation
|
||||
action = torch.randn(5)
|
||||
batch = {TransitionKey.ACTION.value: action, **observation}
|
||||
processed = preprocessor(batch)
|
||||
# observation.state wasn't in original, so it won't be in processed
|
||||
assert OBS_STATE not in processed
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1, 5)
|
||||
|
||||
# Test with zero action (representing "null" action)
|
||||
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=torch.zeros(5))
|
||||
batch = transition_to_batch(transition)
|
||||
processed = preprocessor(batch)
|
||||
assert processed[OBS_STATE].shape == (1, 10)
|
||||
# Action should be present and batched, even if it's zeros
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1, 5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, _ = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
norm_step = step # Now type checker knows this is NormalizerProcessorStep
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=norm_step.features,
|
||||
norm_map=norm_step.norm_map,
|
||||
stats=norm_step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration
|
||||
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} # Start with float32
|
||||
action = torch.randn(5, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[OBS_STATE].dtype == torch.bfloat16
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
459
tests/processor/test_smolvla_processor.py
Normal file
459
tests/processor/test_smolvla_processor.py
Normal file
@@ -0,0 +1,459 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for SmolVLA policy processor."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.processor_smolvla import (
|
||||
SmolVLANewLineProcessor,
|
||||
make_smolvla_pre_post_processors,
|
||||
)
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
NormalizerProcessorStep,
|
||||
ProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, transition_to_batch
|
||||
|
||||
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
"""Mock tokenizer processor step for testing."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Accept any arguments to mimic the real TokenizerProcessorStep interface
|
||||
pass
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Pass through transition unchanged
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
# Pass through features unchanged
|
||||
return features
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default SmolVLA configuration for testing."""
|
||||
config = SmolVLAConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
config.device = "cpu"
|
||||
config.vlm_model_name = "HuggingFaceTB/SmolVLM-Instruct"
|
||||
config.pad_language_to = "max_length"
|
||||
config.tokenizer_max_length = 100
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(8), "std": torch.ones(8)},
|
||||
OBS_IMAGE: {}, # No normalization for images
|
||||
ACTION: {"min": torch.full((7,), -1.0), "max": torch.ones(7)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_smolvla_processor_basic():
|
||||
"""Test basic creation of SmolVLA processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 6
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], SmolVLANewLineProcessor)
|
||||
# Step 3 would be TokenizerProcessorStep but it's mocked
|
||||
assert isinstance(preprocessor.steps[4], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[5], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
|
||||
|
||||
|
||||
def test_smolvla_newline_processor_single_task():
|
||||
"""Test SmolVLANewLineProcessor with single task string."""
|
||||
processor = SmolVLANewLineProcessor()
|
||||
|
||||
# Test with task that doesn't have newline
|
||||
transition = create_transition(complementary_data={"task": "test task"})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
|
||||
|
||||
# Test with task that already has newline
|
||||
transition = create_transition(complementary_data={"task": "test task\n"})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
|
||||
|
||||
|
||||
def test_smolvla_newline_processor_list_of_tasks():
|
||||
"""Test SmolVLANewLineProcessor with list of task strings."""
|
||||
processor = SmolVLANewLineProcessor()
|
||||
|
||||
# Test with list of tasks
|
||||
tasks = ["task1", "task2\n", "task3"]
|
||||
transition = create_transition(complementary_data={"task": tasks})
|
||||
result = processor(transition)
|
||||
expected = ["task1\n", "task2\n", "task3\n"]
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == expected
|
||||
|
||||
|
||||
def test_smolvla_newline_processor_empty_transition():
|
||||
"""Test SmolVLANewLineProcessor with empty transition."""
|
||||
processor = SmolVLANewLineProcessor()
|
||||
|
||||
# Test with no complementary_data
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
# Test with complementary_data but no task
|
||||
transition = create_transition(complementary_data={"other": "data"})
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
# Test with None task
|
||||
transition = create_transition(complementary_data={"task": None})
|
||||
result = processor(transition)
|
||||
assert result == transition
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_smolvla_processor_cuda():
|
||||
"""Test SmolVLA processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "HuggingFaceTB/SmolVLM-Instruct"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action, complementary_data={"task": "test task"})
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[OBS_STATE].device.type == "cuda"
|
||||
assert processed[OBS_IMAGE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION.value].device.type == "cuda"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_smolvla_processor_accelerate_scenario():
|
||||
"""Test SmolVLA processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "HuggingFaceTB/SmolVLM-Instruct"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU and batched
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 8).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_smolvla_processor_multi_gpu():
|
||||
"""Test SmolVLA processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Mock the tokenizer processor to act as pass-through
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, transition):
|
||||
return transition
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_config(self):
|
||||
return {"tokenizer_name": "HuggingFaceTB/SmolVLM-Instruct"}
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 8).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
def test_smolvla_processor_without_stats():
|
||||
"""Test SmolVLA processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
# Mock the tokenizer processor
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
dataset_stats=None,
|
||||
)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
|
||||
def test_smolvla_newline_processor_state_dict():
|
||||
"""Test SmolVLANewLineProcessor state dict methods."""
|
||||
processor = SmolVLANewLineProcessor()
|
||||
|
||||
# Test state_dict (should be empty)
|
||||
state = processor.state_dict()
|
||||
assert state == {}
|
||||
|
||||
# Test load_state_dict (should do nothing)
|
||||
processor.load_state_dict({})
|
||||
|
||||
# Test reset (should do nothing)
|
||||
processor.reset()
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
assert config == {}
|
||||
|
||||
|
||||
def test_smolvla_newline_processor_transform_features():
|
||||
"""Test SmolVLANewLineProcessor transform_features method."""
|
||||
processor = SmolVLANewLineProcessor()
|
||||
|
||||
# Test transform_features
|
||||
features = {
|
||||
PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
}
|
||||
result = processor.transform_features(features)
|
||||
assert result == features # Should return unchanged
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_smolvla_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch(
|
||||
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
|
||||
):
|
||||
preprocessor, _ = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration (SmolVLA has NormalizerProcessorStep at index 5)
|
||||
normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data with both state and visual observations
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(7, dtype=torch.float32)
|
||||
transition = create_transition(
|
||||
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[OBS_STATE].dtype == torch.bfloat16
|
||||
assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
# Check state stats (has normalization)
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
|
||||
467
tests/processor/test_tdmpc_processor.py
Normal file
467
tests/processor/test_tdmpc_processor.py
Normal file
@@ -0,0 +1,467 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for TDMPC policy processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, transition_to_batch
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default TDMPC configuration for testing."""
|
||||
config = TDMPCConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(12,)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(12), "std": torch.ones(12)},
|
||||
OBS_IMAGE: {}, # No normalization for images
|
||||
ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_tdmpc_processor_basic():
|
||||
"""Test basic creation of TDMPC processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
|
||||
|
||||
|
||||
def test_tdmpc_processor_normalization():
|
||||
"""Test that TDMPC processor correctly normalizes and unnormalizes data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is processed and batched
|
||||
assert processed[OBS_STATE].shape == (1, 12)
|
||||
assert processed[OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1, 6)
|
||||
|
||||
# Process action through postprocessor
|
||||
postprocessed = postprocessor(processed[TransitionKey.ACTION.value])
|
||||
|
||||
# Check that action is unnormalized (but still batched)
|
||||
assert postprocessed.shape == (1, 6)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_tdmpc_processor_cuda():
|
||||
"""Test TDMPC processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[OBS_STATE].device.type == "cuda"
|
||||
assert processed[OBS_IMAGE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION.value].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
postprocessed = postprocessor(processed[TransitionKey.ACTION.value])
|
||||
|
||||
# Check that action is back on CPU
|
||||
assert postprocessed.device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_tdmpc_processor_accelerate_scenario():
|
||||
"""Test TDMPC processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12).to(device),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(6).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_tdmpc_processor_multi_gpu():
|
||||
"""Test TDMPC processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12).to(device),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(6).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
def test_tdmpc_processor_without_stats():
|
||||
"""Test TDMPC processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_tdmpc_processor_save_and_load():
|
||||
"""Test saving and loading TDMPC processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, config_filename="policy_preprocessor.json"
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
processed = loaded_preprocessor(batch)
|
||||
assert processed[OBS_STATE].shape == (1, 12)
|
||||
assert processed[OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1, 6)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_tdmpc_processor_mixed_precision():
|
||||
"""Test TDMPC processor with mixed precision."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Replace DeviceProcessorStep with one that uses float16
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Update normalizer to use the same device as the device processor
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float16, # Match the float16 dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[OBS_STATE].dtype == torch.float16
|
||||
assert processed[OBS_IMAGE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.float16
|
||||
|
||||
|
||||
def test_tdmpc_processor_batch_data():
|
||||
"""Test TDMPC processor with batched data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 64
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(batch_size, 12),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed[OBS_STATE].shape == (batch_size, 12)
|
||||
assert processed[OBS_IMAGE].shape == (batch_size, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (batch_size, 6)
|
||||
|
||||
|
||||
def test_tdmpc_processor_edge_cases():
|
||||
"""Test TDMPC processor with edge cases."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Test with only state observation (no image)
|
||||
observation = {OBS_STATE: torch.randn(12)}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
assert processed[OBS_STATE].shape == (1, 12)
|
||||
assert OBS_IMAGE not in processed
|
||||
|
||||
# Test with only image observation (no state)
|
||||
observation = {OBS_IMAGE: torch.randn(3, 224, 224)}
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
assert processed[OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert OBS_STATE not in processed
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_tdmpc_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, _ = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration
|
||||
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data with both state and visual observations
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(12, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[OBS_STATE].dtype == torch.bfloat16
|
||||
assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
# Check state stats (has normalization)
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
|
||||
1029
tests/processor/test_tokenizer_processor.py
Normal file
1029
tests/processor/test_tokenizer_processor.py
Normal file
File diff suppressed because it is too large
Load Diff
462
tests/processor/test_vqbet_processor.py
Normal file
462
tests/processor/test_vqbet_processor.py
Normal file
@@ -0,0 +1,462 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for VQBeT policy processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, transition_to_batch
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default VQBeT configuration for testing."""
|
||||
config = VQBeTConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.VISUAL: NormalizationMode.IDENTITY,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(8), "std": torch.ones(8)},
|
||||
OBS_IMAGE: {}, # No normalization for images
|
||||
ACTION: {"min": torch.full((7,), -1.0), "max": torch.ones(7)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_vqbet_processor_basic():
|
||||
"""Test basic creation of VQBeT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "policy_preprocessor"
|
||||
assert postprocessor.name == "policy_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
|
||||
|
||||
|
||||
def test_vqbet_processor_with_images():
|
||||
"""Test VQBeT processor with image and state observations."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create test data with images and states
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is batched
|
||||
assert processed[OBS_STATE].shape == (1, 8)
|
||||
assert processed[OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1, 7)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_vqbet_processor_cuda():
|
||||
"""Test VQBeT processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[OBS_STATE].device.type == "cuda"
|
||||
assert processed[OBS_IMAGE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION.value].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
postprocessed = postprocessor(processed[TransitionKey.ACTION.value])
|
||||
|
||||
# Check that action is back on CPU
|
||||
assert postprocessed.device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_vqbet_processor_accelerate_scenario():
|
||||
"""Test VQBeT processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU and batched
|
||||
device = torch.device("cuda:0")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 8).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on same GPU
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_vqbet_processor_multi_gpu():
|
||||
"""Test VQBeT processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(1, 8).to(device),
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data stays on cuda:1
|
||||
assert processed[OBS_STATE].device == device
|
||||
assert processed[OBS_IMAGE].device == device
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
def test_vqbet_processor_without_stats():
|
||||
"""Test VQBeT processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_vqbet_processor_save_and_load():
|
||||
"""Test saving and loading VQBeT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, config_filename="policy_preprocessor.json"
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
processed = loaded_preprocessor(batch)
|
||||
assert processed[OBS_STATE].shape == (1, 8)
|
||||
assert processed[OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (1, 7)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_vqbet_processor_mixed_precision():
|
||||
"""Test VQBeT processor with mixed precision."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Replace DeviceProcessorStep with one that uses float16
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Update normalizer to use the same device as the device processor
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float16, # Match the float16 dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(7, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[OBS_STATE].dtype == torch.float16
|
||||
assert processed[OBS_IMAGE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.float16
|
||||
|
||||
|
||||
def test_vqbet_processor_large_batch():
|
||||
"""Test VQBeT processor with large batch sizes."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Test with large batch
|
||||
batch_size = 128
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(batch_size, 8),
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 7)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through preprocessor
|
||||
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed[OBS_STATE].shape == (batch_size, 8)
|
||||
assert processed[OBS_IMAGE].shape == (batch_size, 3, 224, 224)
|
||||
assert processed[TransitionKey.ACTION.value].shape == (batch_size, 7)
|
||||
|
||||
|
||||
def test_vqbet_processor_sequential_processing():
|
||||
"""Test VQBeT processor with sequential data processing."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Process multiple samples sequentially
|
||||
results = []
|
||||
for _ in range(5):
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
processed = preprocessor(batch)
|
||||
results.append(processed)
|
||||
|
||||
# Check that all results are consistent
|
||||
for result in results:
|
||||
assert result[OBS_STATE].shape == (1, 8)
|
||||
assert result[OBS_IMAGE].shape == (1, 3, 224, 224)
|
||||
assert result[TransitionKey.ACTION.value].shape == (1, 7)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_vqbet_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, _ = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
|
||||
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
# Device processor converts to bfloat16
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
|
||||
elif isinstance(step, NormalizerProcessorStep):
|
||||
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
|
||||
modified_steps.append(
|
||||
NormalizerProcessorStep(
|
||||
features=step.features,
|
||||
norm_map=step.norm_map,
|
||||
stats=step.stats,
|
||||
device=config.device,
|
||||
dtype=torch.float32, # Deliberately configured as float32
|
||||
)
|
||||
)
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Verify initial normalizer configuration
|
||||
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
|
||||
assert normalizer_step.dtype == torch.float32
|
||||
|
||||
# Create test data with both state and visual observations
|
||||
observation = {
|
||||
OBS_STATE: torch.randn(8, dtype=torch.float32),
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(7, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(batch)
|
||||
|
||||
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
|
||||
assert processed[OBS_STATE].dtype == torch.bfloat16
|
||||
assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16
|
||||
|
||||
# Verify normalizer automatically adapted its internal state
|
||||
assert normalizer_step.dtype == torch.bfloat16
|
||||
# Check state stats (has normalization)
|
||||
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
|
||||
assert stat_tensor.dtype == torch.bfloat16
|
||||
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
|
||||
Reference in New Issue
Block a user