Compare commits

..

166 Commits

Author SHA1 Message Date
AdilZouitine
dcd850feab Refactor SACObservationEncoder to improve modularity and readability. Split initialization into dedicated methods for image and state layers, and enhance caching logic for image features. Update forward method to streamline feature encoding and ensure proper normalization handling. 2025-04-18 15:10:22 +02:00
AdilZouitine
1ce368503d Refactor SACPolicy initialization by breaking down the constructor into smaller methods for normalization, encoders, critics, actor, and temperature setup. This enhances readability and maintainability. 2025-04-18 15:10:22 +02:00
AdilZouitine
fb075a709d Refactor input and output normalization handling in SACPolicy for improved clarity and efficiency. Consolidate encoder initialization logic and remove redundant else statements. 2025-04-18 15:10:22 +02:00
AdilZouitine
3424644ecd Fix init temp
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-18 15:10:22 +02:00
AdilZouitine
c37936f2c9 Update log_std_min type to float in PolicyConfig for consistency 2025-04-18 15:10:22 +02:00
AdilZouitine
c5382a450c fix caching
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-18 15:10:22 +02:00
AdilZouitine
2f7339b410 Handle caching
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-18 15:10:22 +02:00
AdilZouitine
9e5f254db0 change the tanh distribution to match hil serl
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-18 15:10:22 +02:00
AdilZouitine
8122721f6d match target entropy hil serl
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-18 15:10:22 +02:00
AdilZouitine
5c352ae558 stick to hil serl nn architecture
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-18 15:10:22 +02:00
AdilZouitine
9386892f8e Refactor modeling_sac and parameter handling for clarity and reusability.
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
2025-04-18 15:10:22 +02:00
AdilZouitine
267a837a2c fix encoder training 2025-04-18 15:10:22 +02:00
pre-commit-ci[bot]
28b595c651 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:10:22 +02:00
Michel Aractingi
9fd4c21d4d General fixes in code, removed delta action, fixed grasp penalty, added logic to put gripper reward in info 2025-04-18 15:10:22 +02:00
pre-commit-ci[bot]
02e1ed0bfb [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:10:22 +02:00
AdilZouitine
e18274bc9a fix caching and dataset stats is optional 2025-04-18 15:10:22 +02:00
AdilZouitine
68c271ad25 Add rounding for safety 2025-04-18 15:10:22 +02:00
pre-commit-ci[bot]
a3ada81816 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:10:22 +02:00
AdilZouitine
203315d378 fix sign issue 2025-04-18 15:10:22 +02:00
AdilZouitine
78c640b6d8 Refactor complementary_info handling in ReplayBuffer 2025-04-18 15:10:22 +02:00
AdilZouitine
d5a87f67cf Handle gripper penalty 2025-04-18 15:10:22 +02:00
AdilZouitine
8bcf41761d fix caching 2025-04-18 15:10:22 +02:00
pre-commit-ci[bot]
1efaf02df9 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:10:22 +02:00
AdilZouitine
cf58890bb0 fix indentation issue 2025-04-18 15:10:22 +02:00
AdilZouitine
7c2c67fc3c Enhance SAC configuration and replay buffer with asynchronous prefetching support
- Added async_prefetch parameter to SACConfig for improved buffer management.
- Implemented get_iterator method in ReplayBuffer to support asynchronous prefetching of batches.
- Updated learner_server to utilize the new iterator for online and offline sampling, enhancing training efficiency.
2025-04-18 15:10:22 +02:00
AdilZouitine
70130b9841 Enhance SACPolicy to support shared encoder and optimize action selection
- Cached encoder output in select_action method to reduce redundant computations.
- Updated action selection and grasp critic calls to utilize cached encoder features when available.
2025-04-18 15:10:22 +02:00
AdilZouitine
6167886472 Enhance SACPolicy and learner server for improved grasp critic integration
- Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions.
- Refactored the forward method to handle grasp critic model selection and loss computation more clearly.
- Adjusted learner server to utilize optimized parameters for grasp critic during training.
- Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs.
2025-04-18 15:10:22 +02:00
AdilZouitine
f9fb9d4594 Refactor SACPolicy for improved readability and action dimension handling
- Cleaned up code formatting for better readability, including consistent spacing and removal of unnecessary blank lines.
- Consolidated continuous action dimension calculation to enhance clarity and maintainability.
- Simplified loss return statements in the forward method to improve code structure.
- Ensured grasp critic parameters are included conditionally based on configuration settings.
2025-04-18 15:10:22 +02:00
AdilZouitine
d86d29fe21 Add mock gripper support and enhance SAC policy action handling
- Introduced mock_gripper parameter in ManiskillEnvConfig to enable gripper simulation.
- Added ManiskillMockGripperWrapper to adjust action space for environments with discrete actions.
- Updated SACPolicy to compute continuous action dimensions correctly, ensuring compatibility with the new gripper setup.
- Refactored action handling in the training loop to accommodate the changes in action dimensions.
2025-04-18 15:10:22 +02:00
AdilZouitine
f83d215e7a Refactor SAC policy and training loop to enhance discrete action support
- Updated SACPolicy to conditionally compute losses for grasp critic based on num_discrete_actions.
- Simplified forward method to return loss outputs as a dictionary for better clarity.
- Adjusted learner_server to handle both main and grasp critic losses during training.
- Ensured optimizers are created conditionally for grasp critic based on configuration settings.
2025-04-18 15:10:22 +02:00
AdilZouitine
7361a11a4d Refactor SAC configuration and policy to support discrete actions
- Removed GraspCriticNetworkConfig class and integrated its parameters into SACConfig.
- Added num_discrete_actions parameter to SACConfig for better action handling.
- Updated SACPolicy to conditionally create grasp critic networks based on num_discrete_actions.
- Enhanced grasp critic forward pass to handle discrete actions and compute losses accordingly.
2025-04-18 15:10:22 +02:00
Michel Aractingi
0cce2fe0fa Added Gripper quantization wrapper and grasp penalty
removed complementary info from buffer and learner server
removed get_gripper_action function
added gripper parameters to `common/envs/configs.py`
2025-04-18 15:10:22 +02:00
pre-commit-ci[bot]
88d26ae976 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:10:22 +02:00
s1lent4gnt
3a2308d86f Add grasp critic to the training loop
- Integrated the grasp critic gradient update to the training loop in learner_server
- Added Adam optimizer and configured grasp critic learning rate in configuration_sac
- Added target critics networks update after the critics gradient step
2025-04-18 15:10:22 +02:00
s1lent4gnt
fdd04efdb7 Add get_gripper_action method to GamepadController 2025-04-18 15:10:22 +02:00
s1lent4gnt
ff18be18ad Add gripper penalty wrapper 2025-04-18 15:10:22 +02:00
s1lent4gnt
427720426b Add complementary info in the replay buffer
- Added complementary info in the add method
- Added complementary info in the sample method
2025-04-18 15:10:22 +02:00
s1lent4gnt
66693965c0 Add grasp critic
- Implemented grasp critic to evaluate gripper actions
- Added corresponding config parameters for tuning
2025-04-18 15:10:22 +02:00
pre-commit-ci[bot]
334cf8143e [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:10:22 +02:00
AdilZouitine
5b49601072 Fix convergence of sac, multiple torch compile on the same model caused divergence 2025-04-18 15:10:22 +02:00
AdilZouitine
0185a0b6fd Fix cuda graph break 2025-04-18 15:10:22 +02:00
s1lent4gnt
70d418935d Fix: Prevent Invalid next_state References When optimize_memory=True (#918) 2025-04-18 15:10:22 +02:00
pre-commit-ci[bot]
eb44a06a9b [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:10:22 +02:00
Michel Aractingi
8eb3c1510c Added support for controlling the gripper with the pygame interface of gamepad
Minor modifications in gym_manipulator to quantize the gripper actions
clamped the observations after F.resize in ConvertToLeRobotObservation wrapper due to a bug in F.resize, images were returned exceeding the maximum value of 1.0
2025-04-18 15:10:22 +02:00
AdilZouitine
4d5ecb082e Refactor SACPolicy for improved type annotations and readability
- Enhanced type annotations for variables in the `SACPolicy` class to improve code clarity.
- Updated method calls to use keyword arguments for better readability.
- Streamlined the extraction of batch components, ensuring consistent typing across the class methods.
2025-04-18 15:10:22 +02:00
AdilZouitine
6e687e2910 Refactor SACPolicy and learner_server for improved clarity and functionality
- Updated the `forward` method in `SACPolicy` to handle loss computation for actor, critic, and temperature models.
- Replaced direct calls to `compute_loss_*` methods with a unified `forward` method in `learner_server`.
- Enhanced batch processing by consolidating input parameters into a single dictionary for better readability and maintainability.
- Removed redundant code and improved documentation for clarity.
2025-04-18 15:10:22 +02:00
AdilZouitine
eb710647bf Refactor actor_server.py for improved structure and logging
- Consolidated logging initialization and enhanced logging for actor processes.
- Streamlined the handling of gRPC connections and process management.
- Improved readability by organizing core algorithm functions and communication functions.
- Added detailed comments and documentation for clarity.
- Ensured proper queue management and shutdown handling for actor processes.
2025-04-18 15:10:22 +02:00
AdilZouitine
176557d770 Refactor learner_server.py for improved structure and clarity
- Removed unused imports and streamlined the code structure.
- Consolidated logging initialization and enhanced logging for training processes.
- Improved handling of training state loading and resume logic.
- Refactored transition and interaction message processing for better readability and maintainability.
- Added detailed comments and documentation for clarity.
2025-04-18 15:10:22 +02:00
AdilZouitine
3beab33fac Refactor imports in modeling_sac.py for improved organization
- Rearranged import statements for better readability.
- Removed unused imports and streamlined the code structure.
2025-04-18 15:10:22 +02:00
AdilZouitine
c0ba4b4954 Refactor SACConfig properties for improved readability
- Simplified the `image_features` property to directly iterate over `input_features`.
- Removed unused imports and unnecessary code related to main execution, enhancing clarity and maintainability.
2025-04-18 15:10:22 +02:00
AdilZouitine
8fb373aeb2 fix 2025-04-18 15:10:22 +02:00
AdilZouitine
5a0ee06651 Enhance logging for actor and learner servers
- Implemented process-specific logging for actor and learner servers to improve traceability.
- Created a dedicated logs directory and ensured it exists before logging.
- Initialized logging with explicit log files for each process, including actor transitions, interactions, and policy.
- Updated the actor CLI to validate configuration and set up logging accordingly.
2025-04-18 15:10:22 +02:00
Michel Aractingi
05a237ce10 Added gripper control mechanism to gym_manipulator
Moved HilSerl env config to configs/env/configs.py
fixes in actor_server and modeling_sac and configuration_sac
added the possibility of ignoring missing keys in env_cfg in get_features_from_env_config function
2025-04-18 15:10:22 +02:00
AdilZouitine
88cc2b8fc8 Add WrapperConfig for environment wrappers and update SACConfig properties
- Introduced `WrapperConfig` dataclass for environment wrapper configurations.
- Updated `ManiskillEnvConfig` to include a `wrapper` field for enhanced environment management.
- Modified `SACConfig` to return `None` for `observation_delta_indices` and `action_delta_indices` properties.
- Refactored `make_robot_env` function to improve readability and maintainability.
2025-04-18 15:10:22 +02:00
Michel Aractingi
b69132c79d Change HILSerlRobotEnvConfig to inherit from EnvConfig
Added support for hil_serl classifier to be trained with train.py
run classifier training by python lerobot/scripts/train.py --policy.type=hilserl_classifier
fixes in find_joint_limits, control_robot, end_effector_control_utils
2025-04-18 15:10:21 +02:00
AdilZouitine
db897a1619 [WIP] Update SAC configuration and environment settings
- Reduced frame rate in `ManiskillEnvConfig` from 400 to 200.
- Enhanced `SACConfig` with new dataclasses for actor, learner, and network configurations.
- Improved input and output feature management in `SACConfig`.
- Refactored `actor_server` and `learner_server` to access configuration properties directly.
- Updated training pipeline to validate configurations and handle dataset repo IDs more robustly.
2025-04-18 15:09:46 +02:00
AdilZouitine
0b5b62c8fb Add wandb run id in config 2025-04-18 15:09:46 +02:00
AdilZouitine
056f79d358 [WIP] Non functional yet
Add ManiSkill environment configuration and wrappers

- Introduced `VideoRecordConfig` for video recording settings.
- Added `ManiskillEnvConfig` to encapsulate environment-specific configurations.
- Implemented various wrappers for the ManiSkill environment, including observation and action scaling.
- Enhanced the `make_maniskill` function to create a wrapped ManiSkill environment with video recording and observation processing.
- Updated the `actor_server` and `learner_server` to utilize the new configuration structure.
- Refactored the training pipeline to accommodate the new environment and policy configurations.
2025-04-18 15:09:46 +02:00
Michel Aractingi
114ec644d0 Change config logic in:
- gym_manipulator
- find_joint_limits
- end_effector_utils
2025-04-18 15:09:45 +02:00
AdilZouitine
26ee8b6ae5 Add .devcontainer to .gitignore for improved development environment management 2025-04-18 15:09:27 +02:00
AdilZouitine
38e8864284 Add task field to frame_dict in ReplayBuffer and simplify save_episode calls
- Introduced a new "task" field in frame_dict to meet the requirements of LeRobotDataset.
- Removed task_name parameter from save_episode calls for consistency.
2025-04-18 15:09:27 +02:00
AdilZouitine
80d566eb56 Handle new config with sac 2025-04-18 15:09:27 +02:00
AdilZouitine
bb5a95889f Handle multi optimizers 2025-04-18 15:09:27 +02:00
pre-commit-ci[bot]
0ea27704f6 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:09:25 +02:00
Michel Aractingi
2abbd60a0d Removed depleted files and scripts 2025-04-18 15:07:48 +02:00
pre-commit-ci[bot]
1c8daf11fd [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:07:46 +02:00
AdilZouitine
cdcf346061 Update tensor device assignment in ReplayBuffer class
- Changed the device assignment for tensors in the ReplayBuffer class from `device` to `storage_device` for consistency and improved resource management.
2025-04-18 15:06:52 +02:00
pre-commit-ci[bot]
42f95e827d [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:06:52 +02:00
AdilZouitine
618ed00d45 Initialize log_alpha with the logarithm of temperature_init in SACPolicy
- Updated the SACPolicy class to set log_alpha using the logarithm of the initial temperature value from the configuration.
2025-04-18 15:06:52 +02:00
pre-commit-ci[bot]
50d8db481e [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:06:52 +02:00
AdilZouitine
e4a5971ffd Remove unused functions and imports from modeling_sac.py
- Deleted the `find_and_copy_params` function and the `Ensemble` class, as they were deemed unnecessary.
- Cleaned up imports by removing `from_modules` from `tensordict` to enhance code clarity.
- Simplified the assertion in the `Policy` class for better readability.
2025-04-18 15:06:52 +02:00
AdilZouitine
36f9ccd851 Add intervention rate tracking in act_with_policy function
- Introduced counters for tracking intervention steps and total steps during training.
- Calculated and logged the intervention rate at the end of each episode.
- Reset intervention counters after each episode to ensure accurate tracking.
2025-04-18 15:06:52 +02:00
AdilZouitine
787aee0e60 - Updated the logging condition to use log_freq directly instead of accessing it through cfg.training.log_freq for improved readability and speed. 2025-04-18 15:06:52 +02:00
Eugene Mironov
0341a38fdd [PORT HIL-SERL] Optimize training loop, extract config usage (#855)
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:06:52 +02:00
AdilZouitine
ffbed4a141 Enhance training information logging in learner server
- Added tracking for replay buffer size and offline replay buffer size during training steps.
2025-04-18 15:06:52 +02:00
AdilZouitine
03fe0f054b Update configuration files for improved performance and flexibility
- Increased frame rate in `maniskill_example.yaml` from 20 to 400 for enhanced simulation speed.
- Updated `sac_maniskill.yaml` to set `dataset_repo_id` to null and adjusted `grad_clip_norm` from 10.0 to 40.0.
- Changed `storage_device` from "cpu" to "cuda" for better resource utilization.
- Modified `save_freq` from 2000000 to 1000000 to optimize saving intervals.
- Enhanced input normalization parameters for `observation.state` and `observation.image` in SAC policy.
- Adjusted `num_critics` from 10 to 2 and `policy_parameters_push_frequency` from 1 to 4 for improved training dynamics.
- Updated `learner_server.py` to utilize `offline_buffer_capacity` for replay buffer initialization.
- Changed action multiplier in `maniskill_manipulator.py` from 1 to 0.03 for finer control over actions.
2025-04-18 15:06:52 +02:00
pre-commit-ci[bot]
fd74c194b6 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:06:52 +02:00
AdilZouitine
0959694bab Refactor SACPolicy and learner server for improved replay buffer management
- Updated SACPolicy to create critic heads using a list comprehension for better readability.
- Simplified the saving and loading of models using `save_model` and `load_model` functions from the safetensors library.
- Introduced `initialize_offline_replay_buffer` function in the learner server to streamline offline dataset handling and replay buffer initialization.
- Enhanced logging for dataset loading processes to improve traceability during training.
2025-04-18 15:06:52 +02:00
Michel Aractingi
7b01e16439 Add end effector action space to hil-serl (#861)
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-04-18 15:06:52 +02:00
AdilZouitine
66816fd871 Enhance SAC configuration and policy with gradient clipping and temperature management
- Introduced `grad_clip_norm` parameter in SAC configuration for gradient clipping
- Updated SACPolicy to store temperature as an instance variable for consistent usage
- Modified loss calculations in SACPolicy to utilize the instance temperature
- Enhanced MLP and CriticHead to support a customizable final activation function
- Implemented gradient clipping in the learner server during training steps for both actor and critic
- Added tracking for gradient norms in training information
2025-04-18 15:06:52 +02:00
pre-commit-ci[bot]
599326508f [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:06:52 +02:00
AdilZouitine
2f04d0d2b9 Add custom save and load methods for SAC policy
- Implement `_save_pretrained` method to handle TensorDict state saving
- Add `_from_pretrained` class method for loading SAC policy from files
- Create utility function `find_and_copy_params` to handle parameter copying
2025-04-18 15:06:52 +02:00
AdilZouitine
e002c5ec56 Remove torch.no_grad decorator and optimize next action prediction in SAC policy
- Removed `@torch.no_grad` decorator from Unnormalize forward method

- Added TODO comment for optimizing next action prediction in SAC policy
- Minor formatting adjustment in NaN assertion for log standard deviation
Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com>
2025-04-18 15:06:52 +02:00
s1lent4gnt
3dfb37e976 [Port HIL-SERL] Balanced sampler function speed up and refactor to align with train.py (#715)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-04-18 15:06:52 +02:00
Eugene Mironov
b6a2200983 [HIL-SERL] Migrate threading to multiprocessing (#759)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-04-18 15:06:52 +02:00
pre-commit-ci[bot]
85fe8a3f4e [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-04-18 15:06:51 +02:00
AdilZouitine
bb69cb3c8c Add storage device configuration for SAC policy and replay buffer
- Introduce `storage_device` parameter in SAC configuration and training settings
- Update learner server to use configurable storage device for replay buffer
- Reduce online buffer capacity in ManiSkill configuration
- Modify replay buffer initialization to support custom storage device
2025-04-18 15:04:58 +02:00
AdilZouitine
ae51c19b3c Add memory optimization option to ReplayBuffer
- Introduce `optimize_memory` parameter to reduce memory usage in replay buffer
- Implement simplified memory optimization by not storing duplicate next_states
- Update learner server and buffer initialization to use memory optimization by default
2025-04-18 15:04:58 +02:00
AdilZouitine
9ea79f8a76 Add storage device parameter to replay buffer initialization
- Specify storage device for replay buffer to optimize memory management
2025-04-18 15:04:58 +02:00
AdilZouitine
1d4ec50a58 Refactor ReplayBuffer with tensor-based storage and improved sampling efficiency
- Replaced list-based memory storage with pre-allocated tensor storage
- Optimized sampling process with direct tensor indexing
- Added support for DrQ image augmentation during sampling for offline dataset
- Improved dataset conversion with more robust episode handling
- Enhanced buffer initialization and state tracking
- Added comprehensive testing for buffer conversion and sampling
2025-04-18 15:04:58 +02:00
AdilZouitine
4c73891575 Update ManiSkill configuration and replay buffer to support truncation and dataset handling
- Reduced image size in ManiSkill environment configuration from 128 to 64
- Added support for truncation in replay buffer and actor server
- Updated SAC policy configuration to use a specific dataset and modify vision encoder settings
- Improved dataset conversion process with progress tracking and task naming
- Added flexibility for joint action space masking in learner server
2025-04-18 15:04:58 +02:00
Michel Aractingi
d3b84ecd6f Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen.
Added tensordict dependencies
Updated the version of torch and torchvision

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:58 +02:00
Eugene Mironov
e1d55c7a44 [Port HIL-SERL] Adjust Actor-Learner architecture & clean up dependency management for HIL-SERL (#722) 2025-04-18 15:04:56 +02:00
AdilZouitine
85242cac67 Refactor SAC policy with performance optimizations and multi-camera support
- Introduced Ensemble and CriticHead classes for more efficient critic network handling
- Added support for multiple camera inputs in observation encoder
- Optimized image encoding by batching image processing
- Updated configuration for ManiSkill environment with reduced image size and action scaling
- Compiled critic networks for improved performance
- Simplified normalization and ensemble handling in critic networks
Co-authored-by: michel-aractingi <michel.aractingi@gmail.com>
2025-04-18 15:04:44 +02:00
Michel Aractingi
0d88a5ee09 - Fixed big issue in the loading of the policy parameters sent by the learner to the actor -- pass only the actor to the update_policy_parameters and remove strict=False
- Fixed big issue in the normalization of the actions in the `forward` function of the critic -- remove the `torch.no_grad` decorator in `normalize.py` in the normalization function
- Fixed performance issue to boost the optimization frequency by setting the storage device to be the same as the device of learning.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:44 +02:00
AdilZouitine
62e237bdee Re-enable parameter push thread in learner server
- Uncomment and start the param_push_thread
- Restore thread joining for param_push_thread
2025-04-18 15:04:44 +02:00
AdilZouitine
c85f88fb62 Improve wandb logging and custom step tracking in logger
- Modify logger to support multiple custom step keys
- Update logging method to handle custom step keys more flexibly

- Enhance logging of optimization step and frequency
Co-authored-by: michel-aractingi  <michel.aractingi@gmail.com>
2025-04-18 15:04:44 +02:00
AdilZouitine
a90f4872f2 Add maniskill support.
Co-authored-by: Michel Aractingi <michel.aractingi@gmail.com>
2025-04-18 15:04:44 +02:00
Michel Aractingi
a16ea283f5 Fixed bug in the action scale of the intervention actions and offline dataset actions. (scale by inverse delta)
Co-authored-by: Adil Zouitine <adizouitinegm@gmail.com>
2025-04-18 15:04:44 +02:00
Michel Aractingi
8209a6dfb7 Modified crop_dataset_roi interface to automatically write the cropped parameters to a json file in the meta of the dataset
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:44 +02:00
Michel Aractingi
b5fbeb7401 Optimized the replay buffer from the memory side to store data on cpu instead of a gpu device and send the batches to the gpu.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:44 +02:00
Michel Aractingi
2ac25b02e2 nit
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:43 +02:00
Michel Aractingi
39fe4b1301 removed uncomment in actor server
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:43 +02:00
Michel Aractingi
140e30e386 Changed the init_final value to center the starting mean and std of the policy
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:43 +02:00
Michel Aractingi
ddcc0415e4 Changed bounds for a new so100 robot
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:43 +02:00
Michel Aractingi
5195f40fd3 Hardcoded some normalization parameters. TODO refactor
Added masking actions on the level of the intervention actions and offline dataset

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:43 +02:00
Michel Aractingi
98c6557869 fix log_alpha in modeling_sac: change to nn.parameter
added pretrained vision model in policy

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:43 +02:00
Michel Aractingi
ee820859d3 Added logging for interventions to monitor the rate of interventions through time
Added an s keyboard command to force success in the case the reward classifier fails

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:43 +02:00
Michel Aractingi
5d6879d93a Added possiblity to record and replay delta actions during teleoperation rather than absolute actions
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:42 +02:00
Yoel
fae47d58d3 [PORT-Hilserl] classifier fixes (#695)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Eugene Mironov
3a07301365 [Port HIL-SERL] Add resnet-10 as default encoder for HIL-SERL (#696)
Co-authored-by: Khalil Meftah <kmeftah.khalil@gmail.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Ke Wang <superwk1017@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
f1af97dc9c - Added JointMaskingActionSpace wrapper in gym_manipulator in order to select which joints will be controlled. For example, we can disable the gripper actions for some tasks.
- Added Nan detection mechanisms in the actor, learner and gym_manipulator for the case where we encounter nans in the loop.
- changed the non-blocking in the `.to(device)` functions to only work for the case of cuda because they were causing nans when running the policy on mps
- Added some joint clipping and limits in the env, robot and policy configs. TODO clean this part and make the limits in one config file only.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
f2266101df Added sac_real config file in the policym configs dir.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
9784d8a47f Several fixes to move the actor_server and learner_server code from the maniskill environment to the real robot environment.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Eugene Mironov
af769abd8d [HIL-SERL port] Add Reward classifier benchmark tracking to chose best visual encoder (#688) 2025-04-18 15:04:13 +02:00
Michel Aractingi
12c13e320e - Added lerobot/scripts/server/gym_manipulator.py that contains all the necessary wrappers to run a gym-style env around the real robot.
- Added `lerobot/scripts/server/find_joint_limits.py` to test the min and max angles of the motion you wish the robot to explore during RL training.
- Added logic in `manipulator.py` to limit the maximum possible joint angles to allow motion within a predefined joint position range. The limits are specified in the yaml config for each robot. Checkout the so100.yaml.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
273fa2e6e1 fixed bug in crop_dataset_roi.py
added missing buffer.pt in server dir

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
d143043037 Added additional wrappers for the environment: Action repeat, keyboard interface, reset wrapper
Tested the reset mechanism and keyboard interface and the convert wrapper on the robots.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
ca45c34ad5 Added crop_dataset_roi.py that allows you to load a lerobotdataset -> crop its images -> create a new lerobot dataset with the cropped and resized images.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
b1679050de - Added base gym env class for the real robot environment.
- Added several wrappers around the base gym env robot class.
- Including: time limit, reward classifier, crop images, preprocess observations.
- Added an interactive script crop_roi.py where the user can interactively select the roi in the observation images and return the correct crop values that will improve the policy and reward classifier performance.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
d2c41b35db - Refactor observation encoder in modeling_sac.py
- added `torch.compile` to the actor and learner servers.
- organized imports in `train_sac.py`
- optimized the parameters push by not sending the frozen pre-trained encoder.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Yoel
bc7b6d3daf [Port HIL-SERL] Add HF vision encoder option in SAC (#651)
Added support with custom pretrained vision encoder to the modeling sac implementation. Great job @ChorntonYoel !
2025-04-18 15:04:13 +02:00
Michel Aractingi
2516101cba Cleaned learner_server.py. Added several block function to improve readability.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
aebea08a99 Added support for checkpointing the policy. We can save and load the policy state dict, optimizers state, optimization step and interaction step
Added functions for converting the replay buffer from and to LeRobotDataset. When we want to save the replay buffer, we convert it first to LeRobotDataset format and save it locally and vice-versa.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
03616db82c Removed unnecessary time.sleep in the streaming server on the learner side
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
93c4fc198f Added missing config files env/maniskill_example.yaml and policy/sac_maniskill.yaml that are necessary to run the lerobot implementation of sac with the maniskill baselines.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
8cd44ae163 - Added additional logging information in wandb around the timings of the policy loop and optimization loop.
- Optimized critic design that improves the performance of the learner loop by a factor of 2
- Cleaned the code and fixed style issues

- Completed the config with actor_learner_config field that contains host-ip and port elemnts that are necessary for the actor-learner servers.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
2ae657f568 FREEDOM, added back the optimization loop code in learner_server.py
Ran experiment with pushcube env from maniskill. The learning seem to work.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
Michel Aractingi
508f5d1407 Added server directory in lerobot/scripts that contains scripts and the protobuf message types to split training into two processes, acting and learning. The actor rollouts the policy and collects interaction data while the learner recieves the data, trains the policy and sends the updated parameters to the actor. The two scripts are ran simultaneously
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-04-18 15:04:13 +02:00
AdilZouitine
c8b1132846 Stable version of rlpd + drq 2025-04-18 15:04:10 +02:00
AdilZouitine
ef777993cd Add type annotations and restructure SACConfig class fields 2025-04-18 15:03:51 +02:00
Adil Zouitine
760d60ad4b Change SAC policy implementation with configuration and modeling classes 2025-04-18 15:03:51 +02:00
Adil Zouitine
875c0271b7 SAC works 2025-04-18 15:03:51 +02:00
Adil Zouitine
57344bfde5 [WIP] correct sac implementation 2025-04-18 15:03:51 +02:00
Adil Zouitine
46827fb002 Add rlpd tricks 2025-04-18 15:03:51 +02:00
Adil Zouitine
2fd78879f6 SAC works 2025-04-18 15:03:51 +02:00
Adil Zouitine
e8449e9630 remove breakpoint 2025-04-18 15:03:51 +02:00
Adil Zouitine
a0e2be8b92 [WIP] correct sac implementation 2025-04-18 15:03:51 +02:00
Michel Aractingi
181727c0fe Extend reward classifier for multiple camera views (#626) 2025-04-18 15:03:50 +02:00
Eugene Mironov
d1d6ffd23c [Port HIL_SERL] Final fixes for the Reward Classifier (#598) 2025-04-18 15:03:01 +02:00
Michel Aractingi
e5801f467f added temporary fix for missing task_index key in online environment 2025-04-18 15:03:01 +02:00
Michel Aractingi
c6ca9523de split encoder for critic and actor 2025-04-18 15:03:01 +02:00
Michel Aractingi
642e3a3274 style fixes 2025-04-18 15:03:01 +02:00
KeWang1017
146148c48c Refactor SAC configuration and policy for improved action sampling and stability
- Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions.
- Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior.
- Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
2025-04-18 15:03:01 +02:00
KeWang1017
8f15835daa Refine SAC configuration and policy for enhanced performance
- Updated standard deviation parameterization in SACConfig to 'softplus' with defined min and max values for improved stability.
- Modified action sampling in SACPolicy to use reparameterized sampling, ensuring better gradient flow and log probability calculations.
- Cleaned up log probability calculations in TanhMultivariateNormalDiag for clarity and efficiency.
- Increased evaluation frequency in YAML configuration to 50000 for more efficient training cycles.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
2025-04-18 15:03:01 +02:00
KeWang1017
022bd65125 Refactor SACPolicy for improved action sampling and standard deviation handling
- Updated action selection to use distribution sampling and log probabilities for better stochastic behavior.
- Enhanced standard deviation clamping to prevent extreme values, ensuring stability in policy outputs.
- Cleaned up code by removing unnecessary comments and improving readability.

These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
2025-04-18 15:03:01 +02:00
KeWang1017
63d8c96514 trying to get sac running 2025-04-18 15:03:01 +02:00
Michel Aractingi
4624a836e5 Added normalization schemes and style checks 2025-04-18 15:03:01 +02:00
Michel Aractingi
ad7eea132d added optimizer and sac to factory.py 2025-04-18 15:02:59 +02:00
Eugene Mironov
22a1899ff4 [HIL-SERL PORT] Fix linter issues (#588) 2025-04-18 15:02:44 +02:00
Eugene Mironov
17a3a31b5f [Port Hil-SERL] Add unit tests for the reward classifier & fix imports & check script (#578) 2025-04-18 15:02:42 +02:00
Michel Aractingi
1a8b99e360 added comments from kewang 2025-04-18 15:02:13 +02:00
KeWang1017
6db2154f28 Enhance SAC configuration and policy with new parameters and subsampling logic
- Added `num_subsample_critics`, `critic_target_update_weight`, and `utd_ratio` to SACConfig.
- Implemented target entropy calculation in SACPolicy if not provided.
- Introduced subsampling of critics to prevent overfitting during updates.
- Updated temperature loss calculation to use the new target entropy.
- Added comments for future UTD update implementation.

These changes improve the flexibility and performance of the SAC implementation.
2025-04-18 15:02:13 +02:00
KeWang
be3adda95f Port SAC WIP (#581)
Co-authored-by: KeWang1017 <ke.wang@helloleap.ai>
2025-04-18 15:02:13 +02:00
Michel Aractingi
9d48d236c1 completed losses 2025-04-18 15:02:13 +02:00
Michel Aractingi
b57d6a7776 nit in control_robot.py 2025-04-18 15:02:13 +02:00
Michel Aractingi
d1f76cba8e Update lerobot/scripts/train_hilserl_classifier.py
Co-authored-by: Yoel <yoel.chornton@gmail.com>
2025-04-18 15:02:13 +02:00
Eugene Mironov
d78cef1fee Fixup 2025-04-18 15:02:13 +02:00
Michel Aractingi
30a808c0ae Add human intervention mechanism and eval_robot script to evaluate policy on the robot (#541)
Co-authored-by: Yoel <yoel.chornton@gmail.com>
2025-04-18 15:02:13 +02:00
Yoel
4a7f85a6ec Reward classifier and training (#528)
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai>
Co-authored-by: resolver101757 <kelster101757@hotmail.com>
Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com>
Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-04-18 15:02:13 +02:00
k1000dai
b43ece8934 Add pythno3-dev in Dockerfile to build and modify Readme.md , python-dev to python3-dev (#987)
Co-authored-by: makolon <smakolon385@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-04-17 16:17:07 +02:00
Alex Thiele
c10c5a0e64 Fix --width --height type parsing on opencv and intelrealsense scripts (#556)
Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-04-17 15:19:23 +02:00
Junshan Huang
a8db91c40e Fix Windows HTML visualization to make videos could be seen (#647)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-04-17 15:07:28 +02:00
HUANG TZU-CHUN
0f5f7ac780 Fix broken links in examples/4_train_policy_with_script.md (#697) 2025-04-17 14:59:43 +02:00
pre-commit-ci[bot]
768e36660d [pre-commit.ci] pre-commit autoupdate (#980)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-04-14 21:55:06 +02:00
Caroline Pascal
790d6740ba fix(installation): adding note on ffmpeg version during installation (#976)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-04-14 15:36:31 +02:00
191 changed files with 16857 additions and 9033 deletions

3
.gitignore vendored
View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
.dev
# Logging
logs
tmp
@@ -26,6 +26,7 @@ outputs
# VS Code
.vscode
.devcontainer
# HPC
nautilus/*.yaml

View File

@@ -46,9 +46,9 @@ repos:
rev: v3.19.1
hooks:
- id: pyupgrade
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.4
rev: v0.11.5
hooks:
- id: ruff
args: [--fix]
@@ -57,7 +57,7 @@ repos:
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
rev: v8.24.2
rev: v8.24.3
hooks:
- id: gitleaks

View File

@@ -103,13 +103,20 @@ When using `miniconda`, install `ffmpeg` in your environment:
conda install ffmpeg -c conda-forge
```
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can:
> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
> ```bash
> conda install ffmpeg=7.1.1 -c conda-forge
> ```
> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
Install 🤗 LeRobot:
```bash
pip install -e .
```
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
`sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
`sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
- [aloha](https://github.com/huggingface/gym-aloha)

View File

@@ -32,7 +32,11 @@ import numpy as np
import pandas as pd
import PIL
import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from skimage.metrics import (
mean_squared_error,
peak_signal_noise_ratio,
structural_similarity,
)
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@@ -94,7 +98,11 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
def save_decoded_frames(
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
imgs_dir: Path,
save_dir: Path,
frames: torch.Tensor,
timestamps: list[float],
fps: int,
) -> None:
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
return
@@ -104,7 +112,10 @@ def save_decoded_frames(
idx = int(ts * fps)
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
shutil.copyfile(
imgs_dir / f"frame_{idx:06d}.png",
save_dir / f"frame_{idx:06d}_original.png",
)
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
@@ -120,7 +131,11 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate(
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
tqdm(
imgs_dataset,
desc=f"saving {dataset.repo_id} first episode images",
leave=False,
)
):
img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
@@ -275,7 +290,9 @@ def benchmark_encoding_decoding(
random.seed(seed)
benchmark_table = []
for timestamps_mode in tqdm(
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
decoding_cfg["timestamps_modes"],
desc="decodings (timestamps_modes)",
leave=False,
):
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
benchmark_row = benchmark_decoding(

View File

@@ -14,7 +14,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
tcpdump sysstat screen tmux \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
speech-dispatcher portaudio19-dev libgeos-dev \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv python${PYTHON_VERSION}-dev \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# Install ffmpeg build dependencies. See:

View File

@@ -191,7 +191,7 @@ python lerobot/scripts/configure_motor.py \
--brand feetech \
--model sts3215 \
--baudrate 1000000 \
--id 1
--ID 1
```
> [!NOTE]
@@ -204,7 +204,7 @@ python lerobot/scripts/configure_motor.py \
--brand feetech \
--model sts3215 \
--baudrate 1000000 \
--id 2
--ID 2
```
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.

View File

@@ -138,7 +138,7 @@ python lerobot/scripts/configure_motor.py \
--brand feetech \
--model sts3215 \
--baudrate 1000000 \
--id 1
--ID 1
```
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
@@ -150,7 +150,7 @@ python lerobot/scripts/configure_motor.py \
--brand feetech \
--model sts3215 \
--baudrate 1000000 \
--id 2
--ID 2
```
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.

View File

@@ -0,0 +1,94 @@
# Training a HIL-SERL Reward Classifier with LeRobot
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
---
## Training Script Overview
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
1. **Configuration Loading**
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
2. **Dataset Initialization**
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
3. **Classifier Initialization**
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
- A single probability (binary classification), or
- Logits (multi-class classification).
4. **Training Loop Execution**
The script performs:
- Forward and backward passes,
- Optimization steps,
- Periodic logging, evaluation, and checkpoint saving.
---
## Configuring with Hydra
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
### Config File Setup
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
```yaml
# Example: lerobot/configs/policy/reward_classifier.yaml
dataset_repo_id: "my_dataset_repo_id"
## Typical logs and metrics
```
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
After that, you will see training log like this one:
```
[2024-11-29 18:26:36,999][root][INFO] -
Epoch 5/5
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
```
or evaluation log like:
```
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
```
### Metrics Tracking with Weights & Biases (WandB)
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
- **Training Metrics**:
- `train/accuracy`
- `train/loss`
- `train/dataloading_s`
- **Evaluation Metrics**:
- `eval/accuracy`
- `eval/loss`
- `eval/eval_s`
#### Additional Features
You can also log sample predictions during evaluation. Each logged sample will include:
- The **input image**.
- The **predicted label**.
- The **true label**.
- The **classifier's "confidence" (logits/probability)**.
These logs can be useful for diagnosing and debugging performance issues.
#### Generate protobuf files
```bash
python -m grpc_tools.protoc \
-I lerobot/scripts/server \
--python_out=lerobot/scripts/server \
--grpc_python_out=lerobot/scripts/server \
lerobot/scripts/server/hilserl.proto
```

View File

@@ -32,7 +32,10 @@ import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
# We ported a number of existing datasets ourselves, use this to see the list:
print("List of available datasets:")

View File

@@ -22,7 +22,10 @@ from pathlib import Path
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
@@ -77,7 +80,24 @@ def main():
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to supervise the policy.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
"action": [
-0.1,
0.0,
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1.0,
1.1,
1.2,
1.3,
1.4,
],
}
# We can then instantiate the dataset with these delta_timestamps configuration.

View File

@@ -4,7 +4,7 @@ This tutorial will explain the training script, how to use it, and particularly
## The training script
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
LeRobot offers a training script at [`lerobot/scripts/train.py`](../lerobot/scripts/train.py). At a high level it does the following:
- Initialize/load a configuration for the following steps using.
- Instantiates a dataset.
@@ -21,7 +21,7 @@ In the training script, the main function `train` expects a `TrainPipelineConfig
def train(cfg: TrainPipelineConfig):
```
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
@@ -50,7 +50,7 @@ By default, every field takes its default value specified in the dataclass. If a
## Specifying values from the CLI
Let's say that we want to train [Diffusion Policy](../../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
Let's say that we want to train [Diffusion Policy](../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
```bash
python lerobot/scripts/train.py \
--dataset.repo_id=lerobot/pusht \
@@ -60,10 +60,10 @@ python lerobot/scripts/train.py \
Let's break this down:
- To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`.
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../../lerobot/common/policies)
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../../lerobot/common/envs/configs.py)
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../lerobot/common/policies)
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../lerobot/common/envs/configs.py)
Let's see another example. Let's say you've been training [ACT](../../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
Let's see another example. Let's say you've been training [ACT](../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
```bash
python lerobot/scripts/train.py \
--policy.type=act \
@@ -74,7 +74,7 @@ python lerobot/scripts/train.py \
> Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`.
We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task.
Looking at the [`AlohaEnv`](../../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
Looking at the [`AlohaEnv`](../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
```bash
python lerobot/scripts/train.py \
--policy.type=act \

View File

@@ -83,7 +83,7 @@ python lerobot/scripts/configure_motor.py \
--brand dynamixel \
--model xl330-m288 \
--baudrate 1000000 \
--id 1
--ID 1
```
Then unplug your first motor and plug the second motor and set its ID to 2.
@@ -93,7 +93,7 @@ python lerobot/scripts/configure_motor.py \
--brand dynamixel \
--model xl330-m288 \
--baudrate 1000000 \
--id 2
--ID 2
```
Redo the process for all your motors until ID 6.
@@ -830,11 +830,6 @@ It contains:
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
Troubleshooting:
- On Linux, if you encounter any issue during video encoding with `ffmpeg: unknown encoder libsvtav1`, you can:
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform (check the version installed with `ffmpeg -encoders | grep libsvtav1`). If it isn't `ffmpeg 7.X` or lacks `libsvtav1` support, you can explicitly install `ffmpeg 7.X` using: `conda install ffmpeg=7.1.1 -c conda-forge`
- or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1),
- and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/koch_test) that you can obtain by running:

View File

@@ -26,7 +26,10 @@ import math
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
@@ -51,7 +54,24 @@ def main():
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to calculate the loss.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
"action": [
-0.1,
0.0,
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1.0,
1.1,
1.2,
1.3,
1.4,
],
}
# Load the last 10% of episodes of the dataset as a validation set.

View File

@@ -1,4 +0,0 @@
from .camera import Camera
from .configs import CameraConfig
__all__ = ["Camera", "CameraConfig"]

View File

@@ -1,25 +0,0 @@
import abc
import numpy as np
class Camera(abc.ABC):
@abc.abstractmethod
def connect(self):
pass
@abc.abstractmethod
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
pass
@abc.abstractmethod
def async_read(self) -> np.ndarray:
pass
@abc.abstractmethod
def disconnect(self):
pass
def __del__(self):
if getattr(self, "is_connected", False):
self.disconnect()

View File

@@ -1,11 +0,0 @@
import abc
from dataclasses import dataclass
import draccus
@dataclass
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)

View File

@@ -1,4 +0,0 @@
from .camera_realsense import RealSenseCamera
from .configuration_realsense import RealSenseCameraConfig
__all__ = ["RealSenseCamera", "RealSenseCameraConfig"]

View File

@@ -1,4 +0,0 @@
from .camera_opencv import OpenCVCamera
from .configuration_opencv import OpenCVCameraConfig
__all__ = ["OpenCVCamera", "OpenCVCameraConfig"]

View File

@@ -1,37 +0,0 @@
from dataclasses import dataclass
from ..configs import CameraConfig
@CameraConfig.register_subclass("opencv")
@dataclass
class OpenCVCameraConfig(CameraConfig):
"""
Example of tested options for Intel Real Sense D405:
```python
OpenCVCameraConfig(0, 30, 640, 480)
OpenCVCameraConfig(0, 60, 640, 480)
OpenCVCameraConfig(0, 90, 640, 480)
OpenCVCameraConfig(0, 30, 1280, 720)
```
"""
camera_index: int
fps: int | None = None
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
rotation: int | None = None
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")

View File

@@ -1,21 +0,0 @@
from .camera import Camera
from .configs import CameraConfig
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]:
cameras = {}
for key, cfg in camera_configs.items():
if cfg.type == "opencv":
from .opencv import OpenCVCamera
cameras[key] = OpenCVCamera(cfg)
elif cfg.type == "intelrealsense":
from .intel.camera_realsense import RealSenseCamera
cameras[key] = RealSenseCamera(cfg)
else:
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
return cameras

View File

@@ -17,15 +17,12 @@ from pathlib import Path
from huggingface_hub.constants import HF_HOME
OBS_ENV_STATE = "observation.environment_state"
OBS_STATE = "observation.state"
OBS_ENV = "observation.environment_state"
OBS_ROBOT = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
ACTION = "action"
ROBOTS = "robots"
TELEOPERATORS = "teleoperators"
# files & directories
CHECKPOINTS_DIR = "checkpoints"
LAST_CHECKPOINT_LINK = "last"
@@ -37,16 +34,12 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"
# cache dir
default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
if "LEROBOT_HOME" in os.environ:
raise ValueError(
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
)
# cache dir
default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
# calibration dir
default_calibration_path = HF_LEROBOT_HOME / ".calibration"
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()

View File

@@ -19,7 +19,10 @@ from lerobot.common.datasets.utils import load_image_as_numpy
def estimate_num_samples(
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
dataset_len: int,
min_num_samples: int = 100,
max_num_samples: int = 10_000,
power: float = 0.75,
) -> int:
"""Heuristic to estimate the number of samples based on dataset size.
The power controls the sample growth relative to dataset size.
@@ -123,7 +126,9 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
def aggregate_feature_stats(
stats_ft_list: list[dict[str, dict]],
) -> dict[str, dict[str, np.ndarray]]:
"""Aggregates stats for a single feature."""
means = np.stack([s["mean"] for s in stats_ft_list])
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
@@ -152,7 +157,9 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
}
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
def aggregate_stats(
stats_list: list[dict[str, dict]],
) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
The final stats will have the union of all data keys from each of the stats dicts.

View File

@@ -72,7 +72,7 @@ from lerobot.common.datasets.video_utils import (
get_safe_default_codec,
get_video_info,
)
from lerobot.common.robots.utils import Robot
from lerobot.common.robot_devices.robots.utils import Robot
CODEBASE_VERSION = "v2.1"
@@ -318,7 +318,7 @@ class LeRobotDatasetMetadata:
obj.root.mkdir(parents=True, exist_ok=False)
if robot is not None:
features = get_features_from_robot(robot, use_videos)
features = {**(features or {}), **get_features_from_robot(robot)}
robot_type = robot.robot_type
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
@@ -821,7 +821,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
episode_index=self.episode_buffer["episode_index"],
image_key=key,
frame_index=frame_index,
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
@@ -867,7 +869,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key, ft in self.features.items():
# index, episode_index, task_index are already processed above, and image and video
# are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in [
"image",
"video",
]:
continue
episode_buffer[key] = np.stack(episode_buffer[key])

View File

@@ -154,14 +154,32 @@ class OnlineBuffer(torch.utils.data.Dataset):
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
# with real data rather than the dummy initialization.
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
OnlineBuffer.OCCUPANCY_MASK_KEY: {
"dtype": np.dtype("?"),
"shape": (buffer_capacity,),
},
OnlineBuffer.INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.FRAME_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.EPISODE_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.TIMESTAMP_KEY: {
"dtype": np.dtype("float64"),
"shape": (buffer_capacity,),
},
}
for k, v in data_spec.items():
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
complete_data_spec[k] = {
"dtype": v["dtype"],
"shape": (buffer_capacity, *v["shape"]),
}
return complete_data_spec
def add_data(self, data: dict[str, np.ndarray]):

View File

@@ -77,7 +77,9 @@ def check_repo_id(repo_id: str) -> None:
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
def calculate_episode_data_index(
hf_dataset: datasets.Dataset,
) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.

View File

@@ -43,7 +43,10 @@ class EpisodeAwareSampler:
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend(
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
range(
start_index.item() + drop_n_first_frames,
end_index.item() - drop_n_last_frames,
)
)
self.indices = indices

View File

@@ -40,7 +40,7 @@ from lerobot.common.datasets.backward_compatibility import (
BackwardCompatibilityError,
ForwardCompatibilityError,
)
from lerobot.common.robots.utils import Robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
@@ -225,7 +225,10 @@ def load_episodes(local_dir: Path) -> dict:
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer.
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
episode_stats = {
"episode_index": episode_index,
"stats": serialize_dict(episode_stats),
}
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
@@ -409,7 +412,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
if names is not None and names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif key == "observation.environment_state":
type = FeatureType.ENV
@@ -540,7 +543,10 @@ def check_timestamps_sync(
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
delta_timestamps: dict[str, list[float]],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be

View File

@@ -27,7 +27,7 @@ from textwrap import dedent
from lerobot import available_datasets
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
from lerobot.common.robots.aloha.configuration_aloha import AlohaRobotConfig
from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
LOCAL_DIR = Path("data/")
@@ -118,7 +118,10 @@ DATASETS = {
"single_task": "Place the battery into the slot of the remote controller.",
**ALOHA_STATIC_INFO,
},
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
"aloha_static_candy": {
"single_task": "Pick up the candy and unwrap it.",
**ALOHA_STATIC_INFO,
},
"aloha_static_coffee": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
**ALOHA_STATIC_INFO,
@@ -167,13 +170,22 @@ DATASETS = {
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_static_ziploc_slide": {
"single_task": "Slide open the ziploc bag.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_scripted": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_scripted_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_human": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
@@ -194,10 +206,19 @@ DATASETS = {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"pusht": {
"single_task": "Push the T-shaped block onto the T-shaped target.",
**PUSHT_INFO,
},
"pusht_image": {
"single_task": "Push the T-shaped block onto the T-shaped target.",
**PUSHT_INFO,
},
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": {
"single_task": "Put the object into the bin.",
**UNITREEH_INFO,
},
"unitreeh1_two_robot_greeting": {
"single_task": "Greet the other robot with a high five.",
**UNITREEH_INFO,
@@ -207,13 +228,31 @@ DATASETS = {
**UNITREEH_INFO,
},
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_lift_medium_replay": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_lift_medium_replay_image": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"xarm_push_medium_replay": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"xarm_push_medium_replay_image": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"umi_cup_in_the_wild": {
"single_task": "Put the cup on the plate.",
"license": "apache-2.0",

View File

@@ -141,8 +141,8 @@ from lerobot.common.datasets.video_utils import (
get_image_pixel_channels,
get_video_info,
)
from lerobot.common.robots import RobotConfig
from lerobot.common.robots.utils import make_robot_config
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_config
V16 = "v1.6"
V20 = "v2.0"
@@ -379,7 +379,12 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100]
try:
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
subprocess.run(
["git", "rm", "--cached", *files],
cwd=work_dir,
capture_output=True,
check=True,
)
except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:")
print(e.stderr)
@@ -402,7 +407,17 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
repo_url = f"https://huggingface.co/datasets/{repo_id}"
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
subprocess.run(
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
[
"git",
"clone",
"--branch",
branch,
"--single-branch",
"--depth",
"1",
repo_url,
str(work_dir),
],
check=True,
env=env,
)
@@ -410,7 +425,11 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
["git", "lfs", "ls-files", "-n"],
cwd=work_dir,
capture_output=True,
text=True,
check=True,
)
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
return [f for f in video_files if f not in lfs_tracked_files]
@@ -424,7 +443,11 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch
]
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
repo_id=repo_id,
repo_type="dataset",
local_dir=local_dir,
revision=branch,
allow_patterns=video_files,
)
videos_info_dict = {}
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
@@ -451,7 +474,11 @@ def convert_dataset(
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
repo_id=repo_id,
repo_type="dataset",
revision=v1,
local_dir=v1x_dir,
ignore_patterns="videos*/",
)
branch = "main"
if test_branch:
@@ -509,12 +536,21 @@ def convert_dataset(
dataset = dataset.remove_columns(video_keys)
clean_gitattr = Path(
hub_api.hf_hub_download(
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
repo_id=GITATTRIBUTES_REF,
repo_type="dataset",
local_dir=local_dir,
filename=".gitattributes",
)
).absolute()
with tempfile.TemporaryDirectory() as tmp_video_dir:
move_videos(
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
repo_id,
video_keys,
total_episodes,
total_chunks,
Path(tmp_video_dir),
clean_gitattr,
branch,
)
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
for key in video_keys:
@@ -543,7 +579,11 @@ def convert_dataset(
# Episodes
episodes = [
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
{
"episode_index": ep_idx,
"tasks": tasks_by_episodes[ep_idx],
"length": episode_lengths[ep_idx],
}
for ep_idx in episode_indices
]
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
@@ -572,7 +612,12 @@ def convert_dataset(
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
hub_api.delete_folder(
repo_id=repo_id,
path_in_repo="meta_data",
repo_type="dataset",
revision=branch,
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)

View File

@@ -37,8 +37,16 @@ import logging
from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
from lerobot.common.datasets.utils import (
EPISODES_STATS_PATH,
STATS_PATH,
load_stats,
write_info,
)
from lerobot.common.datasets.v21.convert_stats import (
check_aggregate_stats,
convert_stats,
)
V20 = "v2.0"
V21 = "v2.1"
@@ -79,10 +87,16 @@ def convert_dataset(
hub_api = HfApi()
if hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
repo_id=dataset.repo_id,
filename=STATS_PATH,
revision=branch,
repo_type="dataset",
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
path_in_repo=STATS_PATH,
repo_id=dataset.repo_id,
revision=branch,
repo_type="dataset",
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")

View File

@@ -17,7 +17,11 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from tqdm import tqdm
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.common.datasets.compute_stats import (
aggregate_stats,
get_feature_stats,
sample_indices,
)
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import write_episode_stats
@@ -95,5 +99,9 @@ def check_aggregate_stats(
if key in reference_stats and stat in reference_stats[key]:
err_msg = f"feature='{key}' stats='{stat}'"
np.testing.assert_allclose(
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
val,
reference_stats[key][stat],
rtol=rtol,
atol=atol,
err_msg=err_msg,
)

View File

@@ -14,10 +14,12 @@
import abc
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import draccus
from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -53,7 +55,7 @@ class AlohaEnv(EnvConfig):
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_STATE,
"agent_pos": OBS_ROBOT,
"top": f"{OBS_IMAGE}.top",
"pixels/top": f"{OBS_IMAGES}.top",
}
@@ -94,8 +96,8 @@ class PushtEnv(EnvConfig):
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_STATE,
"environment_state": OBS_ENV_STATE,
"agent_pos": OBS_ROBOT,
"environment_state": OBS_ENV,
"pixels": OBS_IMAGE,
}
)
@@ -136,7 +138,7 @@ class XarmEnv(EnvConfig):
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"agent_pos": OBS_STATE,
"agent_pos": OBS_ROBOT,
"pixels": OBS_IMAGE,
}
)
@@ -154,3 +156,135 @@ class XarmEnv(EnvConfig):
"visualization_height": self.visualization_height,
"max_episode_steps": self.episode_length,
}
@dataclass
class VideoRecordConfig:
"""Configuration for video recording in ManiSkill environments."""
enabled: bool = False
record_dir: str = "videos"
trajectory_name: str = "trajectory"
@dataclass
class WrapperConfig:
"""Configuration for environment wrappers."""
joint_masking_action_space: list[bool] | None = None
@dataclass
class EEActionSpaceConfig:
"""Configuration parameters for end-effector action space."""
x_step_size: float
y_step_size: float
z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
use_gamepad: bool = False
@dataclass
class EnvWrapperConfig:
"""Configuration for environment wrappers."""
display_cameras: bool = False
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
resize_size: Optional[Tuple[int, int]] = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
use_gripper: bool = False
gripper_quantization_threshold: float | None = 0.8
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
open_gripper_on_reset: bool = False
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass
class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment."""
robot: Optional[RobotConfig] = None
wrapper: Optional[EnvWrapperConfig] = None
fps: int = 10
name: str = "real_robot"
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
task: str = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
reward_classifier: dict[str, str | None] = field(
default_factory=lambda: {
"pretrained_path": None,
"config_path": None,
}
)
def gym_kwargs(self) -> dict:
return {}
@EnvConfig.register_subclass("maniskill_push")
@dataclass
class ManiskillEnvConfig(EnvConfig):
"""Configuration for the ManiSkill environment."""
name: str = "maniskill/pushcube"
task: str = "PushCube-v1"
image_size: int = 64
control_mode: str = "pd_ee_delta_pose"
state_dim: int = 25
action_dim: int = 7
fps: int = 200
episode_length: int = 50
obs_type: str = "rgb"
render_mode: str = "rgb_array"
render_size: int = 64
device: str = "cuda"
robot: str = "so100" # This is a hack to make the robot config work
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
mock_gripper: bool = False
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(25,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"observation.image": OBS_IMAGE,
"observation.state": OBS_ROBOT,
}
)
reward_classifier: dict[str, str | None] = field(
default_factory=lambda: {
"pretrained_path": None,
"config_path": None,
}
)
@property
def gym_kwargs(self) -> dict:
return {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"max_episode_steps": self.episode_length,
"control_mode": self.control_mode,
"sensor_configs": {"width": self.image_size, "height": self.image_size},
"num_envs": 1,
}

View File

@@ -37,29 +37,35 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
"""
# map to expected inputs for the policy
return_observations = {}
if "pixels" in observations:
if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
else:
imgs = {"observation.image": observations["pixels"]}
# TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation
# And you keep sensor data rgb
for key, img in observations.items():
if "images" not in key:
continue
for imgkey, img in imgs.items():
# TODO(aliberts, rcadene): use transforms.ToTensor()?
# TODO(aliberts, rcadene): use transforms.ToTensor()?
if not torch.is_tensor(img):
img = torch.from_numpy(img)
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
if img.ndim == 3:
img = img.unsqueeze(0)
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
return_observations[imgkey] = img
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
return_observations[key] = img
# obs state agent qpos and qvel
# image
if "environment_state" in observations:
return_observations["observation.environment_state"] = torch.from_numpy(
@@ -68,7 +74,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
# requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
# return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return_observations["observation.state"] = observations["observation.state"].float()
return return_observations
@@ -86,7 +93,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
else:
feature = ft
policy_key = env_cfg.features_map[key]
policy_key = env_cfg.features_map.get(key, key)
policy_features[policy_key] = feature
return policy_features
@@ -101,7 +108,9 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
with warnings.catch_warnings():
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
if not (
hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")
):
warnings.warn(
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
UserWarning,
@@ -115,7 +124,9 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
)
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
def add_envs_task(
env: gym.vector.VectorEnv, observation: dict[str, Any]
) -> dict[str, Any]:
"""Adds task feature to the observation dict with respect to the first environment attribute."""
if hasattr(env.envs[0], "task_description"):
observation["task"] = env.call("task_description")

View File

@@ -1,17 +0,0 @@
class DeviceNotConnectedError(ConnectionError):
"""Exception raised when the device is not connected."""
def __init__(self, message="This device is not connected. Try calling `connect()` first."):
self.message = message
super().__init__(self.message)
class DeviceAlreadyConnectedError(ConnectionError):
"""Exception raised when the device is already connected."""
def __init__(
self,
message="This device is already connected. Try not calling `connect()` twice.",
):
self.message = message
super().__init__(self.message)

View File

@@ -1 +0,0 @@
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus

View File

@@ -1,3 +0,0 @@
from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode
from .dynamixel_calibration import run_arm_calibration
from .tables import *

View File

@@ -1,206 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(aliberts): Should we implement FastSyncRead/Write?
# https://github.com/ROBOTIS-GIT/DynamixelSDK/pull/643
# https://github.com/ROBOTIS-GIT/DynamixelSDK/releases/tag/3.8.2
# https://emanual.robotis.com/docs/en/dxl/protocol2/#fast-sync-read-0x8a
# -> Need to check compatibility across models
import logging
from copy import deepcopy
from enum import Enum
from lerobot.common.utils.encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value
from .tables import (
AVAILABLE_BAUDRATES,
MODEL_BAUDRATE_TABLE,
MODEL_CONTROL_TABLE,
MODEL_ENCODING_TABLE,
MODEL_NUMBER_TABLE,
MODEL_RESOLUTION,
)
PROTOCOL_VERSION = 2.0
BAUDRATE = 1_000_000
DEFAULT_TIMEOUT_MS = 1000
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
logger = logging.getLogger(__name__)
class OperatingMode(Enum):
# DYNAMIXEL only controls current(torque) regardless of speed and position. This mode is ideal for a
# gripper or a system that only uses current(torque) control or a system that has additional
# velocity/position controllers.
CURRENT = 0
# This mode controls velocity. This mode is identical to the Wheel Mode(endless) from existing DYNAMIXEL.
# This mode is ideal for wheel-type robots.
VELOCITY = 1
# This mode controls position. This mode is identical to the Joint Mode from existing DYNAMIXEL. Operating
# position range is limited by the Max Position Limit(48) and the Min Position Limit(52). This mode is
# ideal for articulated robots that each joint rotates less than 360 degrees.
POSITION = 3
# This mode controls position. This mode is identical to the Multi-turn Position Control from existing
# DYNAMIXEL. 512 turns are supported(-256[rev] ~ 256[rev]). This mode is ideal for multi-turn wrists or
# conveyer systems or a system that requires an additional reduction gear. Note that Max Position
# Limit(48), Min Position Limit(52) are not used on Extended Position Control Mode.
EXTENDED_POSITION = 4
# This mode controls both position and current(torque). Up to 512 turns are supported (-256[rev] ~
# 256[rev]). This mode is ideal for a system that requires both position and current control such as
# articulated robots or grippers.
CURRENT_POSITION = 5
# This mode directly controls PWM output. (Voltage Control Mode)
PWM = 16
class DriveMode(Enum):
NON_INVERTED = 0
INVERTED = 1
class TorqueMode(Enum):
ENABLED = 1
DISABLED = 0
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
import dynamixel_sdk as dxl
if length == 1:
data = [value]
elif length == 2:
data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)]
elif length == 4:
data = [
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
]
return data
class DynamixelMotorsBus(MotorsBus):
"""
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
the motors. For more info, see the Dynamixel SDK Documentation:
https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20
"""
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
default_timeout = DEFAULT_TIMEOUT_MS
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
model_encoding_table = deepcopy(MODEL_ENCODING_TABLE)
model_number_table = deepcopy(MODEL_NUMBER_TABLE)
model_resolution_table = deepcopy(MODEL_RESOLUTION)
normalized_data = deepcopy(NORMALIZED_DATA)
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
super().__init__(port, motors, calibration)
import dynamixel_sdk as dxl
self.port_handler = dxl.PortHandler(self.port)
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
self.sync_writer = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0)
self._comm_success = dxl.COMM_SUCCESS
self._no_error = 0x00
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
pass
def _handshake(self) -> None:
self._assert_motors_exist()
def configure_motors(self) -> None:
# By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
for motor in self.motors:
self.write("Return_Delay_Time", motor, 0)
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for name in self._get_motors_list(motors):
self.write("Torque_Enable", name, TorqueMode.DISABLED.value, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for name in self._get_motors_list(motors):
self.write("Torque_Enable", name, TorqueMode.ENABLED.value, num_retry=num_retry)
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
for id_ in ids_values:
model = self._id_to_model(id_)
encoding_table = self.model_encoding_table.get(model)
if encoding_table and data_name in encoding_table:
n_bytes = encoding_table[data_name]
ids_values[id_] = encode_twos_complement(ids_values[id_], n_bytes)
return ids_values
def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
for id_ in ids_values:
model = self._id_to_model(id_)
encoding_table = self.model_encoding_table.get(model)
if encoding_table and data_name in encoding_table:
n_bytes = encoding_table[data_name]
ids_values[id_] = decode_twos_complement(ids_values[id_], n_bytes)
return ids_values
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
"""
On Dynamixel Motors:
Present_Position = Actual_Position + Homing_Offset
"""
half_turn_homings = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
half_turn_homings[motor] = int(max_res / 2) - pos
return half_turn_homings
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
return _split_into_byte_chunks(value, length)
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
for n_try in range(1 + num_retry):
data_list, comm = self.packet_handler.broadcastPing(self.port_handler)
if self._is_comm_success(comm):
break
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
logger.debug(self.packet_handler.getTxRxResult(comm))
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
return {id_: data[0] for id_, data in data_list.items()}

View File

@@ -1,162 +0,0 @@
# {data_name: (address, size_byte)}
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#control-table
X_SERIES_CONTROL_TABLE = {
"Model_Number": (0, 2),
"Model_Information": (2, 4),
"Firmware_Version": (6, 1),
"ID": (7, 1),
"Baud_Rate": (8, 1),
"Return_Delay_Time": (9, 1),
"Drive_Mode": (10, 1),
"Operating_Mode": (11, 1),
"Secondary_ID": (12, 1),
"Protocol_Type": (13, 1),
"Homing_Offset": (20, 4),
"Moving_Threshold": (24, 4),
"Temperature_Limit": (31, 1),
"Max_Voltage_Limit": (32, 2),
"Min_Voltage_Limit": (34, 2),
"PWM_Limit": (36, 2),
"Current_Limit": (38, 2),
"Acceleration_Limit": (40, 4),
"Velocity_Limit": (44, 4),
"Max_Position_Limit": (48, 4),
"Min_Position_Limit": (52, 4),
"Shutdown": (63, 1),
"Torque_Enable": (64, 1),
"LED": (65, 1),
"Status_Return_Level": (68, 1),
"Registered_Instruction": (69, 1),
"Hardware_Error_Status": (70, 1),
"Velocity_I_Gain": (76, 2),
"Velocity_P_Gain": (78, 2),
"Position_D_Gain": (80, 2),
"Position_I_Gain": (82, 2),
"Position_P_Gain": (84, 2),
"Feedforward_2nd_Gain": (88, 2),
"Feedforward_1st_Gain": (90, 2),
"Bus_Watchdog": (98, 1),
"Goal_PWM": (100, 2),
"Goal_Current": (102, 2),
"Goal_Velocity": (104, 4),
"Profile_Acceleration": (108, 4),
"Profile_Velocity": (112, 4),
"Goal_Position": (116, 4),
"Realtime_Tick": (120, 2),
"Moving": (122, 1),
"Moving_Status": (123, 1),
"Present_PWM": (124, 2),
"Present_Current": (126, 2),
"Present_Velocity": (128, 4),
"Present_Position": (132, 4),
"Velocity_Trajectory": (136, 4),
"Position_Trajectory": (140, 4),
"Present_Input_Voltage": (144, 2),
"Present_Temperature": (146, 1),
}
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#baud-rate8
X_SERIES_BAUDRATE_TABLE = {
0: 9_600,
1: 57_600,
2: 115_200,
3: 1_000_000,
4: 2_000_000,
5: 3_000_000,
6: 4_000_000,
}
# {data_name: size_byte}
X_SERIES_ENCODINGS_TABLE = {
"Homing_Offset": X_SERIES_CONTROL_TABLE["Homing_Offset"][1],
"Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1],
"Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1],
"Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1],
"Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1],
"Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1],
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],
}
MODEL_ENCODING_TABLE = {
"x_series": X_SERIES_ENCODINGS_TABLE,
"xl330-m077": X_SERIES_ENCODINGS_TABLE,
"xl330-m288": X_SERIES_ENCODINGS_TABLE,
"xl430-w250": X_SERIES_ENCODINGS_TABLE,
"xm430-w350": X_SERIES_ENCODINGS_TABLE,
"xm540-w270": X_SERIES_ENCODINGS_TABLE,
"xc430-w150": X_SERIES_ENCODINGS_TABLE,
}
# {model: model_resolution}
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#specifications
MODEL_RESOLUTION = {
"x_series": 4096,
"xl330-m077": 4096,
"xl330-m288": 4096,
"xl430-w250": 4096,
"xm430-w350": 4096,
"xm540-w270": 4096,
"xc430-w150": 4096,
}
# {model: model_number}
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#control-table-of-eeprom-area
MODEL_NUMBER_TABLE = {
"xl330-m077": 1190,
"xl330-m288": 1200,
"xl430-w250": 1060,
"xm430-w350": 1020,
"xm540-w270": 1120,
"xc430-w150": 1070,
}
# {model: available_operating_modes}
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#operating-mode11
MODEL_OPERATING_MODES = {
"xl330-m077": [0, 1, 3, 4, 5, 16],
"xl330-m288": [0, 1, 3, 4, 5, 16],
"xl430-w250": [1, 3, 4, 16],
"xm430-w350": [0, 1, 3, 4, 5, 16],
"xm540-w270": [0, 1, 3, 4, 5, 16],
"xc430-w150": [1, 3, 4, 16],
}
MODEL_CONTROL_TABLE = {
"x_series": X_SERIES_CONTROL_TABLE,
"xl330-m077": X_SERIES_CONTROL_TABLE,
"xl330-m288": X_SERIES_CONTROL_TABLE,
"xl430-w250": X_SERIES_CONTROL_TABLE,
"xm430-w350": X_SERIES_CONTROL_TABLE,
"xm540-w270": X_SERIES_CONTROL_TABLE,
"xc430-w150": X_SERIES_CONTROL_TABLE,
}
MODEL_BAUDRATE_TABLE = {
"x_series": X_SERIES_BAUDRATE_TABLE,
"xl330-m077": X_SERIES_BAUDRATE_TABLE,
"xl330-m288": X_SERIES_BAUDRATE_TABLE,
"xl430-w250": X_SERIES_BAUDRATE_TABLE,
"xm430-w350": X_SERIES_BAUDRATE_TABLE,
"xm540-w270": X_SERIES_BAUDRATE_TABLE,
"xc430-w150": X_SERIES_BAUDRATE_TABLE,
}
AVAILABLE_BAUDRATES = [
9_600,
19_200,
38_400,
57_600,
115_200,
230_400,
460_800,
500_000,
576_000,
921_600,
1_000_000,
1_152_000,
2_000_000,
2_500_000,
3_000_000,
3_500_000,
4_000_000,
]

View File

@@ -1,2 +0,0 @@
from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode
from .tables import *

View File

@@ -1,367 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from copy import deepcopy
from enum import Enum
from pprint import pformat
from lerobot.common.utils.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value
from .tables import (
FIRMWARE_MAJOR_VERSION,
FIRMWARE_MINOR_VERSION,
MODEL_BAUDRATE_TABLE,
MODEL_CONTROL_TABLE,
MODEL_ENCODING_TABLE,
MODEL_NUMBER,
MODEL_NUMBER_TABLE,
MODEL_PROTOCOL,
MODEL_RESOLUTION,
SCAN_BAUDRATES,
)
DEFAULT_PROTOCOL_VERSION = 0
BAUDRATE = 1_000_000
DEFAULT_TIMEOUT_MS = 1000
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
logger = logging.getLogger(__name__)
class OperatingMode(Enum):
# position servo mode
POSITION = 0
# The motor is in constant speed mode, which is controlled by parameter 0x2e, and the highest bit 15 is
# the direction bit
VELOCITY = 1
# PWM open-loop speed regulation mode, with parameter 0x2c running time parameter control, bit11 as
# direction bit
PWM = 2
# In step servo mode, the number of step progress is represented by parameter 0x2a, and the highest bit 15
# is the direction bit
STEP = 3
class DriveMode(Enum):
NON_INVERTED = 0
INVERTED = 1
class TorqueMode(Enum):
ENABLED = 1
DISABLED = 0
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
import scservo_sdk as scs
if length == 1:
data = [value]
elif length == 2:
data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)]
elif length == 4:
data = [
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
]
return data
def patch_setPacketTimeout(self, packet_length): # noqa: N802
"""
HACK: This patches the PortHandler behavior to set the correct packet timeouts.
It fixes https://gitee.com/ftservo/SCServoSDK/issues/IBY2S6
The bug is fixed on the official Feetech SDK repo (https://gitee.com/ftservo/FTServo_Python)
but because that version is not published on PyPI, we rely on the (unofficial) on that is, which needs
patching.
"""
self.packet_start_time = self.getCurrentTime()
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
class FeetechMotorsBus(MotorsBus):
"""
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
"""
available_baudrates = deepcopy(SCAN_BAUDRATES)
default_timeout = DEFAULT_TIMEOUT_MS
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
model_encoding_table = deepcopy(MODEL_ENCODING_TABLE)
model_number_table = deepcopy(MODEL_NUMBER_TABLE)
model_resolution_table = deepcopy(MODEL_RESOLUTION)
normalized_data = deepcopy(NORMALIZED_DATA)
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
protocol_version: int = DEFAULT_PROTOCOL_VERSION,
):
super().__init__(port, motors, calibration)
self.protocol_version = protocol_version
self._assert_same_protocol()
import scservo_sdk as scs
self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
self.port_handler, scs.PortHandler
)
self.packet_handler = scs.PacketHandler(protocol_version)
self.sync_reader = scs.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
self.sync_writer = scs.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0)
self._comm_success = scs.COMM_SUCCESS
self._no_error = 0x00
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
raise ValueError(f"Some motors are incompatible with protocol_version={self.protocol_version}")
def _assert_same_protocol(self) -> None:
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
raise RuntimeError("Some motors use an incompatible protocol.")
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
if instruction_name == "sync_read" and self.protocol_version == 1:
raise NotImplementedError(
"'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' sequentially instead."
)
if instruction_name == "broadcast_ping" and self.protocol_version == 1:
raise NotImplementedError(
"'Broadcast Ping' is not available with Feetech motors using Protocol 1. Use 'Ping' sequentially instead."
)
def _assert_same_firmware(self) -> None:
firmware_versions = self._read_firmware_version(self.ids)
if len(set(firmware_versions.values())) != 1:
raise RuntimeError(
"Some Motors use different firmware versions. Update their firmware first using Feetech's software. "
"Visit https://www.feetechrc.com/software."
)
def _handshake(self) -> None:
self._assert_motors_exist()
self._assert_same_firmware()
def configure_motors(self) -> None:
for motor in self.motors:
# By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
self.write("Return_Delay_Time", motor, 0)
# Set 'Maximum_Acceleration' to 254 to speedup acceleration and deceleration of the motors.
# Note: this address is not in the official STS3215 Memory Table
self.write("Maximum_Acceleration", motor, 254)
self.write("Acceleration", motor, 254)
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
"""
On Feetech Motors:
Present_Position = Actual_Position - Homing_Offset
"""
half_turn_homings = {}
for motor, pos in positions.items():
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
half_turn_homings[motor] = pos - int(max_res / 2)
return half_turn_homings
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for name in self._get_motors_list(motors):
self.write("Torque_Enable", name, TorqueMode.DISABLED.value, num_retry=num_retry)
self.write("Lock", name, 0, num_retry=num_retry)
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
for name in self._get_motors_list(motors):
self.write("Torque_Enable", name, TorqueMode.ENABLED.value, num_retry=num_retry)
self.write("Lock", name, 1, num_retry=num_retry)
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
for id_ in ids_values:
model = self._id_to_model(id_)
encoding_table = self.model_encoding_table.get(model)
if encoding_table and data_name in encoding_table:
sign_bit = encoding_table[data_name]
ids_values[id_] = encode_sign_magnitude(ids_values[id_], sign_bit)
return ids_values
def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
for id_ in ids_values:
model = self._id_to_model(id_)
encoding_table = self.model_encoding_table.get(model)
if encoding_table and data_name in encoding_table:
sign_bit = encoding_table[data_name]
ids_values[id_] = decode_sign_magnitude(ids_values[id_], sign_bit)
return ids_values
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
return _split_into_byte_chunks(value, length)
def _broadcast_ping_p1(
self, known_motors_only: bool = True, n_motors: int | None = None, num_retry: int = 0
) -> dict[int, int]:
if known_motors_only:
ids = self.ids
else:
import scservo_sdk as scs
ids = range(scs.MAX_ID + 1)
ids_models = {}
motors_found = 0
for id_ in ids:
model_number = self.ping(id_, num_retry)
if model_number is not None:
ids_models[id_] = model_number
motors_found += 1
if motors_found >= n_motors:
break
return ids_models
def _broadcast_ping_p0(self) -> tuple[dict[int, int], int]:
import scservo_sdk as scs
data_list = {}
status_length = 6
rx_length = 0
wait_length = status_length * scs.MAX_ID
txpacket = [0] * 6
tx_time_per_byte = (1000.0 / self.port_handler.getBaudRate()) * 10.0
txpacket[scs.PKT_ID] = scs.BROADCAST_ID
txpacket[scs.PKT_LENGTH] = 2
txpacket[scs.PKT_INSTRUCTION] = scs.INST_PING
result = self.packet_handler.txPacket(self.port_handler, txpacket)
if result != scs.COMM_SUCCESS:
self.port_handler.is_using = False
return data_list, result
# set rx timeout
self.port_handler.setPacketTimeoutMillis((wait_length * tx_time_per_byte) + (3.0 * scs.MAX_ID) + 16.0)
rxpacket = []
while True:
rxpacket += self.port_handler.readPort(wait_length - rx_length)
rx_length = len(rxpacket)
if self.port_handler.isPacketTimeout(): # or rx_length >= wait_length
break
self.port_handler.is_using = False
if rx_length == 0:
return data_list, scs.COMM_RX_TIMEOUT
while True:
if rx_length < status_length:
return data_list, scs.COMM_RX_CORRUPT
# find packet header
for idx in range(0, (rx_length - 1)):
if (rxpacket[idx] == 0xFF) and (rxpacket[idx + 1] == 0xFF):
break
if idx == 0: # found at the beginning of the packet
# calculate checksum
checksum = 0
for idx in range(2, status_length - 1): # except header & checksum
checksum += rxpacket[idx]
checksum = ~checksum & 0xFF
if rxpacket[status_length - 1] == checksum:
result = scs.COMM_SUCCESS
data_list[rxpacket[scs.PKT_ID]] = rxpacket[scs.PKT_ERROR]
del rxpacket[0:status_length]
rx_length = rx_length - status_length
if rx_length == 0:
return data_list, result
else:
result = scs.COMM_RX_CORRUPT
# remove header (0xFF 0xFF)
del rxpacket[0:2]
rx_length = rx_length - 2
else:
# remove unnecessary packets
del rxpacket[0:idx]
rx_length = rx_length - idx
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
self._assert_protocol_is_compatible("broadcast_ping")
for n_try in range(1 + num_retry):
ids_status, comm = self._broadcast_ping_p0()
if self._is_comm_success(comm):
break
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
logger.debug(self.packet_handler.getTxRxResult(comm))
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
return
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
if ids_errors:
display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()}
logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}")
return self._read_model_number(list(ids_status), raise_on_error)
def _read_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, str]:
firmware_versions = {}
for id_ in motor_ids:
firm_ver_major, comm, error = self._read(
*FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error
)
if not self._is_comm_success(comm) or self._is_error(error):
return
firm_ver_minor, comm, error = self._read(
*FIRMWARE_MINOR_VERSION, id_, raise_on_error=raise_on_error
)
if not self._is_comm_success(comm) or self._is_error(error):
return
firmware_versions[id_] = f"{firm_ver_major}.{firm_ver_minor}"
return firmware_versions
def _read_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
model_numbers = {}
for id_ in motor_ids:
model_nb, comm, error = self._read(*MODEL_NUMBER, id_, raise_on_error=raise_on_error)
if not self._is_comm_success(comm) or self._is_error(error):
return
model_numbers[id_] = model_nb
return model_numbers

View File

@@ -1,202 +0,0 @@
FIRMWARE_MAJOR_VERSION = (0, 1)
FIRMWARE_MINOR_VERSION = (1, 1)
MODEL_NUMBER = (3, 2)
# See this link for STS3215 Memory Table:
# https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true
# data_name: (address, size_byte)
STS_SMS_SERIES_CONTROL_TABLE = {
# EPROM
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
"Model_Number": MODEL_NUMBER, # read-only
"ID": (5, 1),
"Baud_Rate": (6, 1),
"Return_Delay_Time": (7, 1),
"Response_Status_Level": (8, 1),
"Min_Position_Limit": (9, 2),
"Max_Position_Limit": (11, 2),
"Max_Temperature_Limit": (13, 1),
"Max_Voltage_Limit": (14, 1),
"Min_Voltage_Limit": (15, 1),
"Max_Torque_Limit": (16, 2),
"Phase": (18, 1),
"Unloading_Condition": (19, 1),
"LED_Alarm_Condition": (20, 1),
"P_Coefficient": (21, 1),
"D_Coefficient": (22, 1),
"I_Coefficient": (23, 1),
"Minimum_Startup_Force": (24, 2),
"CW_Dead_Zone": (26, 1),
"CCW_Dead_Zone": (27, 1),
"Protection_Current": (28, 2),
"Angular_Resolution": (30, 1),
"Homing_Offset": (31, 2),
"Operating_Mode": (33, 1),
"Protective_Torque": (34, 1),
"Protection_Time": (35, 1),
"Overload_Torque": (36, 1),
"Speed_closed_loop_P_proportional_coefficient": (37, 1),
"Over_Current_Protection_Time": (38, 1),
"Velocity_closed_loop_I_integral_coefficient": (39, 1),
# SRAM
"Torque_Enable": (40, 1),
"Acceleration": (41, 1),
"Goal_Position": (42, 2),
"Goal_Time": (44, 2),
"Goal_Speed": (46, 2),
"Torque_Limit": (48, 2),
"Lock": (55, 1),
"Present_Position": (56, 2), # read-only
"Present_Speed": (58, 2), # read-only
"Present_Load": (60, 2), # read-only
"Present_Voltage": (62, 1), # read-only
"Present_Temperature": (63, 1), # read-only
"Status": (65, 1), # read-only
"Moving": (66, 1), # read-only
"Present_Current": (69, 2), # read-only
# Not in the Memory Table
"Maximum_Acceleration": (85, 2),
}
SCS_SERIES_CONTROL_TABLE = {
# EPROM
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
"Model_Number": MODEL_NUMBER, # read-only
"ID": (5, 1),
"Baud_Rate": (6, 1),
"Return_Delay": (7, 1),
"Response_Status_Level": (8, 1),
"Min_Position_Limit": (9, 2),
"Max_Position_Limit": (11, 2),
"Max_Temperature_Limit": (13, 1),
"Max_Voltage_Limit": (14, 1),
"Min_Voltage_Limit": (15, 1),
"Max_Torque_Limit": (16, 2),
"Phase": (18, 1),
"Unloading_Condition": (19, 1),
"LED_Alarm_Condition": (20, 1),
"P_Coefficient": (21, 1),
"D_Coefficient": (22, 1),
"I_Coefficient": (23, 1),
"Minimum_Startup_Force": (24, 2),
"CW_Dead_Zone": (26, 1),
"CCW_Dead_Zone": (27, 1),
"Protective_Torque": (37, 1),
"Protection_Time": (38, 1),
# SRAM
"Torque_Enable": (40, 1),
"Acceleration": (41, 1),
"Goal_Position": (42, 2),
"Running_Time": (44, 2),
"Goal_Speed": (46, 2),
"Lock": (48, 1),
"Present_Position": (56, 2), # read-only
"Present_Speed": (58, 2), # read-only
"Present_Load": (60, 2), # read-only
"Present_Voltage": (62, 1), # read-only
"Present_Temperature": (63, 1), # read-only
"Sync_Write_Flag": (64, 1), # read-only
"Status": (65, 1), # read-only
"Moving": (66, 1), # read-only
}
STS_SMS_SERIES_BAUDRATE_TABLE = {
0: 1_000_000,
1: 500_000,
2: 250_000,
3: 128_000,
4: 115_200,
5: 57_600,
6: 38_400,
7: 19_200,
}
SCS_SERIES_BAUDRATE_TABLE = {
0: 1_000_000,
1: 500_000,
2: 250_000,
3: 128_000,
4: 115_200,
5: 57_600,
6: 38_400,
7: 19_200,
}
MODEL_CONTROL_TABLE = {
"sts_series": STS_SMS_SERIES_CONTROL_TABLE,
"scs_series": SCS_SERIES_CONTROL_TABLE,
"sms_series": STS_SMS_SERIES_CONTROL_TABLE,
"sts3215": STS_SMS_SERIES_CONTROL_TABLE,
"sts3250": STS_SMS_SERIES_CONTROL_TABLE,
"scs0009": SCS_SERIES_CONTROL_TABLE,
"sm8512bl": STS_SMS_SERIES_CONTROL_TABLE,
}
MODEL_RESOLUTION = {
"sts_series": 4096,
"sms_series": 4096,
"scs_series": 1024,
"sts3215": 4096,
"sts3250": 4096,
"sm8512bl": 65536,
"scs0009": 1024,
}
MODEL_BAUDRATE_TABLE = {
"sts_series": STS_SMS_SERIES_BAUDRATE_TABLE,
"sms_series": STS_SMS_SERIES_BAUDRATE_TABLE,
"scs_series": SCS_SERIES_BAUDRATE_TABLE,
"sm8512bl": STS_SMS_SERIES_BAUDRATE_TABLE,
"sts3215": STS_SMS_SERIES_BAUDRATE_TABLE,
"sts3250": STS_SMS_SERIES_BAUDRATE_TABLE,
"scs0009": SCS_SERIES_BAUDRATE_TABLE,
}
# Sign-Magnitude encoding bits
STS_SMS_SERIES_ENCODINGS_TABLE = {
"Homing_Offset": 11,
"Goal_Speed": 15,
}
MODEL_ENCODING_TABLE = {
"sts_series": STS_SMS_SERIES_ENCODINGS_TABLE,
"sms_series": STS_SMS_SERIES_ENCODINGS_TABLE,
"scs_series": {},
"sts3215": STS_SMS_SERIES_ENCODINGS_TABLE,
"sts3250": STS_SMS_SERIES_ENCODINGS_TABLE,
"sm8512bl": STS_SMS_SERIES_ENCODINGS_TABLE,
"scs0009": {},
}
SCAN_BAUDRATES = [
4_800,
9_600,
14_400,
19_200,
38_400,
57_600,
115_200,
128_000,
250_000,
500_000,
1_000_000,
]
MODEL_NUMBER_TABLE = {
"sts3215": 777,
"sts3250": 2825,
"sm8512bl": 11272,
"scs0009": 1284,
}
MODEL_PROTOCOL = {
"sts_series": 0,
"sms_series": 0,
"scs_series": 1,
"sts3215": 0,
"sts3250": 0,
"sm8512bl": 0,
"scs0009": 1,
}

View File

@@ -1,987 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: N802
# This noqa is for the Protocols classes: PortHandler, PacketHandler GroupSyncRead/Write
# TODO(aliberts): Add block noqa when feature below is available
# https://github.com/astral-sh/ruff/issues/3711
import abc
import logging
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from pprint import pformat
from typing import Protocol, TypeAlias
import serial
from deepdiff import DeepDiff
from tqdm import tqdm
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.utils.utils import enter_pressed, move_cursor_up
NameOrID: TypeAlias = str | int
Value: TypeAlias = int | float
MAX_ID_RANGE = 252
logger = logging.getLogger(__name__)
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
ctrl_table = model_ctrl_table.get(model)
if ctrl_table is None:
raise KeyError(f"Control table for {model=} not found.")
return ctrl_table
def get_address(model_ctrl_table: dict[str, dict], model: str, data_name: str) -> tuple[int, int]:
ctrl_table = get_ctrl_table(model_ctrl_table, model)
addr_bytes = ctrl_table.get(data_name)
if addr_bytes is None:
raise KeyError(f"Address for '{data_name}' not found in {model} control table.")
return addr_bytes
def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None:
all_addr = []
all_bytes = []
for model in motor_models:
addr, bytes = get_address(model_ctrl_table, model, data_name)
all_addr.append(addr)
all_bytes.append(bytes)
if len(set(all_addr)) != 1:
raise NotImplementedError(
f"At least two motor models use a different address for `data_name`='{data_name}'"
f"({list(zip(motor_models, all_addr, strict=False))})."
)
if len(set(all_bytes)) != 1:
raise NotImplementedError(
f"At least two motor models use a different bytes representation for `data_name`='{data_name}'"
f"({list(zip(motor_models, all_bytes, strict=False))})."
)
class MotorNormMode(Enum):
DEGREE = 0
RANGE_0_100 = 1
RANGE_M100_100 = 2
VELOCITY = 3
@dataclass
class MotorCalibration:
id: int
drive_mode: int
homing_offset: int
range_min: int
range_max: int
@dataclass
class Motor:
id: int
model: str
norm_mode: MotorNormMode
class JointOutOfRangeError(Exception):
def __init__(self, message="Joint is out of range"):
self.message = message
super().__init__(self.message)
class PortHandler(Protocol):
def __init__(self, port_name):
self.is_open: bool
self.baudrate: int
self.packet_start_time: float
self.packet_timeout: float
self.tx_time_per_byte: float
self.is_using: bool
self.port_name: str
self.ser: serial.Serial
def openPort(self): ...
def closePort(self): ...
def clearPort(self): ...
def setPortName(self, port_name): ...
def getPortName(self): ...
def setBaudRate(self, baudrate): ...
def getBaudRate(self): ...
def getBytesAvailable(self): ...
def readPort(self, length): ...
def writePort(self, packet): ...
def setPacketTimeout(self, packet_length): ...
def setPacketTimeoutMillis(self, msec): ...
def isPacketTimeout(self): ...
def getCurrentTime(self): ...
def getTimeSinceStart(self): ...
def setupPort(self, cflag_baud): ...
def getCFlagBaud(self, baudrate): ...
class PacketHandler(Protocol):
def getTxRxResult(self, result): ...
def getRxPacketError(self, error): ...
def txPacket(self, port, txpacket): ...
def rxPacket(self, port): ...
def txRxPacket(self, port, txpacket): ...
def ping(self, port, id): ...
def action(self, port, id): ...
def readTx(self, port, id, address, length): ...
def readRx(self, port, id, length): ...
def readTxRx(self, port, id, address, length): ...
def read1ByteTx(self, port, id, address): ...
def read1ByteRx(self, port, id): ...
def read1ByteTxRx(self, port, id, address): ...
def read2ByteTx(self, port, id, address): ...
def read2ByteRx(self, port, id): ...
def read2ByteTxRx(self, port, id, address): ...
def read4ByteTx(self, port, id, address): ...
def read4ByteRx(self, port, id): ...
def read4ByteTxRx(self, port, id, address): ...
def writeTxOnly(self, port, id, address, length, data): ...
def writeTxRx(self, port, id, address, length, data): ...
def write1ByteTxOnly(self, port, id, address, data): ...
def write1ByteTxRx(self, port, id, address, data): ...
def write2ByteTxOnly(self, port, id, address, data): ...
def write2ByteTxRx(self, port, id, address, data): ...
def write4ByteTxOnly(self, port, id, address, data): ...
def write4ByteTxRx(self, port, id, address, data): ...
def regWriteTxOnly(self, port, id, address, length, data): ...
def regWriteTxRx(self, port, id, address, length, data): ...
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
class GroupSyncRead(Protocol):
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.last_result: bool
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def makeParam(self): ...
def addParam(self, id): ...
def removeParam(self, id): ...
def clearParam(self): ...
def txPacket(self): ...
def rxPacket(self): ...
def txRxPacket(self): ...
def isAvailable(self, id, address, data_length): ...
def getData(self, id, address, data_length): ...
class GroupSyncWrite(Protocol):
def __init__(self, port, ph, start_address, data_length):
self.port: str
self.ph: PortHandler
self.start_address: int
self.data_length: int
self.is_param_changed: bool
self.param: list
self.data_dict: dict
def makeParam(self): ...
def addParam(self, id, data): ...
def removeParam(self, id): ...
def changeParam(self, id, data): ...
def clearParam(self): ...
def txPacket(self): ...
class MotorsBus(abc.ABC):
"""The main LeRobot class for implementing motors buses.
There are currently two implementations of this abstract class:
- DynamixelMotorsBus
- FeetechMotorsBus
Note: This class may evolve in the future should we add support for other manufacturers SDKs.
A MotorsBus allows to efficiently read and write to the attached motors.
It represents several motors daisy-chained together and connected through a serial port.
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script:
```bash
python lerobot/scripts/find_motors_bus_port.py
>>> Finding all available ports for the MotorsBus.
>>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
>>> Remove the usb cable from your MotorsBus and press Enter when done.
>>> The port of this MotorsBus is /dev/tty.usbmodem575E0031751.
>>> Reconnect the usb cable.
```
Example of usage for 1 Feetech sts3215 motor connected to the bus:
```python
motors_bus = FeetechMotorsBus(
port="/dev/tty.usbmodem575E0031751",
motors={"gripper": (6, "sts3215")},
)
motors_bus.connect()
position = motors_bus.read("Present_Position")
# Move from a few motor steps as an example
few_steps = 30
motors_bus.write("Goal_Position", position + few_steps)
# When done, properly disconnect the port using
motors_bus.disconnect()
```
"""
available_baudrates: list[int]
default_timeout: int
model_baudrate_table: dict[str, dict]
model_ctrl_table: dict[str, dict]
model_encoding_table: dict[str, dict]
model_number_table: dict[str, int]
model_resolution_table: dict[str, int]
normalized_data: list[str]
def __init__(
self,
port: str,
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
self.port = port
self.motors = motors
self.calibration = calibration if calibration else {}
self.port_handler: PortHandler
self.packet_handler: PacketHandler
self.sync_reader: GroupSyncRead
self.sync_writer: GroupSyncWrite
self._comm_success: int
self._no_error: int
self._id_to_model_dict = {m.id: m.model for m in self.motors.values()}
self._id_to_name_dict = {m.id: name for name, m in self.motors.items()}
self._model_nb_to_model_dict = {v: k for k, v in self.model_number_table.items()}
self._validate_motors()
def __len__(self):
return len(self.motors)
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Port: '{self.port}',\n"
f" Motors: \n{pformat(self.motors, indent=8, sort_dicts=False)},\n"
")',\n"
)
@cached_property
def _has_different_ctrl_tables(self) -> bool:
if len(self.models) < 2:
return False
first_table = self.model_ctrl_table[self.models[0]]
return any(
DeepDiff(first_table, get_ctrl_table(self.model_ctrl_table, model)) for model in self.models[1:]
)
@cached_property
def names(self) -> list[str]:
return list(self.motors)
@cached_property
def models(self) -> list[str]:
return [m.model for m in self.motors.values()]
@cached_property
def ids(self) -> list[int]:
return [m.id for m in self.motors.values()]
def _model_nb_to_model(self, motor_nb: int) -> str:
return self._model_nb_to_model_dict[motor_nb]
def _id_to_model(self, motor_id: int) -> str:
return self._id_to_model_dict[motor_id]
def _id_to_name(self, motor_id: int) -> str:
return self._id_to_name_dict[motor_id]
def _get_motor_id(self, motor: NameOrID) -> int:
if isinstance(motor, str):
return self.motors[motor].id
elif isinstance(motor, int):
return motor
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motor_model(self, motor: NameOrID) -> int:
if isinstance(motor, str):
return self.motors[motor].model
elif isinstance(motor, int):
return self._id_to_model_dict[motor]
else:
raise TypeError(f"'{motor}' should be int, str.")
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
if motors is None:
return self.names
elif isinstance(motors, str):
return [motors]
elif isinstance(motors, list):
return motors.copy()
else:
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
if isinstance(values, (int, float)):
return dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
return {self.motors[motor].id: val for motor, val in values.items()}
else:
raise TypeError(f"'values' is expected to be a single value or a dict. Got {values}")
def _validate_motors(self) -> None:
if len(self.ids) != len(set(self.ids)):
raise ValueError(f"Some motors have the same id!\n{self}")
# Ensure ctrl table available for all models
for model in self.models:
get_ctrl_table(self.model_ctrl_table, model)
def _is_comm_success(self, comm: int) -> bool:
return comm == self._comm_success
def _is_error(self, error: int) -> bool:
return error != self._no_error
def _assert_motors_exist(self) -> None:
# TODO(aliberts): collect all wrong ids/models and display them at once
found_models = {}
for id_ in self.ids:
model_nb = self.ping(id_)
if model_nb is not None:
found_models[id_] = model_nb
expected_models = {m.id: self.model_number_table[m.model] for m in self.motors.values()}
if set(found_models) != set(self.ids):
raise RuntimeError(
f"{self.__class__.__name__} is supposed to have these motors: ({{id: model_nb}})"
f"\n{pformat(expected_models, indent=4, sort_dicts=False)}\n"
f"But it found these motors on port '{self.port}':"
f"\n{pformat(found_models, indent=4, sort_dicts=False)}\n"
)
for id_, model in expected_models.items():
if found_models[id_] != model:
raise RuntimeError(
f"Motor '{self._id_to_name(id_)}' (id={id_}) is supposed to be of model_number={model} "
f"('{self._id_to_model(id_)}') but a model_number={found_models[id_]} "
"was found instead for that id."
)
@abc.abstractmethod
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
pass
@property
def is_connected(self) -> bool:
return self.port_handler.is_open
def connect(self, handshake: bool = True) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(
f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
)
try:
if not self.port_handler.openPort():
raise OSError(f"Failed to open port '{self.port}'.")
elif handshake:
self._handshake()
except (FileNotFoundError, OSError, serial.SerialException) as e:
raise ConnectionError(
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
"\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n"
) from e
self.set_timeout()
logger.debug(f"{self.__class__.__name__} connected.")
@abc.abstractmethod
def _handshake(self) -> None:
pass
@classmethod
def scan_port(cls, port: str, *args, **kwargs) -> dict[int, list[int]]:
bus = cls(port, {}, *args, **kwargs)
try:
bus.port_handler.openPort()
except (FileNotFoundError, OSError, serial.SerialException) as e:
raise ConnectionError(
f"Could not connect to port '{port}'. Make sure you are using the correct port."
"\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n"
) from e
baudrate_ids = {}
for baudrate in tqdm(bus.available_baudrates, desc="Scanning port"):
bus.set_baudrate(baudrate)
ids_models = bus.broadcast_ping()
if ids_models:
tqdm.write(f"Motors found for {baudrate=}: {pformat(ids_models, indent=4)}")
baudrate_ids[baudrate] = list(ids_models)
return baudrate_ids
@abc.abstractmethod
def configure_motors(self) -> None:
pass
@abc.abstractmethod
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
pass
@abc.abstractmethod
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
pass
@contextmanager
def torque_disabled(self):
self.disable_torque()
try:
yield
finally:
self.enable_torque()
def set_timeout(self, timeout_ms: int | None = None):
timeout_ms = timeout_ms if timeout_ms is not None else self.default_timeout
self.port_handler.setPacketTimeoutMillis(timeout_ms)
def get_baudrate(self) -> int:
return self.port_handler.getBaudRate()
def set_baudrate(self, baudrate: int) -> None:
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
logger.info(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
raise OSError("Failed to write bus baud rate.")
@property
def is_calibrated(self) -> bool:
return self.calibration == self.read_calibration()
def read_calibration(self) -> dict[str, MotorCalibration]:
offsets = self.sync_read("Homing_Offset", normalize=False)
mins = self.sync_read("Min_Position_Limit", normalize=False)
maxes = self.sync_read("Max_Position_Limit", normalize=False)
try:
drive_modes = self.sync_read("Drive_Mode", normalize=False)
except KeyError:
drive_modes = dict.fromkeys(self.names, 0)
calibration = {}
for name, motor in self.motors.items():
calibration[name] = MotorCalibration(
id=motor.id,
drive_mode=drive_modes[name],
homing_offset=offsets[name],
range_min=mins[name],
range_max=maxes[name],
)
return calibration
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
for motor, calibration in calibration_dict.items():
self.write("Homing_Offset", motor, calibration.homing_offset)
self.write("Min_Position_Limit", motor, calibration.range_min)
self.write("Max_Position_Limit", motor, calibration.range_max)
self.calibration = calibration_dict
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
if motors is None:
motors = self.names
elif isinstance(motors, (str, int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
for motor in motors:
model = self._get_motor_model(motor)
max_res = self.model_resolution_table[model] - 1
self.write("Homing_Offset", motor, 0, normalize=False)
self.write("Min_Position_Limit", motor, 0, normalize=False)
self.write("Max_Position_Limit", motor, max_res, normalize=False)
self.calibration = {}
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]:
"""
This assumes motors present positions are roughly in the middle of their desired range
Step 1: Set homing and min max to 0
Step 2: Read Present_Position which will be Actual_Position since
Present_Position = Actual_Position ± Homing_Offset (1)
and Homing_Offset = 0 from step 1
Step 3: We want to set the Homing_Offset such that the current Present_Position to be half range of 1
revolution. For instance, if 1 revolution corresponds to 4095 (4096 steps), this means we want the
current Present_Position to be 2047.
In that example:
Present_Position = 2047 (2)
Actual_Position = X (read in step 2)
from (1) and (2):
=> Homing_Offset = ±(X - 2048)
"""
if motors is None:
motors = self.names
elif isinstance(motors, (str, int)):
motors = [motors]
else:
raise TypeError(motors)
self.reset_calibration(motors)
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
homing_offsets = self._get_half_turn_homings(actual_positions)
for motor, offset in homing_offsets.items():
self.write("Homing_Offset", motor, offset)
return homing_offsets
@abc.abstractmethod
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
pass
def record_ranges_of_motion(
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
"""
This assumes that the homing offsets have been set such that all possible values in the range of
motion are positive and that the zero is not crossed. To that end, `set_half_turn_homings` should
typically be called prior to this.
"""
if motors is None:
motors = self.names
elif isinstance(motors, (str, int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
start_positions = self.sync_read("Present_Position", motors, normalize=False)
mins = start_positions.copy()
maxes = start_positions.copy()
while True:
positions = self.sync_read("Present_Position", motors, normalize=False)
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
if display_values:
print("\n-------------------------------------------")
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
for name in motors:
print(f"{name:<15} | {mins[name]:>6} | {positions[name]:>6} | {maxes[name]:>6}")
if enter_pressed():
break
if display_values:
# Move cursor up to overwrite the previous output
move_cursor_up(len(motors) + 3)
return mins, maxes
def _normalize(self, data_name: str, ids_values: dict[int, int]) -> dict[int, float]:
if not self.calibration:
raise RuntimeError(f"{self} has no calibration registered.")
normalized_values = {}
for id_, val in ids_values.items():
name = self._id_to_name(id_)
min_ = self.calibration[name].range_min
max_ = self.calibration[name].range_max
bounded_val = min(max_, max(min_, val))
if self.motors[name].norm_mode is MotorNormMode.RANGE_M100_100:
normalized_values[id_] = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
elif self.motors[name].norm_mode is MotorNormMode.RANGE_0_100:
normalized_values[id_] = ((bounded_val - min_) / (max_ - min_)) * 100
else:
# TODO(alibers): velocity and degree modes
raise NotImplementedError
return normalized_values
def _unnormalize(self, data_name: str, ids_values: dict[int, float]) -> dict[int, int]:
if not self.calibration:
raise RuntimeError(f"{self} has no calibration registered.")
unnormalized_values = {}
for id_, val in ids_values.items():
name = self._id_to_name(id_)
min_ = self.calibration[name].range_min
max_ = self.calibration[name].range_max
if self.motors[name].norm_mode is MotorNormMode.RANGE_M100_100:
bounded_val = min(100.0, max(-100.0, val))
unnormalized_values[id_] = int(((bounded_val + 100) / 200) * (max_ - min_) + min_)
elif self.motors[name].norm_mode is MotorNormMode.RANGE_0_100:
bounded_val = min(100.0, max(0.0, val))
unnormalized_values[id_] = int((bounded_val / 100) * (max_ - min_) + min_)
else:
# TODO(alibers): velocity and degree modes
raise NotImplementedError
return unnormalized_values
@abc.abstractmethod
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
pass
@abc.abstractmethod
def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
pass
def _serialize_data(self, value: int, length: int) -> list[int]:
"""
Converts an unsigned integer value into a list of byte-sized integers to be sent via a communication
protocol. Depending on the protocol, split values can be in big-endian or little-endian order.
Supported data length for both Feetech and Dynamixel:
- 1 (for values 0 to 255)
- 2 (for values 0 to 65,535)
- 4 (for values 0 to 4,294,967,295)
"""
if value < 0:
raise ValueError(f"Negative values are not allowed: {value}")
max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(length)
if max_value is None:
raise NotImplementedError(f"Unsupported byte size: {length}. Expected [1, 2, 4].")
if value > max_value:
raise ValueError(f"Value {value} exceeds the maximum for {length} bytes ({max_value}).")
return self._split_into_byte_chunks(value, length)
@abc.abstractmethod
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
"""Convert an integer into a list of byte-sized integers."""
pass
def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> int | None:
id_ = self._get_motor_id(motor)
for n_try in range(1 + num_retry):
model_number, comm, error = self.packet_handler.ping(self.port_handler, id_)
if self._is_comm_success(comm):
break
logger.debug(f"ping failed for {id_=}: {n_try=} got {comm=} {error=}")
if not self._is_comm_success(comm):
if raise_on_error:
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
else:
return
if self._is_error(error):
if raise_on_error:
raise RuntimeError(self.packet_handler.getRxPacketError(error))
else:
return
return model_number
@abc.abstractmethod
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
pass
def read(
self,
data_name: str,
motor: str,
*,
normalize: bool = True,
num_retry: int = 0,
) -> Value:
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
id_ = self.motors[motor].id
model = self.motors[motor].model
addr, length = get_address(self.model_ctrl_table, model, data_name)
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
id_value = self._decode_sign(data_name, {id_: value})
if normalize and data_name in self.normalized_data:
id_value = self._normalize(data_name, id_value)
return id_value[id_]
def _read(
self,
address: int,
length: int,
motor_id: int,
*,
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[int, int]:
if length == 1:
read_fn = self.packet_handler.read1ByteTxRx
elif length == 2:
read_fn = self.packet_handler.read2ByteTxRx
elif length == 4:
read_fn = self.packet_handler.read4ByteTxRx
else:
raise ValueError(length)
for n_try in range(1 + num_retry):
value, comm, error = read_fn(self.port_handler, motor_id, address)
if self._is_comm_success(comm):
break
logger.debug(
f"Failed to read @{address=} ({length=}) on {motor_id=} ({n_try=}): "
+ self.packet_handler.getTxRxResult(comm)
)
if not self._is_comm_success(comm) and raise_on_error:
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
elif self._is_error(error) and raise_on_error:
raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}")
return value, comm, error
def write(
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
id_ = self.motors[motor].id
model = self.motors[motor].model
addr, length = get_address(self.model_ctrl_table, model, data_name)
if normalize and data_name in self.normalized_data:
value = self._unnormalize(data_name, {id_: value})[id_]
value = self._encode_sign(data_name, {id_: value})[id_]
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _write(
self,
addr: int,
length: int,
motor_id: int,
value: int,
*,
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[int, int]:
data = self._serialize_data(value, length)
for n_try in range(1 + num_retry):
comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, length, data)
if self._is_comm_success(comm):
break
logger.debug(
f"Failed to sync write @{addr=} ({length=}) on id={motor_id} with {value=} ({n_try=}): "
+ self.packet_handler.getTxRxResult(comm)
)
if not self._is_comm_success(comm) and raise_on_error:
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
elif self._is_error(error) and raise_on_error:
raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}")
return comm, error
def sync_read(
self,
data_name: str,
motors: str | list[str] | None = None,
*,
normalize: bool = True,
num_retry: int = 0,
) -> dict[str, Value]:
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
self._assert_protocol_is_compatible("sync_read")
names = self._get_motors_list(motors)
ids = [self.motors[name].id for name in names]
models = [self.motors[name].model for name in names]
if self._has_different_ctrl_tables:
assert_same_address(self.model_ctrl_table, models, data_name)
model = next(iter(models))
addr, length = get_address(self.model_ctrl_table, model, data_name)
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
ids_values, _ = self._sync_read(
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
)
ids_values = self._decode_sign(data_name, ids_values)
if normalize and data_name in self.normalized_data:
ids_values = self._normalize(data_name, ids_values)
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
def _sync_read(
self,
addr: int,
length: int,
motor_ids: list[int],
*,
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> tuple[dict[int, int], int]:
self._setup_sync_reader(motor_ids, addr, length)
for n_try in range(1 + num_retry):
comm = self.sync_reader.txRxPacket()
if self._is_comm_success(comm):
break
logger.debug(
f"Failed to sync read @{addr=} ({length=}) on {motor_ids=} ({n_try=}): "
+ self.packet_handler.getTxRxResult(comm)
)
if not self._is_comm_success(comm) and raise_on_error:
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
values = {id_: self.sync_reader.getData(id_, addr, length) for id_ in motor_ids}
return values, comm
def _setup_sync_reader(self, motor_ids: list[int], addr: int, length: int) -> None:
self.sync_reader.clearParam()
self.sync_reader.start_address = addr
self.sync_reader.data_length = length
for id_ in motor_ids:
self.sync_reader.addParam(id_)
# TODO(aliberts, pkooij): Implementing something like this could get even much faster read times if need be.
# Would have to handle the logic of checking if a packet has been sent previously though but doable.
# This could be at the cost of increase latency between the moment the data is produced by the motors and
# the moment it is used by a policy.
# def _async_read(self, motor_ids: list[int], address: int, length: int):
# if self.sync_reader.start_address != address or self.sync_reader.data_length != length or ...:
# self._setup_sync_reader(motor_ids, address, length)
# else:
# self.sync_reader.rxPacket()
# self.sync_reader.txPacket()
# for id_ in motor_ids:
# value = self.sync_reader.getData(id_, address, length)
def sync_write(
self,
data_name: str,
values: Value | dict[str, Value],
*,
normalize: bool = True,
num_retry: int = 0,
) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
)
ids_values = self._get_ids_values_dict(values)
models = [self._id_to_model(id_) for id_ in ids_values]
if self._has_different_ctrl_tables:
assert_same_address(self.model_ctrl_table, models, data_name)
model = next(iter(models))
addr, length = get_address(self.model_ctrl_table, model, data_name)
if normalize and data_name in self.normalized_data:
ids_values = self._unnormalize(data_name, ids_values)
ids_values = self._encode_sign(data_name, ids_values)
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
def _sync_write(
self,
addr: int,
length: int,
ids_values: dict[int, int],
num_retry: int = 0,
raise_on_error: bool = True,
err_msg: str = "",
) -> int:
self._setup_sync_writer(ids_values, addr, length)
for n_try in range(1 + num_retry):
comm = self.sync_writer.txPacket()
if self._is_comm_success(comm):
break
logger.debug(
f"Failed to sync write @{addr=} ({length=}) with {ids_values=} ({n_try=}): "
+ self.packet_handler.getTxRxResult(comm)
)
if not self._is_comm_success(comm) and raise_on_error:
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
return comm
def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, length: int) -> None:
self.sync_writer.clearParam()
self.sync_writer.start_address = addr
self.sync_writer.data_length = length
for id_, value in ids_values.items():
data = self._serialize_data(value, length)
self.sync_writer.addParam(id_, data)
def disconnect(self, disable_torque: bool = True) -> None:
if not self.is_connected:
raise DeviceNotConnectedError(
f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first."
)
if disable_torque:
self.port_handler.clearPort()
self.port_handler.is_using = False
self.disable_torque(num_retry=5)
self.port_handler.closePort()
logger.debug(f"{self.__class__.__name__} disconnected.")

View File

@@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any
import draccus
import torch
@@ -44,7 +45,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
return "adam"
@abc.abstractmethod
def build(self) -> torch.optim.Optimizer:
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
raise NotImplementedError
@@ -94,7 +95,76 @@ class SGDConfig(OptimizerConfig):
return torch.optim.SGD(params, **kwargs)
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
@OptimizerConfig.register_subclass("multi_adam")
@dataclass
class MultiAdamConfig(OptimizerConfig):
"""Configuration for multiple Adam optimizers with different parameter groups.
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
Args:
lr: Default learning rate (used if not specified for a group)
weight_decay: Default weight decay (used if not specified for a group)
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
grad_clip_norm: Gradient clipping norm
"""
lr: float = 1e-3
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
"""Build multiple Adam optimizers.
Args:
params_dict: Dictionary mapping parameter group names to lists of parameters
The keys should match the keys in optimizer_groups
Returns:
Dictionary mapping parameter group names to their optimizers
"""
optimizers = {}
for name, params in params_dict.items():
# Get group-specific hyperparameters or use defaults
group_config = self.optimizer_groups.get(name, {})
# Create optimizer with merged parameters (defaults + group-specific)
optimizer_kwargs = {
"lr": group_config.get("lr", self.lr),
"betas": group_config.get("betas", (0.9, 0.999)),
"eps": group_config.get("eps", 1e-5),
"weight_decay": group_config.get("weight_decay", self.weight_decay),
}
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
return optimizers
def save_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> None:
"""Save optimizer state to disk.
Args:
optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to save the optimizer state.
"""
if isinstance(optimizer, dict):
# Handle dictionary of optimizers
for name, opt in optimizer.items():
optimizer_dir = save_dir / name
optimizer_dir.mkdir(exist_ok=True, parents=True)
_save_single_optimizer_state(opt, optimizer_dir)
else:
# Handle single optimizer
_save_single_optimizer_state(optimizer, save_dir)
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
"""Save a single optimizer's state to disk."""
state = optimizer.state_dict()
param_groups = state.pop("param_groups")
flat_state = flatten_dict(state)
@@ -102,11 +172,44 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
def load_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
"""Load optimizer state from disk.
Args:
optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to load the optimizer state from.
Returns:
The updated optimizer(s) with loaded state.
"""
if isinstance(optimizer, dict):
# Handle dictionary of optimizers
loaded_optimizers = {}
for name, opt in optimizer.items():
optimizer_dir = save_dir / name
if optimizer_dir.exists():
loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
else:
loaded_optimizers[name] = opt
return loaded_optimizers
else:
# Handle single optimizer
return _load_single_optimizer_state(optimizer, save_dir)
def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
"""Load a single optimizer's state from disk."""
current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
# Handle case where 'state' key might not exist (for newly created optimizers)
if "state" in state:
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
else:
loaded_state_dict = {"state": {}}
if "param_groups" in current_state_dict:
param_groups = deserialize_json_into_object(

View File

@@ -49,7 +49,11 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
from diffusers.optimization import get_scheduler
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
kwargs = {
**asdict(self),
"num_training_steps": num_training_steps,
"optimizer": optimizer,
}
return get_scheduler(**kwargs)
@@ -71,7 +75,10 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
progress = float(adjusted_step - self.num_warmup_steps) / float(
max(1, num_training_steps - self.num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
return max(
0.0,
0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)),
)
return LambdaLR(optimizer, lr_lambda, -1)

View File

@@ -241,7 +241,9 @@ class ACTTemporalEnsembler:
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
(self.chunk_size, 1),
dtype=torch.long,
device=self.ensembled_actions.device,
)
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
@@ -253,7 +255,10 @@ class ACTTemporalEnsembler:
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions_count = torch.cat(
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
[
self.ensembled_actions_count,
torch.ones_like(self.ensembled_actions_count[-1:]),
]
)
# "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = (
@@ -333,7 +338,11 @@ class ACT(nn.Module):
# Backbone for image feature extraction.
if self.config.image_features:
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
replace_stride_with_dilation=[
False,
False,
config.replace_final_stride_with_dilation,
],
weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
)
@@ -427,7 +436,11 @@ class ACT(nn.Module):
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
vae_encoder_input = [
cls_embed,
robot_state_embed,
action_embed,
] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
@@ -540,7 +553,10 @@ class ACTEncoder(nn.Module):
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
self,
x: Tensor,
pos_embed: Tensor | None = None,
key_padding_mask: Tensor | None = None,
) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
@@ -603,7 +619,10 @@ class ACTDecoder(nn.Module):
) -> Tensor:
for layer in self.layers:
x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
x,
encoder_out,
decoder_pos_embed=decoder_pos_embed,
encoder_pos_embed=encoder_pos_embed,
)
if self.norm is not None:
x = self.norm(x)

View File

@@ -33,7 +33,7 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor, nn
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
@@ -209,7 +209,10 @@ class DiffusionModel(nn.Module):
# ========= inference ============
def conditional_sample(
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
self,
batch_size: int,
global_cond: Tensor | None = None,
generator: torch.Generator | None = None,
) -> Tensor:
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
@@ -238,8 +241,8 @@ class DiffusionModel(nn.Module):
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
global_cond_feats = [batch[OBS_STATE]]
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
global_cond_feats = [batch[OBS_ROBOT]]
# Extract image features.
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
@@ -254,7 +257,10 @@ class DiffusionModel(nn.Module):
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
img_features_list,
"(n b s) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
)
else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
@@ -264,12 +270,15 @@ class DiffusionModel(nn.Module):
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
img_features,
"(b s n) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
)
global_cond_feats.append(img_features)
if self.config.env_state_feature:
global_cond_feats.append(batch[OBS_ENV_STATE])
global_cond_feats.append(batch[OBS_ENV])
# Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
@@ -515,7 +524,9 @@ class DiffusionRgbEncoder(nn.Module):
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module:
"""
Args:
@@ -633,10 +644,14 @@ class DiffusionConditionalUnet1d(nn.Module):
self.mid_modules = nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
),
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
),
]
)

View File

@@ -24,6 +24,7 @@ from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
@@ -59,6 +60,14 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
elif name == "sac":
from lerobot.common.policies.sac.modeling_sac import SACPolicy
return SACPolicy
elif name == "hilserl_classifier":
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
return Classifier
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -76,6 +85,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
elif policy_type == "hilserl_classifier":
return ClassifierConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")

View File

@@ -0,0 +1,53 @@
from dataclasses import dataclass
from typing import List
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
@dataclass
class ClassifierConfig(PreTrainedConfig):
"""Configuration for the Classifier model."""
name: str = "hilserl_classifier"
num_classes: int = 2
hidden_dim: int = 256
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2
learning_rate: float = 1e-4
normalization_mode = None
# output_features: Dict[str, PolicyFeature] = field(
# default_factory=lambda: {"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,))}
# )
@property
def observation_delta_indices(self) -> List | None:
return None
@property
def action_delta_indices(self) -> List | None:
return None
@property
def reward_delta_indices(self) -> List | None:
return None
def get_optimizer_preset(self) -> OptimizerConfig:
return AdamWConfig(
lr=self.learning_rate,
weight_decay=0.01,
grad_clip_norm=1.0,
)
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
return None
def validate_features(self) -> None:
"""Validate feature configurations."""
# Classifier doesn't need specific feature validation
pass

View File

@@ -0,0 +1,237 @@
import logging
from typing import Dict, Optional, Tuple
import torch
from torch import Tensor, nn
from lerobot.common.constants import OBS_IMAGE
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
def __init__(
self,
logits: Tensor,
probabilities: Optional[Tensor] = None,
hidden_states: Optional[Tensor] = None,
):
self.logits = logits
self.probabilities = probabilities
self.hidden_states = hidden_states
def __repr__(self):
return (
f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})"
)
class Classifier(PreTrainedPolicy):
"""Image classifier built on top of a pre-trained encoder."""
name = "hilserl_classifier"
config_class = ClassifierConfig
def __init__(
self,
config: ClassifierConfig,
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
):
from transformers import AutoModel
super().__init__(config)
self.config = config
# Initialize normalization (standardized with the policy framework)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
# Set up encoder
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only")
self.encoder = encoder.vision_model
self.vision_config = encoder.config.vision_config
else:
self.encoder = encoder
self.vision_config = getattr(encoder, "config", None)
# Model type from config
self.is_cnn = self.config.model_type == "cnn"
# For CNNs, initialize backbone
if self.is_cnn:
self._setup_cnn_backbone()
self._freeze_encoder()
self._build_classifier_head()
def _setup_cnn_backbone(self):
"""Set up CNN encoder"""
if hasattr(self.encoder, "fc"):
self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters."""
for param in self.encoder.parameters():
param.requires_grad = False
def _build_classifier_head(self) -> None:
"""Initialize the classifier head architecture."""
# Get input dimension based on model type
if self.is_cnn:
input_dim = self.feature_dim
else: # Transformer models
if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size
else:
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
self.classifier_head = nn.Sequential(
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(),
nn.Linear(
self.config.hidden_dim,
1 if self.config.num_classes == 2 else self.config.num_classes,
),
)
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
"""Extract the appropriate output from the encoder."""
with torch.no_grad():
if self.is_cnn:
# The HF ResNet applies pooling internally
outputs = self.encoder(x)
# Get pooled output directly
features = outputs.pooler_output
if features.dim() > 2:
features = features.squeeze(-1).squeeze(-1)
return features
else: # Transformer models
outputs = self.encoder(x)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :]
def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
"""Extract image tensors and label tensors from batch."""
# Find image keys in input features
image_keys = [key for key in self.config.input_features if key.startswith(OBS_IMAGE)]
# Extract the images and labels
images = [batch[key] for key in image_keys]
labels = batch["next.reward"]
return images, labels
def predict(self, xs: list) -> ClassifierOutput:
"""Forward pass of the classifier for inference."""
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
logits = self.classifier_head(encoder_outputs)
if self.config.num_classes == 2:
logits = logits.squeeze(-1)
probabilities = torch.sigmoid(logits)
else:
probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
def forward(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
"""Standard forward pass for training compatible with train.py."""
# Normalize inputs if needed
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract images and labels
images, labels = self.extract_images_and_labels(batch)
# Get predictions
outputs = self.predict(images)
# Calculate loss
if self.config.num_classes == 2:
# Binary classification
loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels)
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
else:
# Multi-class classification
loss = nn.functional.cross_entropy(outputs.logits, labels.long())
predictions = torch.argmax(outputs.logits, dim=1)
# Calculate accuracy for logging
correct = (predictions == labels).sum().item()
total = labels.size(0)
accuracy = 100 * correct / total
# Return loss and metrics for logging
output_dict = {
"accuracy": accuracy,
"correct": correct,
"total": total,
}
return loss, output_dict
def predict_reward(self, batch, threshold=0.6):
"""Legacy method for compatibility."""
images, _ = self.extract_images_and_labels(batch)
if self.config.num_classes == 2:
probs = self.predict(images).probabilities
logging.debug(f"Predicted reward images: {probs}")
return (probs > threshold).float()
else:
return torch.argmax(self.predict(images).probabilities, dim=1)
# Methods required by PreTrainedPolicy abstract class
def get_optim_params(self) -> dict:
"""Return optimizer parameters for the policy."""
return {
"params": self.parameters(),
"lr": getattr(self.config, "learning_rate", 1e-4),
"weight_decay": getattr(self.config, "weight_decay", 0.01),
}
def reset(self):
"""Reset any stateful components (required by PreTrainedPolicy)."""
# Classifier doesn't have stateful components that need resetting
pass
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
"""Return action (class prediction) based on input observation."""
images, _ = self.extract_images_and_labels(batch)
with torch.no_grad():
outputs = self.predict(images)
if self.config.num_classes == 2:
# For binary classification return 0 or 1
return (outputs.probabilities > 0.5).float()
else:
# For multi-class return the predicted class
return torch.argmax(outputs.probabilities, dim=1)

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,10 +17,7 @@
from dataclasses import dataclass
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("keyboard")
@dataclass
class KeyboardTeleopConfig(TeleoperatorConfig):
mock: bool = False
class HILSerlConfig:
pass

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,12 +15,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
@TeleoperatorConfig.register_subclass("widowx")
@dataclass
class WidowXConfig(TeleoperatorConfig):
port: str # Port to connect to the arm
class HILSerlPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "hilserl"],
):
pass

View File

@@ -79,28 +79,46 @@ def create_stats_buffers(
)
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats:
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
if stats and key in stats:
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" not in stats[key] or "std" not in stats[key]:
raise ValueError(
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
)
if isinstance(stats[key]["mean"], np.ndarray):
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
else:
type_ = type(stats[key]["mean"])
raise ValueError(
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
)
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" not in stats[key] or "max" not in stats[key]:
raise ValueError(
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
)
if isinstance(stats[key]["min"], np.ndarray):
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["min"], torch.Tensor):
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
else:
type_ = type(stats[key]["min"])
raise ValueError(
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
)
stats_buffers[key] = buffer
return stats_buffers
@@ -149,12 +167,13 @@ class Normalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
# @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
# FIXME(aliberts, rcadene): This might lead to silent fail!
# NOTE: (azouitine) This continues help us for instantiation SACPolicy
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
@@ -223,7 +242,7 @@ class Unnormalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
# @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():

View File

@@ -61,7 +61,11 @@ from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
)
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
PRECISIONS = {
"bfloat16": torch.bfloat16,
"float32": torch.float32,
"float16": torch.float16,
}
def slice_paligemma_state_dict(state_dict, config):

View File

@@ -48,18 +48,32 @@ def flex_attention_forward(
key_states = key_states[:, :, :, None, :]
key_states = key_states.expand(
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
batch_size,
key_states.shape[1],
num_key_value_heads,
num_key_value_groups,
head_dim,
)
key_states = key_states.reshape(
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
batch_size,
key_states.shape[1],
num_key_value_heads * num_key_value_groups,
head_dim,
)
value_states = value_states[:, :, :, None, :]
value_states = value_states.expand(
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
batch_size,
value_states.shape[1],
num_key_value_heads,
num_key_value_groups,
head_dim,
)
value_states = value_states.reshape(
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
batch_size,
value_states.shape[1],
num_key_value_heads * num_key_value_groups,
head_dim,
)
query_states = query_states.transpose(1, 2)

View File

@@ -57,7 +57,7 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoTokenizer
from lerobot.common.constants import ACTION, OBS_STATE
from lerobot.common.constants import ACTION, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0.paligemma_with_expert import (
@@ -69,7 +69,11 @@ from lerobot.common.utils.utils import get_safe_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
time: torch.tensor,
dimension: int,
min_period: float,
max_period: float,
device="cpu",
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
@@ -271,7 +275,7 @@ class PI0Policy(PreTrainedPolicy):
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch = self.normalize_inputs(batch)
@@ -303,7 +307,7 @@ class PI0Policy(PreTrainedPolicy):
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
batch = self.normalize_inputs(batch)
@@ -380,7 +384,7 @@ class PI0Policy(PreTrainedPolicy):
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
device = batch[OBS_ROBOT].device
tasks = batch["task"]
# PaliGemma prompt has to end with a new line
@@ -427,7 +431,7 @@ class PI0Policy(PreTrainedPolicy):
def prepare_state(self, batch):
"""Pad state"""
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
return state
def prepare_action(self, batch):
@@ -577,7 +581,11 @@ class PI0FlowMatching(nn.Module):
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
timestep,
self.config.proj_width,
min_period=4e-3,
max_period=4.0,
device=device,
)
time_emb = time_emb.type(dtype=dtype)
@@ -609,7 +617,15 @@ class PI0FlowMatching(nn.Module):
return embs, pad_masks, att_masks
def forward(
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
actions,
noise=None,
time=None,
) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None:
@@ -655,7 +671,11 @@ class PI0FlowMatching(nn.Module):
device = state.device
if noise is None:
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
actions_shape = (
bsize,
self.config.n_action_steps,
self.config.max_action_dim,
)
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(

View File

@@ -293,12 +293,18 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat(
[past_key_values[layer_idx]["value_states"], value_states], dim=1
[past_key_values[layer_idx]["value_states"], value_states],
dim=1,
)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask, batch_size, head_dim, query_states, key_states, value_states
attention_mask,
batch_size,
head_dim,
query_states,
key_states,
value_states,
)
att_output = att_output.to(dtype=torch.bfloat16)
@@ -358,12 +364,24 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
return attention_interface
def flash_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
self,
attention_mask,
batch_size,
head_dim,
query_states,
key_states,
value_states,
):
raise NotImplementedError("FA2 is not implemented (yet)")
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
self,
attention_mask,
batch_size,
head_dim,
query_states,
key_states,
value_states,
):
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
@@ -375,17 +393,31 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
batch_size,
sequence_length,
num_key_value_heads,
num_key_value_groups,
head_dim,
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
batch_size,
sequence_length,
num_key_value_heads * num_key_value_groups,
head_dim,
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
batch_size,
sequence_length,
num_key_value_heads,
num_key_value_groups,
head_dim,
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
batch_size,
sequence_length,
num_key_value_heads * num_key_value_groups,
head_dim,
)
# Attention here is upcasted to float32 to match the original eager implementation.

View File

@@ -0,0 +1,229 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import MultiAdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@dataclass
class ConcurrencyConfig:
actor: str = "threads"
learner: str = "threads"
@dataclass
class ActorLearnerConfig:
learner_host: str = "127.0.0.1"
learner_port: int = 50051
policy_parameters_push_frequency: int = 4
@dataclass
class CriticNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
final_activation: str | None = None
@dataclass
class ActorNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
@dataclass
class PolicyConfig:
use_tanh_squash: bool = True
log_std_min: float = 1e-5
log_std_max: float = 10.0
init_final: float = 0.05
@PreTrainedConfig.register_subclass("sac")
@dataclass
class SACConfig(PreTrainedConfig):
"""Soft Actor-Critic (SAC) configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
reinforcement learning framework. It learns a policy and a Q-function simultaneously
using experience collected from the environment.
This configuration class contains all the parameters needed to define a SAC agent,
including network architectures, optimization settings, and algorithm-specific
hyperparameters.
Args:
actor_network: Configuration for the actor network architecture.
critic_network: Configuration for the critic network architecture.
policy: Configuration for the policy parameters.
n_obs_steps: Number of observation steps to consider.
normalization_mapping: Mapping of feature types to normalization modes.
dataset_stats: Statistics for normalizing different types of inputs.
input_features: Dictionary of input features with their types and shapes.
output_features: Dictionary of output features with their types and shapes.
camera_number: Number of cameras used for visual observations.
device: Device to run the model on (e.g., "cuda", "cpu").
storage_device: Device to store the model on.
vision_encoder_name: Name of the vision encoder model.
freeze_vision_encoder: Whether to freeze the vision encoder during training.
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
shared_encoder: Whether to use a shared encoder for actor and critic.
num_discrete_actions: Number of discrete actions, eg for gripper actions.
image_embedding_pooling_dim: Dimension of the image embedding pooling.
concurrency: Configuration for concurrency settings.
actor_learner: Configuration for actor-learner architecture.
online_steps: Number of steps for online training.
online_env_seed: Seed for the online environment.
online_buffer_capacity: Capacity of the online replay buffer.
offline_buffer_capacity: Capacity of the offline replay buffer.
async_prefetch: Whether to use asynchronous prefetching for the buffers.
online_step_before_learning: Number of steps before learning starts.
policy_update_freq: Frequency of policy updates.
discount: Discount factor for the SAC algorithm.
temperature_init: Initial temperature value.
num_critics: Number of critics in the ensemble.
num_subsample_critics: Number of subsampled critics for training.
critic_lr: Learning rate for the critic network.
actor_lr: Learning rate for the actor network.
temperature_lr: Learning rate for the temperature parameter.
critic_target_update_weight: Weight for the critic target update.
utd_ratio: Update-to-data ratio for the UTD algorithm.
state_encoder_hidden_dim: Hidden dimension size for the state encoder.
latent_dim: Dimension of the latent space.
target_entropy: Target entropy for the SAC algorithm.
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
"""
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ENV": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: {
"observation.image": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
}
)
# Architecture specifics
camera_number: int = 1
device: str = "cuda"
storage_device: str = "cpu"
# Set to "helper2424/resnet10" for hil serl
vision_encoder_name: str | None = None
freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32
shared_encoder: bool = True
num_discrete_actions: int | None = None
image_embedding_pooling_dim: int = 8
# Training parameter
online_steps: int = 1000000
online_env_seed: int = 10000
online_buffer_capacity: int = 100000
offline_buffer_capacity: int = 100000
async_prefetch: bool = False
online_step_before_learning: int = 100
policy_update_freq: int = 1
# SAC algorithm parameters
discount: float = 0.99
temperature_init: float = 1.0
num_critics: int = 2
num_subsample_critics: int | None = None
critic_lr: float = 3e-4
actor_lr: float = 3e-4
temperature_lr: float = 3e-4
critic_target_update_weight: float = 0.005
utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1
state_encoder_hidden_dim: int = 256
latent_dim: int = 256
target_entropy: float | None = None
use_backup_entropy: bool = True
grad_clip_norm: float = 40.0
# Network configuration
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
def __post_init__(self):
super().__post_init__()
# Any validation specific to SAC configuration
def get_optimizer_preset(self) -> MultiAdamConfig:
return MultiAdamConfig(
weight_decay=0.0,
optimizer_groups={
"actor": {"lr": self.actor_lr},
"critic": {"lr": self.critic_lr},
"temperature": {"lr": self.temperature_lr},
},
)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
has_image = any(key.startswith("observation.image") for key in self.input_features)
has_state = "observation.state" in self.input_features
if not (has_state or has_image):
raise ValueError(
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
)
if "action" not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
def image_features(self) -> list[str]:
return [key for key in self.input_features if "image" in key]
@property
def observation_delta_indices(self) -> list:
return None
@property
def action_delta_indices(self) -> list:
return None # SAC typically predicts one action at a time
@property
def reward_delta_indices(self) -> None:
return None

File diff suppressed because it is too large Load Diff

View File

@@ -35,11 +35,15 @@ import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_output_shape,
populate_queues,
)
class TDMPCPolicy(PreTrainedPolicy):
@@ -63,7 +67,11 @@ class TDMPCPolicy(PreTrainedPolicy):
config_class = TDMPCConfig
name = "tdmpc"
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
def __init__(
self,
config: TDMPCConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
@@ -189,13 +197,20 @@ class TDMPCPolicy(PreTrainedPolicy):
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
# trajectories.
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
z = einops.repeat(
z,
"b d -> n b d",
n=self.config.n_gaussian_samples + self.config.n_pi_samples,
)
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
# algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros(
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
self.config.horizon,
batch_size,
self.config.action_feature.shape[0],
device=device,
)
# Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None:
@@ -291,9 +306,10 @@ class TDMPCPolicy(PreTrainedPolicy):
if self.config.q_ensemble_size > 2:
G += (
running_discount
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
0
]
* torch.min(
terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))],
dim=0,
)[0]
)
else:
G += running_discount * torch.min(terminal_values, dim=0)[0]
@@ -329,7 +345,10 @@ class TDMPCPolicy(PreTrainedPolicy):
# Apply random image augmentations.
if self.config.image_features and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
partial(
random_shifts_aug,
max_random_shift_ratio=self.config.max_random_shift_ratio,
),
observations["observation.image"],
)
@@ -553,7 +572,10 @@ class TDMPCTOLD(nn.Module):
self._Qs = nn.ModuleList(
[
nn.Sequential(
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.Linear(
config.latent_dim + config.action_feature.shape[0],
config.mlp_dim,
),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -702,11 +724,26 @@ class TDMPCObservationEncoder(nn.Module):
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(),
)
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
@@ -749,13 +786,14 @@ class TDMPCObservationEncoder(nn.Module):
if self.config.image_features:
feat.append(
flatten_forward_unflatten(
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
self.image_enc_layers,
obs_dict[next(iter(self.config.image_features))],
)
)
if self.config.env_state_feature:
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV_STATE]))
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV]))
if self.config.robot_state_feature:
feat.append(self.state_enc_layers(obs_dict[OBS_STATE]))
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT]))
return torch.stack(feat, dim=0).mean(0)
@@ -796,7 +834,9 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
ema_module.named_parameters(recurse=False),
module.named_parameters(recurse=False),
strict=True,
):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict):

View File

@@ -193,7 +193,12 @@ class VQBeTConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
return list(
range(
1 - self.n_obs_steps,
self.n_action_pred_token + self.action_chunk_size - 1,
)
)
@property
def reward_delta_indices(self) -> None:

View File

@@ -29,7 +29,11 @@ from torch import Tensor, nn
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_output_shape,
populate_queues,
)
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
@@ -324,7 +328,8 @@ class VQBeTModel(nn.Module):
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self.state_projector = MLP(
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
config.robot_state_feature.shape[0],
hidden_channels=[self.config.gpt_input_dim],
)
self.rgb_feature_projector = MLP(
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
@@ -354,7 +359,11 @@ class VQBeTModel(nn.Module):
)
# Separate batch and sequence dims.
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
img_features,
"(b s n) ... -> b s n ...",
b=batch_size,
s=n_obs_steps,
n=self.num_images,
)
# Arrange prior and current observation step tokens as shown in the class docstring.
@@ -391,7 +400,11 @@ class VQBeTModel(nn.Module):
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
if len_additional_action_token > 0:
features = torch.cat(
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
[
features[:, historical_act_pred_index],
features[:, -len_additional_action_token:],
],
dim=1,
)
else:
features = features[:, historical_act_pred_index]
@@ -514,7 +527,13 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
torch.cat(
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
(
x,
F.one_hot(
sampled_primary_centers,
num_classes=self.config.vqvae_n_embed,
),
),
axis=1,
)
)
@@ -532,7 +551,9 @@ class VQBeTHead(nn.Module):
else:
cbet_logits = self.map_to_cbet_preds_bin(x)
cbet_logits = einops.rearrange(
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
cbet_logits,
"(NT) (G C) -> (NT) G C",
G=self.vqvae_model.vqvae_num_layers,
)
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
NT, G, choices = cbet_probs.shape
@@ -730,7 +751,9 @@ class VQBeTRgbEncoder(nn.Module):
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module:
"""
Args:

View File

@@ -377,7 +377,10 @@ class ResidualVQ(nn.Module):
self.layers = nn.ModuleList(
[
VectorQuantize(
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
dim=codebook_dim,
codebook_dim=codebook_dim,
accept_image_fmap=accept_image_fmap,
**kwargs,
)
for _ in range(num_quantizers)
]

View File

@@ -12,24 +12,67 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import dataclass
from ..configs import CameraConfig
import draccus
@CameraConfig.register_subclass("intelrealsense")
@dataclass
class RealSenseCameraConfig(CameraConfig):
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@CameraConfig.register_subclass("opencv")
@dataclass
class OpenCVCameraConfig(CameraConfig):
"""
Example of tested options for Intel Real Sense D405:
```python
RealSenseCameraConfig(128422271347, 30, 640, 480)
RealSenseCameraConfig(128422271347, 60, 640, 480)
RealSenseCameraConfig(128422271347, 90, 640, 480)
RealSenseCameraConfig(128422271347, 30, 1280, 720)
RealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True)
RealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90)
OpenCVCameraConfig(0, 30, 640, 480)
OpenCVCameraConfig(0, 60, 640, 480)
OpenCVCameraConfig(0, 90, 640, 480)
OpenCVCameraConfig(0, 30, 1280, 720)
```
"""
camera_index: int
fps: int | None = None
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
channels: int | None = None
rotation: int | None = None
mock: bool = False
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
@CameraConfig.register_subclass("intelrealsense")
@dataclass
class IntelRealSenseCameraConfig(CameraConfig):
"""
Example of tested options for Intel Real Sense D405:
```python
IntelRealSenseCameraConfig(128422271347, 30, 640, 480)
IntelRealSenseCameraConfig(128422271347, 60, 640, 480)
IntelRealSenseCameraConfig(128422271347, 90, 640, 480)
IntelRealSenseCameraConfig(128422271347, 30, 1280, 720)
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True)
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90)
```
"""

View File

@@ -31,15 +31,14 @@ from threading import Thread
import numpy as np
from PIL import Image
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.utils.robot_utils import (
from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
busy_wait,
)
from lerobot.common.utils.utils import capture_timestamp_utc
from ..camera import Camera
from .configuration_realsense import RealSenseCameraConfig
SERIAL_NUMBER_INDEX = 1
@@ -109,11 +108,13 @@ def save_images_from_cameras(
cameras = []
for cam_sn in serial_numbers:
print(f"{cam_sn=}")
config = RealSenseCameraConfig(serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock)
camera = RealSenseCamera(config)
config = IntelRealSenseCameraConfig(
serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock
)
camera = IntelRealSenseCamera(config)
camera.connect()
print(
f"RealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
)
cameras.append(camera)
@@ -165,11 +166,11 @@ def save_images_from_cameras(
camera.disconnect()
class RealSenseCamera(Camera):
class IntelRealSenseCamera:
"""
The RealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras:
The IntelRealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras:
- is instantiated with the serial number of the camera - won't randomly change as it can be the case of OpenCVCamera for Linux,
- can also be instantiated with the camera's name — if it's unique using RealSenseCamera.init_from_name(),
- can also be instantiated with the camera's name — if it's unique using IntelRealSenseCamera.init_from_name(),
- depth map can be returned.
To find the camera indices of your cameras, you can run our utility script that will save a few frames for each camera:
@@ -177,15 +178,15 @@ class RealSenseCamera(Camera):
python lerobot/common/robot_devices/cameras/intelrealsense.py --images-dir outputs/images_from_intelrealsense_cameras
```
When an RealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
of the given camera will be used.
Example of instantiating with a serial number:
```python
from lerobot.common.robot_devices.cameras.configs import RealSenseCameraConfig
from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
config = RealSenseCameraConfig(serial_number=128422271347)
camera = RealSenseCamera(config)
config = IntelRealSenseCameraConfig(serial_number=128422271347)
camera = IntelRealSenseCamera(config)
camera.connect()
color_image = camera.read()
# when done using the camera, consider disconnecting
@@ -194,21 +195,21 @@ class RealSenseCamera(Camera):
Example of instantiating with a name if it's unique:
```
config = RealSenseCameraConfig(name="Intel RealSense D405")
config = IntelRealSenseCameraConfig(name="Intel RealSense D405")
```
Example of changing default fps, width, height and color_mode:
```python
config = RealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720)
config = RealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480)
config = RealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr")
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720)
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480)
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr")
# Note: might error out upon `camera.connect()` if these settings are not compatible with the camera
```
Example of returning depth:
```python
config = RealSenseCameraConfig(serial_number=128422271347, use_depth=True)
camera = RealSenseCamera(config)
config = IntelRealSenseCameraConfig(serial_number=128422271347, use_depth=True)
camera = IntelRealSenseCamera(config)
camera.connect()
color_image, depth_map = camera.read()
```
@@ -216,7 +217,7 @@ class RealSenseCamera(Camera):
def __init__(
self,
config: RealSenseCameraConfig,
config: IntelRealSenseCameraConfig,
):
self.config = config
if config.name is not None:
@@ -281,7 +282,9 @@ class RealSenseCamera(Camera):
def connect(self):
if self.is_connected:
raise DeviceAlreadyConnectedError(f"RealSenseCamera({self.serial_number}) is already connected.")
raise RobotDeviceAlreadyConnectedError(
f"IntelRealSenseCamera({self.serial_number}) is already connected."
)
if self.mock:
import tests.cameras.mock_pyrealsense2 as rs
@@ -294,7 +297,11 @@ class RealSenseCamera(Camera):
if self.fps and self.capture_width and self.capture_height:
# TODO(rcadene): can we set rgb8 directly?
config.enable_stream(
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
rs.stream.color,
self.capture_width,
self.capture_height,
rs.format.rgb8,
self.fps,
)
else:
config.enable_stream(rs.stream.color)
@@ -302,7 +309,11 @@ class RealSenseCamera(Camera):
if self.use_depth:
if self.fps and self.capture_width and self.capture_height:
config.enable_stream(
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
rs.stream.depth,
self.capture_width,
self.capture_height,
rs.format.z16,
self.fps,
)
else:
config.enable_stream(rs.stream.depth)
@@ -327,7 +338,7 @@ class RealSenseCamera(Camera):
"To find the serial number you should use, run `python lerobot/common/robot_devices/cameras/intelrealsense.py`."
)
raise OSError(f"Can't access RealSenseCamera({self.serial_number}).")
raise OSError(f"Can't access IntelRealSenseCamera({self.serial_number}).")
color_stream = profile.get_stream(rs.stream.color)
color_profile = color_stream.as_video_stream_profile()
@@ -339,15 +350,15 @@ class RealSenseCamera(Camera):
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
# Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError(
f"Can't set {self.fps=} for RealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
)
if self.capture_width is not None and self.capture_width != actual_width:
raise OSError(
f"Can't set {self.capture_width=} for RealSenseCamera({self.serial_number}). Actual value is {actual_width}."
f"Can't set {self.capture_width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
)
if self.capture_height is not None and self.capture_height != actual_height:
raise OSError(
f"Can't set {self.capture_height=} for RealSenseCamera({self.serial_number}). Actual value is {actual_height}."
f"Can't set {self.capture_height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
)
self.fps = round(actual_fps)
@@ -367,8 +378,8 @@ class RealSenseCamera(Camera):
If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`.
"""
if not self.is_connected:
raise DeviceNotConnectedError(
f"RealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
raise RobotDeviceNotConnectedError(
f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
)
if self.mock:
@@ -383,7 +394,7 @@ class RealSenseCamera(Camera):
color_frame = frame.get_color_frame()
if not color_frame:
raise OSError(f"Can't capture color image from RealSenseCamera({self.serial_number}).")
raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
color_image = np.asanyarray(color_frame.get_data())
@@ -415,7 +426,7 @@ class RealSenseCamera(Camera):
if self.use_depth:
depth_frame = frame.get_depth_frame()
if not depth_frame:
raise OSError(f"Can't capture depth image from RealSenseCamera({self.serial_number}).")
raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
depth_map = np.asanyarray(depth_frame.get_data())
@@ -442,8 +453,8 @@ class RealSenseCamera(Camera):
def async_read(self):
"""Access the latest color image"""
if not self.is_connected:
raise DeviceNotConnectedError(
f"RealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
raise RobotDeviceNotConnectedError(
f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
)
if self.thread is None:
@@ -469,8 +480,8 @@ class RealSenseCamera(Camera):
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(
f"RealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
raise RobotDeviceNotConnectedError(
f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
)
if self.thread is not None and self.thread.is_alive():
@@ -492,14 +503,14 @@ class RealSenseCamera(Camera):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Save a few frames using `RealSenseCamera` for all cameras connected to the computer, or a selected subset."
description="Save a few frames using `IntelRealSenseCamera` for all cameras connected to the computer, or a selected subset."
)
parser.add_argument(
"--serial-numbers",
type=int,
nargs="*",
default=None,
help="List of serial numbers used to instantiate the `RealSenseCamera`. If not provided, find and use all available camera indices.",
help="List of serial numbers used to instantiate the `IntelRealSenseCamera`. If not provided, find and use all available camera indices.",
)
parser.add_argument(
"--fps",
@@ -509,13 +520,13 @@ if __name__ == "__main__":
)
parser.add_argument(
"--width",
type=str,
type=int,
default=640,
help="Set the width for all cameras. If not provided, use the default width of each camera.",
)
parser.add_argument(
"--height",
type=str,
type=int,
default=480,
help="Set the height for all cameras. If not provided, use the default height of each camera.",
)

View File

@@ -24,20 +24,19 @@ import shutil
import threading
import time
from pathlib import Path
from threading import Thread
import cv2
import numpy as np
from PIL import Image
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.utils.robot_utils import (
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
busy_wait,
)
from lerobot.common.utils.utils import capture_timestamp_utc
from ..camera import Camera
from .configuration_opencv import OpenCVCameraConfig
# The maximum opencv device index depends on your operating system. For instance,
# if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case
# on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23.
@@ -46,12 +45,12 @@ from .configuration_opencv import OpenCVCameraConfig
MAX_OPENCV_INDEX = 60
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX) -> list[dict]:
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
cameras = []
if platform.system() == "Linux":
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
possible_ports = [str(port) for port in Path("/dev").glob("video*")]
ports = _find_cameras(possible_ports)
ports = _find_cameras(possible_ports, mock=mock)
for port in ports:
cameras.append(
{
@@ -65,7 +64,7 @@ def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX
f"scanning all indices from 0 to {MAX_OPENCV_INDEX}"
)
possible_indices = range(max_index_search_range)
indices = _find_cameras(possible_indices)
indices = _find_cameras(possible_indices, mock=mock)
for index in indices:
cameras.append(
{
@@ -77,7 +76,14 @@ def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX
return cameras
def _find_cameras(possible_camera_ids: list[int | str], raise_when_empty=False) -> list[int | str]:
def _find_cameras(
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
) -> list[int | str]:
if mock:
import tests.cameras.mock_cv2 as cv2
else:
import cv2
camera_ids = []
for camera_idx in possible_camera_ids:
camera = cv2.VideoCapture(camera_idx)
@@ -121,19 +127,20 @@ def save_images_from_cameras(
width=None,
height=None,
record_time_s=2,
mock=False,
):
"""
Initializes all the cameras and saves images to the directory. Useful to visually identify the camera
associated to a given camera index.
"""
if camera_ids is None or len(camera_ids) == 0:
camera_infos = find_cameras()
camera_infos = find_cameras(mock=mock)
camera_ids = [cam["index"] for cam in camera_infos]
print("Connecting cameras")
cameras = []
for cam_idx in camera_ids:
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height)
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
camera = OpenCVCamera(config)
camera.connect()
print(
@@ -183,7 +190,7 @@ def save_images_from_cameras(
print(f"Images have been saved to {images_dir}")
class OpenCVCamera(Camera):
class OpenCVCamera:
"""
The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate
with the cameras. Most cameras are compatible. For more info, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
@@ -252,6 +259,7 @@ class OpenCVCamera(Camera):
self.fps = config.fps
self.channels = config.channels
self.color_mode = config.color_mode
self.mock = config.mock
self.camera = None
self.is_connected = False
@@ -260,6 +268,11 @@ class OpenCVCamera(Camera):
self.color_image = None
self.logs = {}
if self.mock:
import tests.cameras.mock_cv2 as cv2
else:
import cv2
self.rotation = None
if config.rotation == -90:
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
@@ -270,11 +283,16 @@ class OpenCVCamera(Camera):
def connect(self):
if self.is_connected:
raise DeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
# Use 1 thread to avoid blocking the main thread. Especially useful during data collection
# when other threads are used to save the images.
cv2.setNumThreads(1)
if self.mock:
import tests.cameras.mock_cv2 as cv2
else:
import cv2
# Use 1 thread to avoid blocking the main thread. Especially useful during data collection
# when other threads are used to save the images.
cv2.setNumThreads(1)
backend = (
cv2.CAP_V4L2
@@ -287,11 +305,17 @@ class OpenCVCamera(Camera):
)
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
# First create a temporary camera trying to access `camera_index`,
# and verify it is a valid camera by calling `isOpened`.
tmp_camera = cv2.VideoCapture(camera_idx, backend)
is_camera_open = tmp_camera.isOpened()
# Release camera to make it accessible for `find_camera_indices`
tmp_camera.release()
del tmp_camera
self.camera = cv2.VideoCapture(camera_idx, backend)
if not self.camera.isOpened():
self.camera.release() # Release the failed attempt
# If the camera doesn't work, display the camera indices corresponding to
# valid cameras.
if not is_camera_open:
# Verify that the provided `camera_index` is valid before printing the traceback
cameras_info = find_cameras()
available_cam_ids = [cam["index"] for cam in cameras_info]
@@ -303,6 +327,11 @@ class OpenCVCamera(Camera):
raise OSError(f"Can't access OpenCVCamera({camera_idx}).")
# Secondly, create the camera that will be used downstream.
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
# needs to be re-created.
self.camera = cv2.VideoCapture(camera_idx, backend)
if self.fps is not None:
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
if self.capture_width is not None:
@@ -346,7 +375,7 @@ class OpenCVCamera(Camera):
If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`.
"""
if not self.is_connected:
raise DeviceNotConnectedError(
raise RobotDeviceNotConnectedError(
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
)
@@ -368,6 +397,11 @@ class OpenCVCamera(Camera):
# However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks,
# so we convert the image color from BGR to RGB.
if requested_color_mode == "rgb":
if self.mock:
import tests.cameras.mock_cv2 as cv2
else:
import cv2
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
h, w, _ = color_image.shape
@@ -398,13 +432,13 @@ class OpenCVCamera(Camera):
def async_read(self):
if not self.is_connected:
raise DeviceNotConnectedError(
raise RobotDeviceNotConnectedError(
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
)
if self.thread is None:
self.stop_event = threading.Event()
self.thread = threading.Thread(target=self.read_loop, args=())
self.thread = Thread(target=self.read_loop, args=())
self.thread.daemon = True
self.thread.start()
@@ -420,7 +454,7 @@ class OpenCVCamera(Camera):
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(
raise RobotDeviceNotConnectedError(
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
)
@@ -458,13 +492,13 @@ if __name__ == "__main__":
)
parser.add_argument(
"--width",
type=str,
type=int,
default=None,
help="Set the width for all cameras. If not provided, use the default width of each camera.",
)
parser.add_argument(
"--height",
type=str,
type=int,
default=None,
help="Set the height for all cameras. If not provided, use the default height of each camera.",
)

View File

@@ -0,0 +1,71 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Protocol
import numpy as np
from lerobot.common.robot_devices.cameras.configs import (
CameraConfig,
IntelRealSenseCameraConfig,
OpenCVCameraConfig,
)
# Defines a camera type
class Camera(Protocol):
def connect(self): ...
def read(self, temporary_color: str | None = None) -> np.ndarray: ...
def async_read(self) -> np.ndarray: ...
def disconnect(self): ...
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[Camera]:
cameras = {}
for key, cfg in camera_configs.items():
if cfg.type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
cameras[key] = OpenCVCamera(cfg)
elif cfg.type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import (
IntelRealSenseCamera,
)
cameras[key] = IntelRealSenseCamera(cfg)
else:
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
return cameras
def make_camera(camera_type, **kwargs) -> Camera:
if camera_type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
config = OpenCVCameraConfig(**kwargs)
return OpenCVCamera(config)
elif camera_type == "intelrealsense":
from lerobot.common.robot_devices.cameras.intelrealsense import (
IntelRealSenseCamera,
)
config = IntelRealSenseCameraConfig(**kwargs)
return IntelRealSenseCamera(config)
else:
raise ValueError(f"The camera type '{camera_type}' is not valid.")

View File

@@ -17,7 +17,7 @@ from pathlib import Path
import draccus
from lerobot.common.robots import RobotConfig
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
@@ -87,6 +87,8 @@ class RecordControlConfig(ControlConfig):
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# Reset follower arms to an initial configuration.
reset_follower_arms: bool = True
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.

View File

@@ -25,6 +25,7 @@ from copy import copy
from functools import cache
import rerun as rr
import numpy as np
import torch
from deepdiff import DeepDiff
from termcolor import colored
@@ -33,8 +34,8 @@ from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.robots.utils import Robot
from lerobot.common.utils.robot_utils import busy_wait
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, has_method
@@ -128,14 +129,22 @@ def predict_action(observation, policy, device, use_amp):
return action
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
def init_keyboard_listener(assign_rewards=False):
"""
Initializes a keyboard listener to enable early termination of an episode
or environment reset by pressing the right arrow key ('->'). This may require
sudo permissions to allow the terminal to monitor keyboard events.
Args:
assign_rewards (bool): If True, allows annotating the collected trajectory
with a binary reward at the end of the episode to indicate success.
"""
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if assign_rewards:
events["next.reward"] = 0
if is_headless():
logging.warning(
@@ -160,6 +169,13 @@ def init_keyboard_listener():
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
elif assign_rewards and key == keyboard.Key.space:
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
print(
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
events["next.reward"],
)
except Exception as e:
print(f"Error handling key press: {e}")
@@ -246,6 +262,8 @@ def control_loop(
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
if teleoperate:
observation, action = robot.teleop_step(record_data=True)
else:
@@ -253,7 +271,10 @@ def control_loop(
if policy is not None:
pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
observation,
policy,
get_safe_torch_device(policy.config.device),
policy.config.use_amp,
)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
@@ -301,7 +322,17 @@ def reset_environment(robot, events, reset_time_s, fps):
)
def stop_recording(robot, listener, display_data):
def reset_follower_position(robot: Robot, target_position):
current_position = robot.follower_arms["main"].read("Present_Position")
trajectory = torch.from_numpy(
np.linspace(current_position, target_position, 50)
) # NOTE: 30 is just an aribtrary number
for pose in trajectory:
robot.send_action(pose)
busy_wait(0.015)
def stop_recording(robot, listener, display_cameras):
robot.disconnect()
if not is_headless() and listener is not None:
@@ -327,12 +358,20 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
dataset: LeRobotDataset,
robot: Robot,
fps: int,
use_videos: bool,
extra_features: dict = None,
) -> None:
features_from_robot = get_features_from_robot(robot, use_videos)
if extra_features is not None:
features_from_robot.update(extra_features)
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
("features", dataset.features, get_features_from_robot(robot, use_videos)),
("features", dataset.features, features_from_robot),
]
mismatches = []

View File

@@ -0,0 +1,881 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
import math
import time
import traceback
from copy import deepcopy
import numpy as np
import tqdm
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.common.utils.utils import capture_timestamp_utc
PROTOCOL_VERSION = 2.0
BAUDRATE = 1_000_000
TIMEOUT_MS = 1000
MAX_ID_RANGE = 252
# The following bounds define the lower and upper joints range (after calibration).
# For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees
# which corresponds to a half rotation on the left and half rotation on the right.
# Some joints might require higher range, so we allow up to [-270, 270] degrees until
# an error is raised.
LOWER_BOUND_DEGREE = -270
UPPER_BOUND_DEGREE = 270
# For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper),
# their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully
# closed, and 100% is fully open. To account for slight calibration issue, we allow up to
# [-10, 110] until an error is raised.
LOWER_BOUND_LINEAR = -10
UPPER_BOUND_LINEAR = 110
HALF_TURN_DEGREE = 180
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m077
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m288
# https://emanual.robotis.com/docs/en/dxl/x/xl430-w250
# https://emanual.robotis.com/docs/en/dxl/x/xm430-w350
# https://emanual.robotis.com/docs/en/dxl/x/xm540-w270
# https://emanual.robotis.com/docs/en/dxl/x/xc430-w150
# data_name: (address, size_byte)
X_SERIES_CONTROL_TABLE = {
"Model_Number": (0, 2),
"Model_Information": (2, 4),
"Firmware_Version": (6, 1),
"ID": (7, 1),
"Baud_Rate": (8, 1),
"Return_Delay_Time": (9, 1),
"Drive_Mode": (10, 1),
"Operating_Mode": (11, 1),
"Secondary_ID": (12, 1),
"Protocol_Type": (13, 1),
"Homing_Offset": (20, 4),
"Moving_Threshold": (24, 4),
"Temperature_Limit": (31, 1),
"Max_Voltage_Limit": (32, 2),
"Min_Voltage_Limit": (34, 2),
"PWM_Limit": (36, 2),
"Current_Limit": (38, 2),
"Acceleration_Limit": (40, 4),
"Velocity_Limit": (44, 4),
"Max_Position_Limit": (48, 4),
"Min_Position_Limit": (52, 4),
"Shutdown": (63, 1),
"Torque_Enable": (64, 1),
"LED": (65, 1),
"Status_Return_Level": (68, 1),
"Registered_Instruction": (69, 1),
"Hardware_Error_Status": (70, 1),
"Velocity_I_Gain": (76, 2),
"Velocity_P_Gain": (78, 2),
"Position_D_Gain": (80, 2),
"Position_I_Gain": (82, 2),
"Position_P_Gain": (84, 2),
"Feedforward_2nd_Gain": (88, 2),
"Feedforward_1st_Gain": (90, 2),
"Bus_Watchdog": (98, 1),
"Goal_PWM": (100, 2),
"Goal_Current": (102, 2),
"Goal_Velocity": (104, 4),
"Profile_Acceleration": (108, 4),
"Profile_Velocity": (112, 4),
"Goal_Position": (116, 4),
"Realtime_Tick": (120, 2),
"Moving": (122, 1),
"Moving_Status": (123, 1),
"Present_PWM": (124, 2),
"Present_Current": (126, 2),
"Present_Velocity": (128, 4),
"Present_Position": (132, 4),
"Velocity_Trajectory": (136, 4),
"Position_Trajectory": (140, 4),
"Present_Input_Voltage": (144, 2),
"Present_Temperature": (146, 1),
}
X_SERIES_BAUDRATE_TABLE = {
0: 9_600,
1: 57_600,
2: 115_200,
3: 1_000_000,
4: 2_000_000,
5: 3_000_000,
6: 4_000_000,
}
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
MODEL_CONTROL_TABLE = {
"x_series": X_SERIES_CONTROL_TABLE,
"xl330-m077": X_SERIES_CONTROL_TABLE,
"xl330-m288": X_SERIES_CONTROL_TABLE,
"xl430-w250": X_SERIES_CONTROL_TABLE,
"xm430-w350": X_SERIES_CONTROL_TABLE,
"xm540-w270": X_SERIES_CONTROL_TABLE,
"xc430-w150": X_SERIES_CONTROL_TABLE,
}
MODEL_RESOLUTION = {
"x_series": 4096,
"xl330-m077": 4096,
"xl330-m288": 4096,
"xl430-w250": 4096,
"xm430-w350": 4096,
"xm540-w270": 4096,
"xc430-w150": 4096,
}
MODEL_BAUDRATE_TABLE = {
"x_series": X_SERIES_BAUDRATE_TABLE,
"xl330-m077": X_SERIES_BAUDRATE_TABLE,
"xl330-m288": X_SERIES_BAUDRATE_TABLE,
"xl430-w250": X_SERIES_BAUDRATE_TABLE,
"xm430-w350": X_SERIES_BAUDRATE_TABLE,
"xm540-w270": X_SERIES_BAUDRATE_TABLE,
"xc430-w150": X_SERIES_BAUDRATE_TABLE,
}
NUM_READ_RETRY = 10
NUM_WRITE_RETRY = 10
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
"""
resolutions = [MODEL_RESOLUTION[model] for model in models]
steps = degrees / 180 * np.array(resolutions) / 2
steps = steps.astype(int)
return steps
def convert_to_bytes(value, bytes, mock=False):
if mock:
return value
import dynamixel_sdk as dxl
# Note: No need to convert back into unsigned int, since this byte preprocessing
# already handles it for us.
if bytes == 1:
data = [
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
]
elif bytes == 2:
data = [
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
]
elif bytes == 4:
data = [
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
]
else:
raise NotImplementedError(
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
f"{bytes} is provided instead."
)
return data
def get_group_sync_key(data_name, motor_names):
group_key = f"{data_name}_" + "_".join(motor_names)
return group_key
def get_result_name(fn_name, data_name, motor_names):
group_key = get_group_sync_key(data_name, motor_names)
rslt_name = f"{fn_name}_{group_key}"
return rslt_name
def get_queue_name(fn_name, data_name, motor_names):
group_key = get_group_sync_key(data_name, motor_names)
queue_name = f"{fn_name}_{group_key}"
return queue_name
def get_log_name(var_name, fn_name, data_name, motor_names):
group_key = get_group_sync_key(data_name, motor_names)
log_name = f"{var_name}_{fn_name}_{group_key}"
return log_name
def assert_same_address(model_ctrl_table, motor_models, data_name):
all_addr = []
all_bytes = []
for model in motor_models:
addr, bytes = model_ctrl_table[model][data_name]
all_addr.append(addr)
all_bytes.append(bytes)
if len(set(all_addr)) != 1:
raise NotImplementedError(
f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
)
if len(set(all_bytes)) != 1:
raise NotImplementedError(
f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
)
class TorqueMode(enum.Enum):
ENABLED = 1
DISABLED = 0
class DriveMode(enum.Enum):
NON_INVERTED = 0
INVERTED = 1
class CalibrationMode(enum.Enum):
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
DEGREE = 0
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
LINEAR = 1
class JointOutOfRangeError(Exception):
def __init__(self, message="Joint is out of range"):
self.message = message
super().__init__(self.message)
class DynamixelMotorsBus:
"""
The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on
the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
A DynamixelMotorsBus instance requires a port (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script:
```bash
python lerobot/scripts/find_motors_bus_port.py
>>> Finding all available ports for the MotorBus.
>>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
>>> Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
>>> The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751.
>>> Reconnect the usb cable.
```
Example of usage for 1 motor connected to the bus:
```python
motor_name = "gripper"
motor_index = 6
motor_model = "xl330-m288"
config = DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem575E0031751",
motors={motor_name: (motor_index, motor_model)},
)
motors_bus = DynamixelMotorsBus(config)
motors_bus.connect()
position = motors_bus.read("Present_Position")
# move from a few motor steps as an example
few_steps = 30
motors_bus.write("Goal_Position", position + few_steps)
# when done, consider disconnecting
motors_bus.disconnect()
```
"""
def __init__(
self,
config: DynamixelMotorsBusConfig,
):
self.port = config.port
self.motors = config.motors
self.mock = config.mock
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
self.model_resolution = deepcopy(MODEL_RESOLUTION)
self.port_handler = None
self.packet_handler = None
self.calibration = None
self.is_connected = False
self.group_readers = {}
self.group_writers = {}
self.logs = {}
def connect(self):
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(
f"DynamixelMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
)
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
self.port_handler = dxl.PortHandler(self.port)
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
try:
if not self.port_handler.openPort():
raise OSError(f"Failed to open port '{self.port}'.")
except Exception:
traceback.print_exc()
print(
"\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n"
)
raise
# Allow to read and write
self.is_connected = True
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
def reconnect(self):
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
self.port_handler = dxl.PortHandler(self.port)
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
if not self.port_handler.openPort():
raise OSError(f"Failed to open port '{self.port}'.")
self.is_connected = True
def are_motors_configured(self):
# Only check the motor indices and not baudrate, since if the motor baudrates are incorrect,
# a ConnectionError will be raised anyway.
try:
return (self.motor_indices == self.read("ID")).all()
except ConnectionError as e:
print(e)
return False
def find_motor_indices(self, possible_ids=None, num_retry=2):
if possible_ids is None:
possible_ids = range(MAX_ID_RANGE)
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
except ConnectionError:
continue
if idx != present_idx:
# sanity check
raise OSError(
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
)
indices.append(idx)
return indices
def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
raise OSError("Failed to write bus baud rate.")
@property
def motor_names(self) -> list[str]:
return list(self.motors.keys())
@property
def motor_models(self) -> list[str]:
return [model for _, model in self.motors.values()]
@property
def motor_indices(self) -> list[int]:
return [idx for idx, _ in self.motors.values()]
def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
"""
try:
values = self.apply_calibration(values, motor_names)
except JointOutOfRangeError as e:
print(e)
self.autocorrect_calibration(values, motor_names)
values = self.apply_calibration(values, motor_names)
return values
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree.
Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor
rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range.
Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
when given a goal position that is + or - their resolution. For instance, dynamixel xl330-m077 have a resolution of 4096, and
at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor.
To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
in the centered nominal degree range ]-180, 180[.
"""
if motor_names is None:
motor_names = self.motor_names
# Convert from unsigned int32 original range [0, 2**32] to signed float32 range
values = values.astype(np.float32)
for i, name in enumerate(motor_names):
calib_idx = self.calibration["motor_names"].index(name)
calib_mode = self.calibration["calib_mode"][calib_idx]
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
drive_mode = self.calibration["drive_mode"][calib_idx]
homing_offset = self.calibration["homing_offset"][calib_idx]
_, model = self.motors[name]
resolution = self.model_resolution[model]
# Update direction of rotation of the motor to match between leader and follower.
# In fact, the motor of the leader for a given joint can be assembled in an
# opposite direction in term of rotation than the motor of the follower on the same joint.
if drive_mode:
values[i] *= -1
# Convert from range [-2**31, 2**31] to
# nominal range [-resolution//2, resolution//2] (e.g. [-2048, 2048])
values[i] += homing_offset
# Convert from range [-resolution//2, resolution//2] to
# universal float32 centered degree range [-180, 180]
# (e.g. 2048 / (4096 // 2) * 180 = 180)
values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE
if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE):
raise JointOutOfRangeError(
f"Wrong motor position range detected for {name}. "
f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), "
f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, "
f"but present value is {values[i]} degree. "
"This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. "
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
)
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
start_pos = self.calibration["start_pos"][calib_idx]
end_pos = self.calibration["end_pos"][calib_idx]
# Rescale the present position to a nominal range [0, 100] %,
# useful for joints with linear motions like Aloha gripper
values[i] = (values[i] - start_pos) / (end_pos - start_pos) * 100
if (values[i] < LOWER_BOUND_LINEAR) or (values[i] > UPPER_BOUND_LINEAR):
raise JointOutOfRangeError(
f"Wrong motor position range detected for {name}. "
f"Expected to be in nominal range of [0, 100] % (a full linear translation), "
f"with a maximum range of [{LOWER_BOUND_LINEAR}, {UPPER_BOUND_LINEAR}] % to account for some imprecision during calibration, "
f"but present value is {values[i]} %. "
"This might be due to a cable connection issue creating an artificial jump in motor values. "
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
)
return values
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration.
For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given
a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position.
Known issues:
#1: Motor value randomly shifts of a full turn, caused by hardware/connection errors.
#2: Motor internal homing offset is shifted by a full turn, caused by using default calibration (e.g Aloha).
#3: motor internal homing offset is shifted by less or more than a full turn, caused by using default calibration
or by human error during manual calibration.
Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn.
Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`,
that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue.
Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
"""
if motor_names is None:
motor_names = self.motor_names
# Convert from unsigned int32 original range [0, 2**32] to signed float32 range
values = values.astype(np.float32)
for i, name in enumerate(motor_names):
calib_idx = self.calibration["motor_names"].index(name)
calib_mode = self.calibration["calib_mode"][calib_idx]
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
drive_mode = self.calibration["drive_mode"][calib_idx]
homing_offset = self.calibration["homing_offset"][calib_idx]
_, model = self.motors[name]
resolution = self.model_resolution[model]
# Update direction of rotation of the motor to match between leader and follower.
# In fact, the motor of the leader for a given joint can be assembled in an
# opposite direction in term of rotation than the motor of the follower on the same joint.
if drive_mode:
values[i] *= -1
# Convert from initial range to range [-180, 180] degrees
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
start_pos = self.calibration["start_pos"][calib_idx]
end_pos = self.calibration["end_pos"][calib_idx]
# Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
# Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100
# 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
low_factor = (start_pos - values[i]) / resolution
upp_factor = (end_pos - values[i]) / resolution
if not in_range:
# Get first integer between the two bounds
if low_factor < upp_factor:
factor = math.ceil(low_factor)
if factor > upp_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
else:
factor = math.ceil(upp_factor)
if factor > low_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
f"from '{out_of_range_str}' to '{in_range_str}'."
)
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""Inverse of `apply_calibration`."""
if motor_names is None:
motor_names = self.motor_names
for i, name in enumerate(motor_names):
calib_idx = self.calibration["motor_names"].index(name)
calib_mode = self.calibration["calib_mode"][calib_idx]
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
drive_mode = self.calibration["drive_mode"][calib_idx]
homing_offset = self.calibration["homing_offset"][calib_idx]
_, model = self.motors[name]
resolution = self.model_resolution[model]
# Convert from nominal 0-centered degree range [-180, 180] to
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
# Subtract the homing offsets to come back to actual motor range of values
# which can be arbitrary.
values[i] -= homing_offset
# Remove drive mode, which is the rotation direction of the motor, to come back to
# actual motor rotation direction which can be arbitrary.
if drive_mode:
values[i] *= -1
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
start_pos = self.calibration["start_pos"][calib_idx]
end_pos = self.calibration["end_pos"][calib_idx]
# Convert from nominal lnear range of [0, 100] % to
# actual motor range of values which can be arbitrary.
values[i] = values[i] / 100 * (end_pos - start_pos) + start_pos
values = np.round(values).astype(np.int32)
return values
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
return_list = True
if not isinstance(motor_ids, list):
return_list = False
motor_ids = [motor_ids]
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
for idx in motor_ids:
group.addParam(idx)
for _ in range(num_retry):
comm = group.txRxPacket()
if comm == dxl.COMM_SUCCESS:
break
if comm != dxl.COMM_SUCCESS:
raise ConnectionError(
f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
values = []
for idx in motor_ids:
value = group.getData(idx, addr, bytes)
values.append(value)
if return_list:
return values
else:
return values[0]
def read(self, data_name, motor_names: str | list[str] | None = None):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
)
start_time = time.perf_counter()
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
if motor_names is None:
motor_names = self.motor_names
if isinstance(motor_names, str):
motor_names = [motor_names]
motor_ids = []
models = []
for name in motor_names:
motor_idx, model = self.motors[name]
motor_ids.append(motor_idx)
models.append(model)
assert_same_address(self.model_ctrl_table, models, data_name)
addr, bytes = self.model_ctrl_table[model][data_name]
group_key = get_group_sync_key(data_name, motor_names)
if data_name not in self.group_readers:
# create new group reader
self.group_readers[group_key] = dxl.GroupSyncRead(
self.port_handler, self.packet_handler, addr, bytes
)
for idx in motor_ids:
self.group_readers[group_key].addParam(idx)
for _ in range(NUM_READ_RETRY):
comm = self.group_readers[group_key].txRxPacket()
if comm == dxl.COMM_SUCCESS:
break
if comm != dxl.COMM_SUCCESS:
raise ConnectionError(
f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
values = []
for idx in motor_ids:
value = self.group_readers[group_key].getData(idx, addr, bytes)
values.append(value)
values = np.array(values)
# Convert to signed int to use range [-2048, 2048] for our motor positions.
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
values = values.astype(np.int32)
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
self.logs[ts_utc_name] = capture_timestamp_utc()
return values
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
if not isinstance(motor_ids, list):
motor_ids = [motor_ids]
if not isinstance(values, list):
values = [values]
assert_same_address(self.model_ctrl_table, motor_models, data_name)
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
for idx, value in zip(motor_ids, values, strict=True):
data = convert_to_bytes(value, bytes, self.mock)
group.addParam(idx, data)
for _ in range(num_retry):
comm = group.txPacket()
if comm == dxl.COMM_SUCCESS:
break
if comm != dxl.COMM_SUCCESS:
raise ConnectionError(
f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
def write(
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
)
start_time = time.perf_counter()
if self.mock:
import tests.motors.mock_dynamixel_sdk as dxl
else:
import dynamixel_sdk as dxl
if motor_names is None:
motor_names = self.motor_names
if isinstance(motor_names, str):
motor_names = [motor_names]
if isinstance(values, (int, float, np.integer)):
values = [int(values)] * len(motor_names)
values = np.array(values)
motor_ids = []
models = []
for name in motor_names:
motor_idx, model = self.motors[name]
motor_ids.append(motor_idx)
models.append(model)
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
values = self.revert_calibration(values, motor_names)
values = values.tolist()
assert_same_address(self.model_ctrl_table, models, data_name)
addr, bytes = self.model_ctrl_table[model][data_name]
group_key = get_group_sync_key(data_name, motor_names)
init_group = data_name not in self.group_readers
if init_group:
self.group_writers[group_key] = dxl.GroupSyncWrite(
self.port_handler, self.packet_handler, addr, bytes
)
for idx, value in zip(motor_ids, values, strict=True):
data = convert_to_bytes(value, bytes, self.mock)
if init_group:
self.group_writers[group_key].addParam(idx, data)
else:
self.group_writers[group_key].changeParam(idx, data)
comm = self.group_writers[group_key].txPacket()
if comm != dxl.COMM_SUCCESS:
raise ConnectionError(
f"Write failed due to communication error on port {self.port} for group_key {group_key}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?
# log the utc time when the write has been completed
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
self.logs[ts_utc_name] = capture_timestamp_utc()
def disconnect(self):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"DynamixelMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first."
)
if self.port_handler is not None:
self.port_handler.closePort()
self.port_handler = None
self.packet_handler = None
self.group_readers = {}
self.group_writers = {}
self.is_connected = False
def __del__(self):
if getattr(self, "is_connected", False):
self.disconnect()

View File

@@ -0,0 +1,906 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
import math
import time
import traceback
from copy import deepcopy
import numpy as np
import tqdm
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.common.utils.utils import capture_timestamp_utc
PROTOCOL_VERSION = 0
BAUDRATE = 1_000_000
TIMEOUT_MS = 1000
MAX_ID_RANGE = 252
# The following bounds define the lower and upper joints range (after calibration).
# For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees
# which corresponds to a half rotation on the left and half rotation on the right.
# Some joints might require higher range, so we allow up to [-270, 270] degrees until
# an error is raised.
LOWER_BOUND_DEGREE = -270
UPPER_BOUND_DEGREE = 270
# For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper),
# their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully
# closed, and 100% is fully open. To account for slight calibration issue, we allow up to
# [-10, 110] until an error is raised.
LOWER_BOUND_LINEAR = -10
UPPER_BOUND_LINEAR = 110
HALF_TURN_DEGREE = 180
# See this link for STS3215 Memory Table:
# https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true
# data_name: (address, size_byte)
SCS_SERIES_CONTROL_TABLE = {
"Model": (3, 2),
"ID": (5, 1),
"Baud_Rate": (6, 1),
"Return_Delay": (7, 1),
"Response_Status_Level": (8, 1),
"Min_Angle_Limit": (9, 2),
"Max_Angle_Limit": (11, 2),
"Max_Temperature_Limit": (13, 1),
"Max_Voltage_Limit": (14, 1),
"Min_Voltage_Limit": (15, 1),
"Max_Torque_Limit": (16, 2),
"Phase": (18, 1),
"Unloading_Condition": (19, 1),
"LED_Alarm_Condition": (20, 1),
"P_Coefficient": (21, 1),
"D_Coefficient": (22, 1),
"I_Coefficient": (23, 1),
"Minimum_Startup_Force": (24, 2),
"CW_Dead_Zone": (26, 1),
"CCW_Dead_Zone": (27, 1),
"Protection_Current": (28, 2),
"Angular_Resolution": (30, 1),
"Offset": (31, 2),
"Mode": (33, 1),
"Protective_Torque": (34, 1),
"Protection_Time": (35, 1),
"Overload_Torque": (36, 1),
"Speed_closed_loop_P_proportional_coefficient": (37, 1),
"Over_Current_Protection_Time": (38, 1),
"Velocity_closed_loop_I_integral_coefficient": (39, 1),
"Torque_Enable": (40, 1),
"Acceleration": (41, 1),
"Goal_Position": (42, 2),
"Goal_Time": (44, 2),
"Goal_Speed": (46, 2),
"Torque_Limit": (48, 2),
"Lock": (55, 1),
"Present_Position": (56, 2),
"Present_Speed": (58, 2),
"Present_Load": (60, 2),
"Present_Voltage": (62, 1),
"Present_Temperature": (63, 1),
"Status": (65, 1),
"Moving": (66, 1),
"Present_Current": (69, 2),
# Not in the Memory Table
"Maximum_Acceleration": (85, 2),
}
SCS_SERIES_BAUDRATE_TABLE = {
0: 1_000_000,
1: 500_000,
2: 250_000,
3: 128_000,
4: 115_200,
5: 57_600,
6: 38_400,
7: 19_200,
}
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
MODEL_CONTROL_TABLE = {
"scs_series": SCS_SERIES_CONTROL_TABLE,
"sts3215": SCS_SERIES_CONTROL_TABLE,
}
MODEL_RESOLUTION = {
"scs_series": 4096,
"sts3215": 4096,
}
MODEL_BAUDRATE_TABLE = {
"scs_series": SCS_SERIES_BAUDRATE_TABLE,
"sts3215": SCS_SERIES_BAUDRATE_TABLE,
}
# High number of retries is needed for feetech compared to dynamixel motors.
NUM_READ_RETRY = 20
NUM_WRITE_RETRY = 20
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
"""
resolutions = [MODEL_RESOLUTION[model] for model in models]
steps = degrees / 180 * np.array(resolutions) / 2
steps = steps.astype(int)
return steps
def convert_to_bytes(value, bytes, mock=False):
if mock:
return value
import scservo_sdk as scs
# Note: No need to convert back into unsigned int, since this byte preprocessing
# already handles it for us.
if bytes == 1:
data = [
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
]
elif bytes == 2:
data = [
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
]
elif bytes == 4:
data = [
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
]
else:
raise NotImplementedError(
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
f"{bytes} is provided instead."
)
return data
def get_group_sync_key(data_name, motor_names):
group_key = f"{data_name}_" + "_".join(motor_names)
return group_key
def get_result_name(fn_name, data_name, motor_names):
group_key = get_group_sync_key(data_name, motor_names)
rslt_name = f"{fn_name}_{group_key}"
return rslt_name
def get_queue_name(fn_name, data_name, motor_names):
group_key = get_group_sync_key(data_name, motor_names)
queue_name = f"{fn_name}_{group_key}"
return queue_name
def get_log_name(var_name, fn_name, data_name, motor_names):
group_key = get_group_sync_key(data_name, motor_names)
log_name = f"{var_name}_{fn_name}_{group_key}"
return log_name
def assert_same_address(model_ctrl_table, motor_models, data_name):
all_addr = []
all_bytes = []
for model in motor_models:
addr, bytes = model_ctrl_table[model][data_name]
all_addr.append(addr)
all_bytes.append(bytes)
if len(set(all_addr)) != 1:
raise NotImplementedError(
f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
)
if len(set(all_bytes)) != 1:
raise NotImplementedError(
f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
)
class TorqueMode(enum.Enum):
ENABLED = 1
DISABLED = 0
class DriveMode(enum.Enum):
NON_INVERTED = 0
INVERTED = 1
class CalibrationMode(enum.Enum):
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
DEGREE = 0
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
LINEAR = 1
class JointOutOfRangeError(Exception):
def __init__(self, message="Joint is out of range"):
self.message = message
super().__init__(self.message)
class FeetechMotorsBus:
"""
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on
the python feetech sdk to communicate with the motors. For more info, see the [feetech SDK Documentation](https://emanual.robotis.com/docs/en/software/feetech/feetech_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
A FeetechMotorsBus instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
To find the port, you can run our utility script:
```bash
python lerobot/scripts/find_motors_bus_port.py
>>> Finding all available ports for the MotorsBus.
>>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
>>> Remove the usb cable from your FeetechMotorsBus and press Enter when done.
>>> The port of this FeetechMotorsBus is /dev/tty.usbmodem575E0031751.
>>> Reconnect the usb cable.
```
Example of usage for 1 motor connected to the bus:
```python
motor_name = "gripper"
motor_index = 6
motor_model = "sts3215"
config = FeetechMotorsBusConfig(
port="/dev/tty.usbmodem575E0031751",
motors={motor_name: (motor_index, motor_model)},
)
motors_bus = FeetechMotorsBus(config)
motors_bus.connect()
position = motors_bus.read("Present_Position")
# move from a few motor steps as an example
few_steps = 30
motors_bus.write("Goal_Position", position + few_steps)
# when done, consider disconnecting
motors_bus.disconnect()
```
"""
def __init__(
self,
config: FeetechMotorsBusConfig,
):
self.port = config.port
self.motors = config.motors
self.mock = config.mock
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
self.model_resolution = deepcopy(MODEL_RESOLUTION)
self.port_handler = None
self.packet_handler = None
self.calibration = None
self.is_connected = False
self.group_readers = {}
self.group_writers = {}
self.logs = {}
self.track_positions = {}
def connect(self):
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(
f"FeetechMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
)
if self.mock:
import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
self.port_handler = scs.PortHandler(self.port)
self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION)
try:
if not self.port_handler.openPort():
raise OSError(f"Failed to open port '{self.port}'.")
except Exception:
traceback.print_exc()
print(
"\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n"
)
raise
# Allow to read and write
self.is_connected = True
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
def reconnect(self):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
self.port_handler = scs.PortHandler(self.port)
self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION)
if not self.port_handler.openPort():
raise OSError(f"Failed to open port '{self.port}'.")
self.is_connected = True
def are_motors_configured(self):
# Only check the motor indices and not baudrate, since if the motor baudrates are incorrect,
# a ConnectionError will be raised anyway.
try:
return (self.motor_indices == self.read("ID")).all()
except ConnectionError as e:
print(e)
return False
def find_motor_indices(self, possible_ids=None, num_retry=2):
if possible_ids is None:
possible_ids = range(MAX_ID_RANGE)
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
except ConnectionError:
continue
if idx != present_idx:
# sanity check
raise OSError(
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
)
indices.append(idx)
return indices
def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
raise OSError("Failed to write bus baud rate.")
@property
def motor_names(self) -> list[str]:
return list(self.motors.keys())
@property
def motor_models(self) -> list[str]:
return [model for _, model in self.motors.values()]
@property
def motor_indices(self) -> list[int]:
return [idx for idx, _ in self.motors.values()]
def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
"""
try:
values = self.apply_calibration(values, motor_names)
except JointOutOfRangeError as e:
print(e)
self.autocorrect_calibration(values, motor_names)
values = self.apply_calibration(values, motor_names)
return values
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree.
Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor
rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range.
Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
when given a goal position that is + or - their resolution. For instance, feetech xl330-m077 have a resolution of 4096, and
at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor.
To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
in the centered nominal degree range ]-180, 180[.
"""
if motor_names is None:
motor_names = self.motor_names
# Convert from unsigned int32 original range [0, 2**32] to signed float32 range
values = values.astype(np.float32)
for i, name in enumerate(motor_names):
calib_idx = self.calibration["motor_names"].index(name)
calib_mode = self.calibration["calib_mode"][calib_idx]
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
drive_mode = self.calibration["drive_mode"][calib_idx]
homing_offset = self.calibration["homing_offset"][calib_idx]
_, model = self.motors[name]
resolution = self.model_resolution[model]
# Update direction of rotation of the motor to match between leader and follower.
# In fact, the motor of the leader for a given joint can be assembled in an
# opposite direction in term of rotation than the motor of the follower on the same joint.
if drive_mode:
values[i] *= -1
# Convert from range [-2**31, 2**31[ to
# nominal range ]-resolution, resolution[ (e.g. ]-2048, 2048[)
values[i] += homing_offset
# Convert from range ]-resolution, resolution[ to
# universal float32 centered degree range ]-180, 180[
values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE
if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE):
raise JointOutOfRangeError(
f"Wrong motor position range detected for {name}. "
f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), "
f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, "
f"but present value is {values[i]} degree. "
"This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. "
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
)
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
start_pos = self.calibration["start_pos"][calib_idx]
end_pos = self.calibration["end_pos"][calib_idx]
# Rescale the present position to a nominal range [0, 100] %,
# useful for joints with linear motions like Aloha gripper
values[i] = (values[i] - start_pos) / (end_pos - start_pos) * 100
if (values[i] < LOWER_BOUND_LINEAR) or (values[i] > UPPER_BOUND_LINEAR):
raise JointOutOfRangeError(
f"Wrong motor position range detected for {name}. "
f"Expected to be in nominal range of [0, 100] % (a full linear translation), "
f"with a maximum range of [{LOWER_BOUND_LINEAR}, {UPPER_BOUND_LINEAR}] % to account for some imprecision during calibration, "
f"but present value is {values[i]} %. "
"This might be due to a cable connection issue creating an artificial jump in motor values. "
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
)
return values
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration.
For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given
a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position.
Known issues:
#1: Motor value randomly shifts of a full turn, caused by hardware/connection errors.
#2: Motor internal homing offset is shifted of a full turn, caused by using default calibration (e.g Aloha).
#3: motor internal homing offset is shifted of less or more than a full turn, caused by using default calibration
or by human error during manual calibration.
Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn.
Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`,
that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue.
Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
"""
if motor_names is None:
motor_names = self.motor_names
# Convert from unsigned int32 original range [0, 2**32] to signed float32 range
values = values.astype(np.float32)
for i, name in enumerate(motor_names):
calib_idx = self.calibration["motor_names"].index(name)
calib_mode = self.calibration["calib_mode"][calib_idx]
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
drive_mode = self.calibration["drive_mode"][calib_idx]
homing_offset = self.calibration["homing_offset"][calib_idx]
_, model = self.motors[name]
resolution = self.model_resolution[model]
if drive_mode:
values[i] *= -1
# Convert from initial range to range [-180, 180] degrees
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
) / resolution
upp_factor = (
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
start_pos = self.calibration["start_pos"][calib_idx]
end_pos = self.calibration["end_pos"][calib_idx]
# Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
# Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100
# 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
low_factor = (start_pos - values[i]) / resolution
upp_factor = (end_pos - values[i]) / resolution
if not in_range:
# Get first integer between the two bounds
if low_factor < upp_factor:
factor = math.ceil(low_factor)
if factor > upp_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
else:
factor = math.ceil(upp_factor)
if factor > low_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
f"from '{out_of_range_str}' to '{in_range_str}'."
)
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
"""Inverse of `apply_calibration`."""
if motor_names is None:
motor_names = self.motor_names
for i, name in enumerate(motor_names):
calib_idx = self.calibration["motor_names"].index(name)
calib_mode = self.calibration["calib_mode"][calib_idx]
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
drive_mode = self.calibration["drive_mode"][calib_idx]
homing_offset = self.calibration["homing_offset"][calib_idx]
_, model = self.motors[name]
resolution = self.model_resolution[model]
# Convert from nominal 0-centered degree range [-180, 180] to
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
# Subtract the homing offsets to come back to actual motor range of values
# which can be arbitrary.
values[i] -= homing_offset
# Remove drive mode, which is the rotation direction of the motor, to come back to
# actual motor rotation direction which can be arbitrary.
if drive_mode:
values[i] *= -1
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
start_pos = self.calibration["start_pos"][calib_idx]
end_pos = self.calibration["end_pos"][calib_idx]
# Convert from nominal lnear range of [0, 100] % to
# actual motor range of values which can be arbitrary.
values[i] = values[i] / 100 * (end_pos - start_pos) + start_pos
values = np.round(values).astype(np.int32)
return values
def avoid_rotation_reset(self, values, motor_names, data_name):
if data_name not in self.track_positions:
self.track_positions[data_name] = {
"prev": [None] * len(self.motor_names),
# Assume False at initialization
"below_zero": [False] * len(self.motor_names),
"above_max": [False] * len(self.motor_names),
}
track = self.track_positions[data_name]
if motor_names is None:
motor_names = self.motor_names
for i, name in enumerate(motor_names):
idx = self.motor_names.index(name)
if track["prev"][idx] is None:
track["prev"][idx] = values[i]
continue
# Detect a full rotation occurred
if abs(track["prev"][idx] - values[i]) > 2048:
# Position went below 0 and got reset to 4095
if track["prev"][idx] < values[i]:
# So we set negative value by adding a full rotation
values[i] -= 4096
# Position went above 4095 and got reset to 0
elif track["prev"][idx] > values[i]:
# So we add a full rotation
values[i] += 4096
track["prev"][idx] = values[i]
return values
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
return_list = True
if not isinstance(motor_ids, list):
return_list = False
motor_ids = [motor_ids]
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
for idx in motor_ids:
group.addParam(idx)
for _ in range(num_retry):
comm = group.txRxPacket()
if comm == scs.COMM_SUCCESS:
break
if comm != scs.COMM_SUCCESS:
raise ConnectionError(
f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
values = []
for idx in motor_ids:
value = group.getData(idx, addr, bytes)
values.append(value)
if return_list:
return values
else:
return values[0]
def read(self, data_name, motor_names: str | list[str] | None = None):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
)
start_time = time.perf_counter()
if motor_names is None:
motor_names = self.motor_names
if isinstance(motor_names, str):
motor_names = [motor_names]
motor_ids = []
models = []
for name in motor_names:
motor_idx, model = self.motors[name]
motor_ids.append(motor_idx)
models.append(model)
assert_same_address(self.model_ctrl_table, models, data_name)
addr, bytes = self.model_ctrl_table[model][data_name]
group_key = get_group_sync_key(data_name, motor_names)
if data_name not in self.group_readers:
# Very Important to flush the buffer!
self.port_handler.ser.reset_output_buffer()
self.port_handler.ser.reset_input_buffer()
# create new group reader
self.group_readers[group_key] = scs.GroupSyncRead(
self.port_handler, self.packet_handler, addr, bytes
)
for idx in motor_ids:
self.group_readers[group_key].addParam(idx)
for _ in range(NUM_READ_RETRY):
comm = self.group_readers[group_key].txRxPacket()
if comm == scs.COMM_SUCCESS:
break
if comm != scs.COMM_SUCCESS:
raise ConnectionError(
f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
values = []
for idx in motor_ids:
value = self.group_readers[group_key].getData(idx, addr, bytes)
values.append(value)
values = np.array(values)
# Convert to signed int to use range [-2048, 2048] for our motor positions.
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
values = values.astype(np.int32)
if data_name in CALIBRATION_REQUIRED:
values = self.avoid_rotation_reset(values, motor_names, data_name)
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
self.logs[ts_utc_name] = capture_timestamp_utc()
return values
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
if self.mock:
import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
if not isinstance(motor_ids, list):
motor_ids = [motor_ids]
if not isinstance(values, list):
values = [values]
assert_same_address(self.model_ctrl_table, motor_models, data_name)
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
for idx, value in zip(motor_ids, values, strict=True):
data = convert_to_bytes(value, bytes, self.mock)
group.addParam(idx, data)
for _ in range(num_retry):
comm = group.txPacket()
if comm == scs.COMM_SUCCESS:
break
if comm != scs.COMM_SUCCESS:
raise ConnectionError(
f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
def write(
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
)
start_time = time.perf_counter()
if self.mock:
import tests.motors.mock_scservo_sdk as scs
else:
import scservo_sdk as scs
if motor_names is None:
motor_names = self.motor_names
if isinstance(motor_names, str):
motor_names = [motor_names]
if isinstance(values, (int, float, np.integer)):
values = [int(values)] * len(motor_names)
values = np.array(values)
motor_ids = []
models = []
for name in motor_names:
motor_idx, model = self.motors[name]
motor_ids.append(motor_idx)
models.append(model)
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
values = self.revert_calibration(values, motor_names)
values = values.tolist()
assert_same_address(self.model_ctrl_table, models, data_name)
addr, bytes = self.model_ctrl_table[model][data_name]
group_key = get_group_sync_key(data_name, motor_names)
init_group = data_name not in self.group_readers
if init_group:
self.group_writers[group_key] = scs.GroupSyncWrite(
self.port_handler, self.packet_handler, addr, bytes
)
for idx, value in zip(motor_ids, values, strict=True):
data = convert_to_bytes(value, bytes, self.mock)
if init_group:
self.group_writers[group_key].addParam(idx, data)
else:
self.group_writers[group_key].changeParam(idx, data)
comm = self.group_writers[group_key].txPacket()
if comm != scs.COMM_SUCCESS:
raise ConnectionError(
f"Write failed due to communication error on port {self.port} for group_key {group_key}: "
f"{self.packet_handler.getTxRxResult(comm)}"
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?
# log the utc time when the write has been completed
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
self.logs[ts_utc_name] = capture_timestamp_utc()
def disconnect(self):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"FeetechMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first."
)
if self.port_handler is not None:
self.port_handler.closePort()
self.port_handler = None
self.packet_handler = None
self.group_readers = {}
self.group_writers = {}
self.is_connected = False
def __del__(self):
if getattr(self, "is_connected", False):
self.disconnect()

View File

@@ -12,21 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .configs import MotorsBusConfig
from .motors_bus import MotorsBus
from typing import Protocol
from lerobot.common.robot_devices.motors.configs import (
DynamixelMotorsBusConfig,
FeetechMotorsBusConfig,
MotorsBusConfig,
)
def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
class MotorsBus(Protocol):
def motor_names(self): ...
def set_calibration(self): ...
def apply_calibration(self): ...
def revert_calibration(self): ...
def read(self): ...
def write(self): ...
def make_motors_buses_from_configs(
motors_bus_configs: dict[str, MotorsBusConfig],
) -> list[MotorsBus]:
motors_buses = {}
for key, cfg in motors_bus_configs.items():
if cfg.type == "dynamixel":
from .dynamixel import DynamixelMotorsBus
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
motors_buses[key] = DynamixelMotorsBus(cfg)
elif cfg.type == "feetech":
from lerobot.common.motors.feetech.feetech import FeetechMotorsBus
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
motors_buses[key] = FeetechMotorsBus(cfg)
@@ -38,16 +54,13 @@ def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
if motor_type == "dynamixel":
from .configs import DynamixelMotorsBusConfig
from .dynamixel import DynamixelMotorsBus
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
config = DynamixelMotorsBusConfig(**kwargs)
return DynamixelMotorsBus(config)
elif motor_type == "feetech":
from feetech import FeetechMotorsBus
from .configs import FeetechMotorsBusConfig
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
config = FeetechMotorsBusConfig(**kwargs)
return FeetechMotorsBus(config)

View File

@@ -0,0 +1,613 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import dataclass, field
from typing import Sequence
import draccus
from lerobot.common.robot_devices.cameras.configs import (
CameraConfig,
IntelRealSenseCameraConfig,
OpenCVCameraConfig,
)
from lerobot.common.robot_devices.motors.configs import (
DynamixelMotorsBusConfig,
FeetechMotorsBusConfig,
MotorsBusConfig,
)
@dataclass
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
# TODO(rcadene, aliberts): remove ManipulatorRobotConfig abstraction
@dataclass
class ManipulatorRobotConfig(RobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
# as the number of motors in your follower arms (assumes all follower arms have the same number of
# motors).
max_relative_target: list[float] | float | None = None
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
# gripper is not put in torque mode.
gripper_open_degree: float | None = None
mock: bool = False
def __post_init__(self):
if self.mock:
for arm in self.leader_arms.values():
if not arm.mock:
arm.mock = True
for arm in self.follower_arms.values():
if not arm.mock:
arm.mock = True
for cam in self.cameras.values():
if not cam.mock:
cam.mock = True
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
for name in self.follower_arms:
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
raise ValueError(
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
f"`max_relative_target` list has as many parameters as there are motors per arm. "
"Note: This feature does not yet work with robots where different follower arms have "
"different numbers of motors."
)
@RobotConfig.register_subclass("aloha")
@dataclass
class AlohaRobotConfig(ManipulatorRobotConfig):
# Specific to Aloha, LeRobot comes with default calibration files. Assuming the motors have been
# properly assembled, no manual calibration step is expected. If you need to run manual calibration,
# simply update this path to ".cache/calibration/aloha"
calibration_dir: str = ".cache/calibration/aloha_default"
# /!\ FOR SAFETY, READ THIS /!\
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
# When you feel more confident with teleoperation or running the policy, you can extend
# this safety limit and even removing it by setting it to `null`.
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
max_relative_target: int | None = 5
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": DynamixelMotorsBusConfig(
# window_x
port="/dev/ttyDXL_leader_left",
motors={
# name: (index, model)
"waist": [1, "xm430-w350"],
"shoulder": [2, "xm430-w350"],
"shoulder_shadow": [3, "xm430-w350"],
"elbow": [4, "xm430-w350"],
"elbow_shadow": [5, "xm430-w350"],
"forearm_roll": [6, "xm430-w350"],
"wrist_angle": [7, "xm430-w350"],
"wrist_rotate": [8, "xl430-w250"],
"gripper": [9, "xc430-w150"],
},
),
"right": DynamixelMotorsBusConfig(
# window_x
port="/dev/ttyDXL_leader_right",
motors={
# name: (index, model)
"waist": [1, "xm430-w350"],
"shoulder": [2, "xm430-w350"],
"shoulder_shadow": [3, "xm430-w350"],
"elbow": [4, "xm430-w350"],
"elbow_shadow": [5, "xm430-w350"],
"forearm_roll": [6, "xm430-w350"],
"wrist_angle": [7, "xm430-w350"],
"wrist_rotate": [8, "xl430-w250"],
"gripper": [9, "xc430-w150"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": DynamixelMotorsBusConfig(
port="/dev/ttyDXL_follower_left",
motors={
# name: (index, model)
"waist": [1, "xm540-w270"],
"shoulder": [2, "xm540-w270"],
"shoulder_shadow": [3, "xm540-w270"],
"elbow": [4, "xm540-w270"],
"elbow_shadow": [5, "xm540-w270"],
"forearm_roll": [6, "xm540-w270"],
"wrist_angle": [7, "xm540-w270"],
"wrist_rotate": [8, "xm430-w350"],
"gripper": [9, "xm430-w350"],
},
),
"right": DynamixelMotorsBusConfig(
port="/dev/ttyDXL_follower_right",
motors={
# name: (index, model)
"waist": [1, "xm540-w270"],
"shoulder": [2, "xm540-w270"],
"shoulder_shadow": [3, "xm540-w270"],
"elbow": [4, "xm540-w270"],
"elbow_shadow": [5, "xm540-w270"],
"forearm_roll": [6, "xm540-w270"],
"wrist_angle": [7, "xm540-w270"],
"wrist_rotate": [8, "xm430-w350"],
"gripper": [9, "xm430-w350"],
},
),
}
)
# Troubleshooting: If one of your IntelRealSense cameras freeze during
# data recording due to bandwidth limit, you might need to plug the camera
# on another USB hub or PCIe card.
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"cam_high": IntelRealSenseCameraConfig(
serial_number=128422271347,
fps=30,
width=640,
height=480,
),
"cam_low": IntelRealSenseCameraConfig(
serial_number=130322270656,
fps=30,
width=640,
height=480,
),
"cam_left_wrist": IntelRealSenseCameraConfig(
serial_number=218622272670,
fps=30,
width=640,
height=480,
),
"cam_right_wrist": IntelRealSenseCameraConfig(
serial_number=130322272300,
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False
@RobotConfig.register_subclass("koch")
@dataclass
class KochRobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/koch"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem585A0085511",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl330-m077"],
"shoulder_lift": [2, "xl330-m077"],
"elbow_flex": [3, "xl330-m077"],
"wrist_flex": [4, "xl330-m077"],
"wrist_roll": [5, "xl330-m077"],
"gripper": [6, "xl330-m077"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl430-w250"],
"shoulder_lift": [2, "xl430-w250"],
"elbow_flex": [3, "xl330-m288"],
"wrist_flex": [4, "xl330-m288"],
"wrist_roll": [5, "xl330-m288"],
"gripper": [6, "xl330-m288"],
},
),
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"laptop": OpenCVCameraConfig(
camera_index=0,
fps=30,
width=640,
height=480,
),
"phone": OpenCVCameraConfig(
camera_index=1,
fps=30,
width=640,
height=480,
),
}
)
# ~ Koch specific settings ~
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
# to squeeze the gripper and have it spring back to an open position on its own.
gripper_open_degree: float = 35.156
mock: bool = False
@RobotConfig.register_subclass("koch_bimanual")
@dataclass
class KochBimanualRobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/koch_bimanual"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem585A0085511",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl330-m077"],
"shoulder_lift": [2, "xl330-m077"],
"elbow_flex": [3, "xl330-m077"],
"wrist_flex": [4, "xl330-m077"],
"wrist_roll": [5, "xl330-m077"],
"gripper": [6, "xl330-m077"],
},
),
"right": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem575E0031751",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl330-m077"],
"shoulder_lift": [2, "xl330-m077"],
"elbow_flex": [3, "xl330-m077"],
"wrist_flex": [4, "xl330-m077"],
"wrist_roll": [5, "xl330-m077"],
"gripper": [6, "xl330-m077"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"left": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl430-w250"],
"shoulder_lift": [2, "xl430-w250"],
"elbow_flex": [3, "xl330-m288"],
"wrist_flex": [4, "xl330-m288"],
"wrist_roll": [5, "xl330-m288"],
"gripper": [6, "xl330-m288"],
},
),
"right": DynamixelMotorsBusConfig(
port="/dev/tty.usbmodem575E0032081",
motors={
# name: (index, model)
"shoulder_pan": [1, "xl430-w250"],
"shoulder_lift": [2, "xl430-w250"],
"elbow_flex": [3, "xl330-m288"],
"wrist_flex": [4, "xl330-m288"],
"wrist_roll": [5, "xl330-m288"],
"gripper": [6, "xl330-m288"],
},
),
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"laptop": OpenCVCameraConfig(
camera_index=0,
fps=30,
width=640,
height=480,
),
"phone": OpenCVCameraConfig(
camera_index=1,
fps=30,
width=640,
height=480,
),
}
)
# ~ Koch specific settings ~
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
# to squeeze the gripper and have it spring back to an open position on its own.
gripper_open_degree: float = 35.156
mock: bool = False
@RobotConfig.register_subclass("moss")
@dataclass
class MossRobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/moss"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431091",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0076891",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"laptop": OpenCVCameraConfig(
camera_index=0,
fps=30,
width=640,
height=480,
),
"phone": OpenCVCameraConfig(
camera_index=1,
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False
@RobotConfig.register_subclass("so100")
@dataclass
class So100RobotConfig(ManipulatorRobotConfig):
calibration_dir: str = ".cache/calibration/so100"
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760433331",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem58760431631",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"laptop": OpenCVCameraConfig(
camera_index=0,
fps=30,
width=640,
height=480,
),
"phone": OpenCVCameraConfig(
camera_index=1,
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False
@RobotConfig.register_subclass("stretch")
@dataclass
class StretchRobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"navigation": OpenCVCameraConfig(
camera_index="/dev/hello-nav-head-camera",
fps=10,
width=1280,
height=720,
rotation=-90,
),
"head": IntelRealSenseCameraConfig(
name="Intel RealSense D435I",
fps=30,
width=640,
height=480,
rotation=90,
),
"wrist": IntelRealSenseCameraConfig(
name="Intel RealSense D405",
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False
@RobotConfig.register_subclass("lekiwi")
@dataclass
class LeKiwiRobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Network Configuration
ip: str = "192.168.0.193"
port: int = 5555
video_port: int = 5556
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"front": OpenCVCameraConfig(
camera_index="/dev/video0", fps=30, width=640, height=480, rotation=90
),
"wrist": OpenCVCameraConfig(
camera_index="/dev/video2", fps=30, width=640, height=480, rotation=180
),
}
)
calibration_dir: str = ".cache/calibration/lekiwi"
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0077581",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/ttyACM0",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
"left_wheel": (7, "sts3215"),
"back_wheel": (8, "sts3215"),
"right_wheel": (9, "sts3215"),
},
),
}
)
teleop_keys: dict[str, str] = field(
default_factory=lambda: {
# Movement
"forward": "w",
"backward": "s",
"left": "a",
"right": "d",
"rotate_left": "z",
"rotate_right": "x",
# Speed control
"speed_up": "r",
"speed_down": "f",
# quit teleop
"quit": "q",
}
)
mock: bool = False

View File

@@ -17,9 +17,12 @@
import numpy as np
from ..motors_bus import MotorNormMode, MotorsBus
from .dynamixel import TorqueMode
from .tables import MODEL_RESOLUTION
from lerobot.common.robot_devices.motors.dynamixel import (
CalibrationMode,
TorqueMode,
convert_degrees_to_steps,
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
@@ -46,17 +49,6 @@ def apply_drive_mode(position, drive_mode):
return position
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
"""
resolutions = [MODEL_RESOLUTION[model] for model in models]
steps = degrees / 180 * np.array(resolutions) / 2
steps = steps.astype(int)
return steps
def compute_nearest_rounded_position(position, models):
delta_turn = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, models)
nearest_pos = np.round(position.astype(float) / delta_turn) * delta_turn
@@ -97,11 +89,11 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
# It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will
# correspond to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position.
zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.models)
zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models)
# Compute homing offset so that `present_position + homing_offset ~= target_position`.
zero_pos = arm.read("Present_Position")
zero_nearest_pos = compute_nearest_rounded_position(zero_pos, arm.models)
zero_nearest_pos = compute_nearest_rounded_position(zero_pos, arm.motor_models)
homing_offset = zero_target_pos - zero_nearest_pos
# The rotated target position corresponds to a rotation of a quarter turn from the zero position.
@@ -115,7 +107,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.models)
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
# Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@@ -124,7 +116,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# Re-compute homing offset to take into account drive mode
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.models)
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
homing_offset = rotated_target_pos - rotated_nearest_pos
print("\nMove arm to rest position")
@@ -133,13 +125,13 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
print()
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
calib_mode = [MotorNormMode.DEGREE.name] * len(arm.names)
calib_mode = [CalibrationMode.DEGREE.name] * len(arm.motor_names)
# TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml?
if robot_type in ["aloha"] and "gripper" in arm.names:
if robot_type in ["aloha"] and "gripper" in arm.motor_names:
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
calib_idx = arm.names.index("gripper")
calib_mode[calib_idx] = MotorNormMode.LINEAR.name
calib_idx = arm.motor_names.index("gripper")
calib_mode[calib_idx] = CalibrationMode.LINEAR.name
calib_data = {
"homing_offset": homing_offset.tolist(),
@@ -147,6 +139,6 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
"start_pos": zero_pos.tolist(),
"end_pos": rotated_pos.tolist(),
"calib_mode": calib_mode,
"motor_names": arm.names,
"motor_names": arm.motor_names,
}
return calib_data

View File

@@ -0,0 +1,509 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Logic to calibrate a robot arm built with feetech motors"""
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
import time
import numpy as np
from lerobot.common.robot_devices.motors.feetech import (
CalibrationMode,
TorqueMode,
convert_degrees_to_steps,
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
# The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used.
ZERO_POSITION_DEGREE = 0
ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
def apply_drive_mode(position, drive_mode):
assert_drive_mode(drive_mode)
# Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted,
# to [-1, 1] with 1 indicates original rotation direction and -1 inverted.
signed_drive_mode = -(drive_mode * 2 - 1)
position *= signed_drive_mode
return position
def move_until_block(arm, motor_name, positive_direction=True, while_move_hook=None):
count = 0
while True:
present_pos = arm.read("Present_Position", motor_name)
if positive_direction:
# Move +100 steps every time. Lower the steps to lower the speed at which the arm moves.
arm.write("Goal_Position", present_pos + 100, motor_name)
else:
arm.write("Goal_Position", present_pos - 100, motor_name)
if while_move_hook is not None:
while_move_hook()
present_pos = arm.read("Present_Position", motor_name).item()
present_speed = arm.read("Present_Speed", motor_name).item()
present_current = arm.read("Present_Current", motor_name).item()
# present_load = arm.read("Present_Load", motor_name).item()
# present_voltage = arm.read("Present_Voltage", motor_name).item()
# present_temperature = arm.read("Present_Temperature", motor_name).item()
# print(f"{present_pos=}")
# print(f"{present_speed=}")
# print(f"{present_current=}")
# print(f"{present_load=}")
# print(f"{present_voltage=}")
# print(f"{present_temperature=}")
if present_speed == 0 and present_current > 40:
count += 1
if count > 100 or present_current > 300:
return present_pos
else:
count = 0
def move_to_calibrate(
arm,
motor_name,
invert_drive_mode=False,
positive_first=True,
in_between_move_hook=None,
while_move_hook=None,
):
initial_pos = arm.read("Present_Position", motor_name)
if positive_first:
p_present_pos = move_until_block(
arm, motor_name, positive_direction=True, while_move_hook=while_move_hook
)
else:
n_present_pos = move_until_block(
arm, motor_name, positive_direction=False, while_move_hook=while_move_hook
)
if in_between_move_hook is not None:
in_between_move_hook()
if positive_first:
n_present_pos = move_until_block(
arm, motor_name, positive_direction=False, while_move_hook=while_move_hook
)
else:
p_present_pos = move_until_block(
arm, motor_name, positive_direction=True, while_move_hook=while_move_hook
)
zero_pos = (n_present_pos + p_present_pos) / 2
calib_data = {
"initial_pos": initial_pos,
"homing_offset": zero_pos if invert_drive_mode else -zero_pos,
"invert_drive_mode": invert_drive_mode,
"drive_mode": -1 if invert_drive_mode else 0,
"zero_pos": zero_pos,
"start_pos": n_present_pos if invert_drive_mode else p_present_pos,
"end_pos": p_present_pos if invert_drive_mode else n_present_pos,
}
return calib_data
def apply_offset(calib, offset):
calib["zero_pos"] += offset
if calib["drive_mode"]:
calib["homing_offset"] += offset
else:
calib["homing_offset"] -= offset
return calib
def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
if robot_type == "so100":
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
elif robot_type == "moss":
return run_arm_auto_calibration_moss(arm, robot_type, arm_name, arm_type)
else:
raise ValueError(robot_type)
def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
if not (robot_type == "so100" and arm_type == "follower"):
raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254])
initial_acceleration = arm.read("Acceleration")
arm.write("Lock", 0)
arm.write("Acceleration", 10)
time.sleep(1)
arm.write("Torque_Enable", TorqueMode.ENABLED.value)
print(f'{arm.read("Present_Position", "elbow_flex")=}')
calib = {}
init_wf_pos = arm.read("Present_Position", "wrist_flex")
init_sl_pos = arm.read("Present_Position", "shoulder_lift")
init_ef_pos = arm.read("Present_Position", "elbow_flex")
arm.write("Goal_Position", init_wf_pos - 800, "wrist_flex")
arm.write("Goal_Position", init_sl_pos + 150 + 1024, "shoulder_lift")
arm.write("Goal_Position", init_ef_pos - 2048, "elbow_flex")
time.sleep(2)
print("Calibrate shoulder_pan")
calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan")
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
time.sleep(1)
print("Calibrate gripper")
calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True)
time.sleep(1)
print("Calibrate wrist_flex")
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex")
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80)
def in_between_move_hook():
nonlocal arm, calib
time.sleep(2)
ef_pos = arm.read("Present_Position", "elbow_flex")
sl_pos = arm.read("Present_Position", "shoulder_lift")
arm.write("Goal_Position", ef_pos + 1024, "elbow_flex")
arm.write("Goal_Position", sl_pos - 1024, "shoulder_lift")
time.sleep(2)
print("Calibrate elbow_flex")
calib["elbow_flex"] = move_to_calibrate(
arm,
"elbow_flex",
positive_first=False,
in_between_move_hook=in_between_move_hook,
)
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
time.sleep(1)
def in_between_move_hook():
nonlocal arm, calib
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex")
print("Calibrate shoulder_lift")
calib["shoulder_lift"] = move_to_calibrate(
arm,
"shoulder_lift",
invert_drive_mode=True,
positive_first=False,
in_between_move_hook=in_between_move_hook,
)
# add an 30 steps as offset to align with body
calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50)
def while_move_hook():
nonlocal arm, calib
positions = {
"shoulder_lift": round(calib["shoulder_lift"]["zero_pos"] - 1600),
"elbow_flex": round(calib["elbow_flex"]["zero_pos"] + 1700),
"wrist_flex": round(calib["wrist_flex"]["zero_pos"] + 800),
"gripper": round(calib["gripper"]["end_pos"]),
}
arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
arm.write(
"Goal_Position",
round(calib["shoulder_lift"]["zero_pos"] - 1600),
"shoulder_lift",
)
time.sleep(2)
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
time.sleep(2)
arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
time.sleep(2)
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
time.sleep(2)
print("Calibrate wrist_roll")
calib["wrist_roll"] = move_to_calibrate(
arm,
"wrist_roll",
invert_drive_mode=True,
positive_first=False,
while_move_hook=while_move_hook,
)
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
time.sleep(1)
arm.write("Goal_Position", calib["gripper"]["start_pos"], "gripper")
time.sleep(1)
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
time.sleep(1)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
time.sleep(1)
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
time.sleep(1)
calib_modes = []
for name in arm.motor_names:
if name == "gripper":
calib_modes.append(CalibrationMode.LINEAR.name)
else:
calib_modes.append(CalibrationMode.DEGREE.name)
calib_dict = {
"homing_offset": [calib[name]["homing_offset"] for name in arm.motor_names],
"drive_mode": [calib[name]["drive_mode"] for name in arm.motor_names],
"start_pos": [calib[name]["start_pos"] for name in arm.motor_names],
"end_pos": [calib[name]["end_pos"] for name in arm.motor_names],
"calib_mode": calib_modes,
"motor_names": arm.motor_names,
}
# Re-enable original accerlation
arm.write("Lock", 0)
arm.write("Acceleration", initial_acceleration)
time.sleep(1)
return calib_dict
def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
if not (robot_type == "moss" and arm_type == "follower"):
raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254])
initial_acceleration = arm.read("Acceleration")
arm.write("Lock", 0)
arm.write("Acceleration", 10)
time.sleep(1)
arm.write("Torque_Enable", TorqueMode.ENABLED.value)
sl_pos = arm.read("Present_Position", "shoulder_lift")
arm.write("Goal_Position", sl_pos - 1024 - 450, "shoulder_lift")
ef_pos = arm.read("Present_Position", "elbow_flex")
arm.write("Goal_Position", ef_pos + 1024 + 450, "elbow_flex")
time.sleep(2)
calib = {}
print("Calibrate shoulder_pan")
calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan")
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
time.sleep(1)
print("Calibrate gripper")
calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True)
time.sleep(1)
print("Calibrate wrist_flex")
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex", invert_drive_mode=True)
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=-210 + 1024)
wr_pos = arm.read("Present_Position", "wrist_roll")
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
time.sleep(1)
arm.write("Goal_Position", wr_pos - 1024, "wrist_roll")
time.sleep(1)
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 2048, "wrist_flex")
time.sleep(1)
arm.write("Goal_Position", calib["gripper"]["end_pos"], "gripper")
time.sleep(1)
print("Calibrate wrist_roll")
calib["wrist_roll"] = move_to_calibrate(arm, "wrist_roll", invert_drive_mode=True)
calib["wrist_roll"] = apply_offset(calib["wrist_roll"], offset=790)
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"] - 1024, "wrist_roll")
arm.write("Goal_Position", calib["gripper"]["start_pos"], "gripper")
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
time.sleep(1)
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 2048, "wrist_flex")
def in_between_move_elbow_flex_hook():
nonlocal arm, calib
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
print("Calibrate elbow_flex")
calib["elbow_flex"] = move_to_calibrate(
arm,
"elbow_flex",
invert_drive_mode=True,
in_between_move_hook=in_between_move_elbow_flex_hook,
)
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
def in_between_move_shoulder_lift_hook():
nonlocal arm, calib
sl = arm.read("Present_Position", "shoulder_lift")
arm.write("Goal_Position", sl - 1500, "shoulder_lift")
time.sleep(1)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1536, "elbow_flex")
time.sleep(1)
arm.write("Goal_Position", calib["wrist_flex"]["start_pos"], "wrist_flex")
time.sleep(1)
print("Calibrate shoulder_lift")
calib["shoulder_lift"] = move_to_calibrate(
arm, "shoulder_lift", in_between_move_hook=in_between_move_shoulder_lift_hook
)
calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=-1024)
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
time.sleep(1)
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
time.sleep(2)
calib_modes = []
for name in arm.motor_names:
if name == "gripper":
calib_modes.append(CalibrationMode.LINEAR.name)
else:
calib_modes.append(CalibrationMode.DEGREE.name)
calib_dict = {
"homing_offset": [calib[name]["homing_offset"] for name in arm.motor_names],
"drive_mode": [calib[name]["drive_mode"] for name in arm.motor_names],
"start_pos": [calib[name]["start_pos"] for name in arm.motor_names],
"end_pos": [calib[name]["end_pos"] for name in arm.motor_names],
"calib_mode": calib_modes,
"motor_names": arm.motor_names,
}
# Re-enable original accerlation
arm.write("Lock", 0)
arm.write("Acceleration", initial_acceleration)
time.sleep(1)
return calib_dict
def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
"""This function ensures that a neural network trained on data collected on a given robot
can work on another robot. For instance before calibration, setting a same goal position
for each motor of two different robots will get two very different positions. But after calibration,
the two robots will move to the same position.To this end, this function computes the homing offset
and the drive mode for each motor of a given robot.
Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps
to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions
being 0. During the calibration process, you will need to manually move the robot to this "zero position".
Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled
in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot
to the "rotated position".
After calibration, the homing offsets and drive modes are stored in a cache.
Example of usage:
```python
run_arm_calibration(arm, "so100", "left", "follower")
```
"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
# It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will
# correspond to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position.
zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models)
# Compute homing offset so that `present_position + homing_offset ~= target_position`.
zero_pos = arm.read("Present_Position")
homing_offset = zero_target_pos - zero_pos
# The rotated target position corresponds to a rotation of a quarter turn from the zero position.
# This allows to identify the rotation direction of each motor.
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
# Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
rotated_pos = arm.read("Present_Position")
drive_mode = (rotated_pos < zero_pos).astype(np.int32)
# Re-compute homing offset to take into account drive mode
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
homing_offset = rotated_target_pos - rotated_drived_pos
print("\nMove arm to rest position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
input("Press Enter to continue...")
print()
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
calib_modes = []
for name in arm.motor_names:
if name == "gripper":
calib_modes.append(CalibrationMode.LINEAR.name)
else:
calib_modes.append(CalibrationMode.DEGREE.name)
calib_dict = {
"homing_offset": homing_offset.tolist(),
"drive_mode": drive_mode.tolist(),
"start_pos": zero_pos.tolist(),
"end_pos": rotated_pos.tolist(),
"calib_mode": calib_modes,
"motor_names": arm.motor_names,
}
return calib_dict

View File

@@ -21,7 +21,7 @@ from pathlib import Path
import cv2
import zmq
from lerobot.common.robots.mobile_manipulator import LeKiwi
from lerobot.common.robot_devices.robots.mobile_manipulator import LeKiwi
def setup_zmq_sockets(config):
@@ -61,7 +61,9 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
calib_dir.mkdir(parents=True, exist_ok=True)
calib_file = calib_dir / "main_follower.json"
try:
from lerobot.common.motors.feetech.feetech_calibration import run_full_arm_calibration
from lerobot.common.robot_devices.robots.feetech_calibration import (
run_arm_manual_calibration,
)
except ImportError:
print("[WARNING] Calibration function not available. Skipping calibration.")
return
@@ -72,7 +74,7 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
print(f"[INFO] Loaded calibration from {calib_file}")
else:
print("[INFO] Calibration file not found. Running manual calibration...")
calibration = run_full_arm_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
print(f"[INFO] Calibration complete. Saving to {calib_file}")
with open(calib_file, "w") as f:
json.dump(calibration, f)
@@ -93,8 +95,8 @@ def run_lekiwi(robot_config):
- Processes incoming commands (arm and wheel commands) and sends back sensor and camera data.
"""
# Import helper functions and classes
from lerobot.common.cameras.utils import make_cameras_from_configs
from lerobot.common.motors.feetech.feetech import FeetechMotorsBus, TorqueMode
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
# Initialize cameras from the robot configuration.
cameras = make_cameras_from_configs(robot_config.cameras)
@@ -116,7 +118,14 @@ def run_lekiwi(robot_config):
robot = LeKiwi(motors_bus)
# Define the expected arm motor IDs.
arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
arm_motor_ids = [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
]
# Disable torque for each arm motor.
for motor in arm_motor_ids:
@@ -130,7 +139,9 @@ def run_lekiwi(robot_config):
images_lock = threading.Lock()
stop_event = threading.Event()
cam_thread = threading.Thread(
target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True
target=run_camera_capture,
args=(cameras, images_lock, latest_images_dict, stop_event),
daemon=True,
)
cam_thread.start()

View File

@@ -18,66 +18,48 @@ and send orders to its motors.
# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated
# calibration procedure, to make it easy for people to add their own robot.
import json
import logging
import time
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Sequence
import numpy as np
import torch
from lerobot.common.cameras.configs import CameraConfig
from lerobot.common.cameras.utils import make_cameras_from_configs
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.motors.configs import MotorsBusConfig
from lerobot.common.motors.motors_bus import MotorsBus
from lerobot.common.motors.utils import make_motors_buses_from_configs
from lerobot.common.robots.config import RobotConfig
from lerobot.common.robots.utils import ensure_safe_goal_position, get_arm_id
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.motors.utils import (
MotorsBus,
make_motors_buses_from_configs,
)
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
from lerobot.common.robot_devices.robots.utils import get_arm_id
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
@dataclass
class ManipulatorRobotConfig(RobotConfig):
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
def ensure_safe_goal_position(
goal_pos: torch.Tensor,
present_pos: torch.Tensor,
max_relative_target: float | list[float],
):
# Cap relative action target magnitude for safety.
diff = goal_pos - present_pos
max_relative_target = torch.tensor(max_relative_target)
safe_diff = torch.minimum(diff, max_relative_target)
safe_diff = torch.maximum(safe_diff, -max_relative_target)
safe_goal_pos = present_pos + safe_diff
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
# as the number of motors in your follower arms (assumes all follower arms have the same number of
# motors).
max_relative_target: list[float] | float | None = None
if not torch.allclose(goal_pos, safe_goal_pos):
logging.debug(
"Relative goal position magnitude had to be clamped to be safe.\n"
f" requested relative goal position target: {diff}\n"
f" clamped relative goal position target: {safe_diff}"
)
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
# gripper is not put in torque mode.
gripper_open_degree: float | None = None
mock: bool = False
def __post_init__(self):
if self.mock:
for arm in self.leader_arms.values():
if not arm.mock:
arm.mock = True
for arm in self.follower_arms.values():
if not arm.mock:
arm.mock = True
for cam in self.cameras.values():
if not cam.mock:
cam.mock = True
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
for name in self.follower_arms:
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
raise ValueError(
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
f"`max_relative_target` list has as many parameters as there are motors per arm. "
"Note: This feature does not yet work with robots where different follower arms have "
"different numbers of motors."
)
return safe_goal_pos
class ManipulatorRobot:
@@ -250,7 +232,7 @@ class ManipulatorRobot:
def connect(self):
if self.is_connected:
raise DeviceAlreadyConnectedError(
raise RobotDeviceAlreadyConnectedError(
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
)
@@ -268,9 +250,9 @@ class ManipulatorRobot:
self.leader_arms[name].connect()
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
from lerobot.common.motors.dynamixel.dynamixel import TorqueMode
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
elif self.robot_type in ["so100", "moss", "lekiwi"]:
from lerobot.common.motors.feetech.feetech import TorqueMode
from lerobot.common.robot_devices.motors.feetech import TorqueMode
# We assume that at connection time, arms are in a rest position, and torque can
# be safely disabled to run calibration and/or set robot preset configurations.
@@ -279,6 +261,8 @@ class ManipulatorRobot:
for name in self.leader_arms:
self.leader_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value)
self.activate_calibration()
# Set robot preset (e.g. torque in leader gripper for Koch v1.1)
if self.robot_type in ["koch", "koch_bimanual"]:
self.set_koch_robot_preset()
@@ -315,9 +299,54 @@ class ManipulatorRobot:
self.is_connected = True
def activate_calibration(self):
"""After calibration all motors function in human interpretable ranges.
Rotations are expressed in degrees in nominal range of [-180, 180],
and linear motions (like gripper of Aloha) in nominal range of [0, 100].
"""
def load_or_run_calibration_(name, arm, arm_type):
arm_id = get_arm_id(name, arm_type)
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
if arm_calib_path.exists():
with open(arm_calib_path) as f:
calibration = json.load(f)
else:
# TODO(rcadene): display a warning in __init__ if calibration file not available
print(f"Missing calibration file '{arm_calib_path}'")
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
from lerobot.common.robot_devices.robots.dynamixel_calibration import (
run_arm_calibration,
)
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
elif self.robot_type in ["so100", "moss", "lekiwi"]:
from lerobot.common.robot_devices.robots.feetech_calibration import (
run_arm_manual_calibration,
)
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f:
json.dump(calibration, f)
return calibration
for name, arm in self.follower_arms.items():
calibration = load_or_run_calibration_(name, arm, "follower")
arm.set_calibration(calibration)
for name, arm in self.leader_arms.items():
calibration = load_or_run_calibration_(name, arm, "leader")
arm.set_calibration(calibration)
def set_koch_robot_preset(self):
def set_operating_mode_(arm):
from lerobot.common.motors.dynamixel.dynamixel import TorqueMode
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
@@ -415,6 +444,9 @@ class ManipulatorRobot:
# Set I_Coefficient and D_Coefficient to default value 0 and 32
self.follower_arms[name].write("I_Coefficient", 0)
self.follower_arms[name].write("D_Coefficient", 32)
# Close the write lock so that Maximum_Acceleration gets written to EPROM address,
# which is mandatory for Maximum_Acceleration to take effect after rebooting.
self.follower_arms[name].write("Lock", 0)
# Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of
# the motors. Note: this configuration is not in the official STS3215 Memory Table
self.follower_arms[name].write("Maximum_Acceleration", 254)
@@ -424,7 +456,7 @@ class ManipulatorRobot:
self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not self.is_connected:
raise DeviceNotConnectedError(
raise RobotDeviceNotConnectedError(
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
)
@@ -442,6 +474,14 @@ class ManipulatorRobot:
before_fwrite_t = time.perf_counter()
goal_pos = leader_pos[name]
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
# if self.config.joint_position_relative_bounds is not None:
# goal_pos = torch.clamp(
# goal_pos,
# self.config.joint_position_relative_bounds["min"],
# self.config.joint_position_relative_bounds["max"],
# )
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
@@ -504,7 +544,7 @@ class ManipulatorRobot:
def capture_observation(self):
"""The returned observations do not have a batch dimension."""
if not self.is_connected:
raise DeviceNotConnectedError(
raise RobotDeviceNotConnectedError(
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
)
@@ -550,7 +590,7 @@ class ManipulatorRobot:
action: tensor containing the concatenated goal positions for the follower arms.
"""
if not self.is_connected:
raise DeviceNotConnectedError(
raise RobotDeviceNotConnectedError(
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
)
@@ -563,6 +603,14 @@ class ManipulatorRobot:
goal_pos = action[from_idx:to_idx]
from_idx = to_idx
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
# if self.config.joint_position_relative_bounds is not None:
# goal_pos = torch.clamp(
# goal_pos,
# self.config.joint_position_relative_bounds["min"],
# self.config.joint_position_relative_bounds["max"],
# )
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
@@ -585,7 +633,7 @@ class ManipulatorRobot:
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(
raise RobotDeviceNotConnectedError(
"ManipulatorRobot is not connected. You need to run `robot.connect()` before disconnecting."
)

View File

@@ -23,14 +23,18 @@ import numpy as np
import torch
import zmq
from lerobot.common.cameras.utils import make_cameras_from_configs
from lerobot.common.errors import DeviceNotConnectedError
from lerobot.common.motors.feetech.feetech import TorqueMode
from lerobot.common.motors.feetech.feetech_calibration import run_full_arm_calibration
from lerobot.common.motors.motors_bus import MotorsBus
from lerobot.common.motors.utils import make_motors_buses_from_configs
from lerobot.common.robots.lekiwi.configuration_lekiwi import LeKiwiRobotConfig
from lerobot.common.robots.utils import get_arm_id
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.motors.feetech import TorqueMode
from lerobot.common.robot_devices.motors.utils import (
MotorsBus,
make_motors_buses_from_configs,
)
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
from lerobot.common.robot_devices.robots.feetech_calibration import (
run_arm_manual_calibration,
)
from lerobot.common.robot_devices.robots.utils import get_arm_id
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError
PYNPUT_AVAILABLE = True
try:
@@ -267,7 +271,7 @@ class MobileManipulator:
calibration = json.load(f)
else:
print(f"Missing calibration file '{arm_calib_path}'")
calibration = run_full_arm_calibration(arm, self.robot_type, name, arm_type)
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f:
@@ -325,7 +329,11 @@ class MobileManipulator:
socks = dict(poller.poll(15))
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
# No new data arrived → reuse ALL old data
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
return (
self.last_frames,
self.last_present_speed,
self.last_remote_arm_state,
)
# Drain all messages, keep only the last
last_msg = None
@@ -338,7 +346,11 @@ class MobileManipulator:
if not last_msg:
# No new message → also reuse old
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
return (
self.last_frames,
self.last_present_speed,
self.last_remote_arm_state,
)
# Decode only the final message
try:
@@ -376,7 +388,11 @@ class MobileManipulator:
except Exception as e:
print(f"[DEBUG] Error decoding video message: {e}")
# If decode fails, fall back to old data
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
return (
self.last_frames,
self.last_present_speed,
self.last_remote_arm_state,
)
return frames, present_speed, remote_arm_state_tensor
@@ -396,7 +412,7 @@ class MobileManipulator:
self, record_data: bool = False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not self.is_connected:
raise DeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
speed_setting = self.speed_levels[self.speed_index]
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
@@ -456,13 +472,17 @@ class MobileManipulator:
and a camera frame.
"""
if not self.is_connected:
raise DeviceNotConnectedError("Not connected. Run `connect()` first.")
raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
frames, present_speed, remote_arm_state_tensor = self._get_data()
body_state = self.wheel_raw_to_body(present_speed)
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
body_state_mm = (
body_state[0] * 1000.0,
body_state[1] * 1000.0,
body_state[2],
) # Convert x,y to mm/s
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
@@ -480,7 +500,7 @@ class MobileManipulator:
def send_action(self, action: torch.Tensor) -> torch.Tensor:
if not self.is_connected:
raise DeviceNotConnectedError("Not connected. Run `connect()` first.")
raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
# Ensure the action tensor has at least 9 elements:
# - First 6: arm positions.
@@ -518,7 +538,7 @@ class MobileManipulator:
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError("Not connected.")
raise RobotDeviceNotConnectedError("Not connected.")
if self.cmd_socket:
stop_cmd = {
"raw_velocity": {"left_wheel": 0, "back_wheel": 0, "right_wheel": 0},
@@ -621,7 +641,11 @@ class MobileManipulator:
# Convert each wheels angular speed (deg/s) to a raw integer.
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
return {
"left_wheel": wheel_raw[0],
"back_wheel": wheel_raw[1],
"right_wheel": wheel_raw[2],
}
def wheel_raw_to_body(
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125

View File

@@ -0,0 +1,208 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from dataclasses import replace
import torch
from stretch_body.gamepad_teleop import GamePadTeleop
from stretch_body.robot import Robot as StretchAPI
from stretch_body.robot_params import RobotParams
from lerobot.common.robot_devices.robots.configs import StretchRobotConfig
class StretchRobot(StretchAPI):
"""Wrapper of stretch_body.robot.Robot"""
def __init__(self, config: StretchRobotConfig | None = None, **kwargs):
super().__init__()
if config is None:
self.config = StretchRobotConfig(**kwargs)
else:
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)
self.robot_type = self.config.type
self.cameras = self.config.cameras
self.is_connected = False
self.teleop = None
self.logs = {}
# TODO(aliberts): test this
RobotParams.set_logging_level("WARNING")
RobotParams.set_logging_formatter("brief_console_formatter")
self.state_keys = None
self.action_keys = None
def connect(self) -> None:
self.is_connected = self.startup()
if not self.is_connected:
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
raise ConnectionError()
for name in self.cameras:
self.cameras[name].connect()
self.is_connected = self.is_connected and self.cameras[name].is_connected
if not self.is_connected:
print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError()
self.run_calibration()
def run_calibration(self) -> None:
if not self.is_homed():
self.home()
def teleop_step(
self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
# TODO(aliberts): return ndarrays instead of torch.Tensors
if not self.is_connected:
raise ConnectionError()
if self.teleop is None:
self.teleop = GamePadTeleop(robot_instance=False)
self.teleop.startup(robot=self)
before_read_t = time.perf_counter()
state = self.get_state()
action = self.teleop.gamepad_controller.get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
before_write_t = time.perf_counter()
self.teleop.do_motion(robot=self)
self.push_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
if self.state_keys is None:
self.state_keys = list(state)
if not record_data:
return
state = torch.as_tensor(list(state.values()))
action = torch.as_tensor(list(action.values()))
# Capture images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = state
action_dict["action"] = action
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict, action_dict
def get_state(self) -> dict:
status = self.get_status()
return {
"head_pan.pos": status["head"]["head_pan"]["pos"],
"head_tilt.pos": status["head"]["head_tilt"]["pos"],
"lift.pos": status["lift"]["pos"],
"arm.pos": status["arm"]["pos"],
"wrist_pitch.pos": status["end_of_arm"]["wrist_pitch"]["pos"],
"wrist_roll.pos": status["end_of_arm"]["wrist_roll"]["pos"],
"wrist_yaw.pos": status["end_of_arm"]["wrist_yaw"]["pos"],
"gripper.pos": status["end_of_arm"]["stretch_gripper"]["pos"],
"base_x.vel": status["base"]["x_vel"],
"base_y.vel": status["base"]["y_vel"],
"base_theta.vel": status["base"]["theta_vel"],
}
def capture_observation(self) -> dict:
# TODO(aliberts): return ndarrays instead of torch.Tensors
before_read_t = time.perf_counter()
state = self.get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
if self.state_keys is None:
self.state_keys = list(state)
state = torch.as_tensor(list(state.values()))
# Capture images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionaries
obs_dict = {}
obs_dict["observation.state"] = state
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor:
# TODO(aliberts): return ndarrays instead of torch.Tensors
if not self.is_connected:
raise ConnectionError()
if self.teleop is None:
self.teleop = GamePadTeleop(robot_instance=False)
self.teleop.startup(robot=self)
if self.action_keys is None:
dummy_action = self.teleop.gamepad_controller.get_state()
self.action_keys = list(dummy_action.keys())
action_dict = dict(zip(self.action_keys, action.tolist(), strict=True))
before_write_t = time.perf_counter()
self.teleop.do_motion(state=action_dict, robot=self)
self.push_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
# TODO(aliberts): return action_sent when motion is limited
return action
def print_logs(self) -> None:
pass
# TODO(aliberts): move robot-specific logs logic here
def teleop_safety_stop(self) -> None:
if self.teleop is not None:
self.teleop._safety_stop(robot=self)
def disconnect(self) -> None:
self.stop()
if self.teleop is not None:
self.teleop.gamepad_controller.stop()
self.teleop.stop()
if len(self.cameras) > 0:
for cam in self.cameras.values():
cam.disconnect()
self.is_connected = False
def __del__(self):
self.disconnect()

View File

@@ -0,0 +1,88 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Protocol
from lerobot.common.robot_devices.robots.configs import (
AlohaRobotConfig,
KochBimanualRobotConfig,
KochRobotConfig,
LeKiwiRobotConfig,
ManipulatorRobotConfig,
MossRobotConfig,
RobotConfig,
So100RobotConfig,
StretchRobotConfig,
)
def get_arm_id(name, arm_type):
"""Returns the string identifier of a robot arm. For instance, for a bimanual manipulator
like Aloha, it could be left_follower, right_follower, left_leader, or right_leader.
"""
return f"{name}_{arm_type}"
class Robot(Protocol):
# TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
robot_type: str
features: dict
def connect(self): ...
def run_calibration(self): ...
def teleop_step(self, record_data=False): ...
def capture_observation(self): ...
def send_action(self, action): ...
def disconnect(self): ...
def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
if robot_type == "aloha":
return AlohaRobotConfig(**kwargs)
elif robot_type == "koch":
return KochRobotConfig(**kwargs)
elif robot_type == "koch_bimanual":
return KochBimanualRobotConfig(**kwargs)
elif robot_type == "moss":
return MossRobotConfig(**kwargs)
elif robot_type == "so100":
return So100RobotConfig(**kwargs)
elif robot_type == "stretch":
return StretchRobotConfig(**kwargs)
elif robot_type == "lekiwi":
return LeKiwiRobotConfig(**kwargs)
else:
raise ValueError(f"Robot type '{robot_type}' is not available.")
def make_robot_from_config(config: RobotConfig):
if isinstance(config, ManipulatorRobotConfig):
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
return ManipulatorRobot(config)
elif isinstance(config, LeKiwiRobotConfig):
from lerobot.common.robot_devices.robots.mobile_manipulator import (
MobileManipulator,
)
return MobileManipulator(config)
else:
from lerobot.common.robot_devices.robots.stretch import StretchRobot
return StretchRobot(config)
def make_robot(robot_type: str, **kwargs) -> Robot:
config = make_robot_config(robot_type, **kwargs)
return make_robot_from_config(config)

View File

@@ -42,3 +42,25 @@ def safe_disconnect(func):
raise e
return wrapper
class RobotDeviceNotConnectedError(Exception):
"""Exception raised when the robot device is not connected."""
def __init__(
self,
message="This robot device is not connected. Try calling `robot_device.connect()` first.",
):
self.message = message
super().__init__(self.message)
class RobotDeviceAlreadyConnectedError(Exception):
"""Exception raised when the robot device is already connected."""
def __init__(
self,
message="This robot device is already connected. Try not calling `robot_device.connect()` twice.",
):
self.message = message
super().__init__(self.message)

View File

@@ -1,4 +0,0 @@
from .config import RobotConfig
from .robot import Robot
__all__ = ["RobotConfig", "Robot"]

View File

@@ -1,17 +0,0 @@
import abc
from dataclasses import dataclass
from pathlib import Path
import draccus
@dataclass(kw_only=True)
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
# Allows to distinguish between different robots of the same type
id: str | None = None
# Directory to store calibration file
calibration_dir: Path | None = None
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)

View File

@@ -1,2 +0,0 @@
from .config_koch_follower import KochFollowerConfig
from .koch_follower import KochFollower

View File

@@ -1,22 +0,0 @@
from dataclasses import dataclass, field
from lerobot.common.cameras import CameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("koch_follower")
@dataclass
class KochFollowerConfig(RobotConfig):
# Port to connect to the arm
port: str
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -1,230 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from typing import Any
from lerobot.common.cameras.utils import make_cameras_from_configs
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.common.motors.dynamixel import (
DynamixelMotorsBus,
OperatingMode,
)
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .config_koch_follower import KochFollowerConfig
logger = logging.getLogger(__name__)
class KochFollower(Robot):
"""
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow
expansion, developed by Alexander Koch from [Tau Robotics](https://tau-robotics.com)
- [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss
"""
config_class = KochFollowerConfig
name = "koch_follower"
def __init__(self, config: KochFollowerConfig):
super().__init__(config)
self.config = config
self.arm = DynamixelMotorsBus(
port=self.config.port,
motors={
"shoulder_pan": Motor(1, "xl430-w250", MotorNormMode.RANGE_M100_100),
"shoulder_lift": Motor(2, "xl430-w250", MotorNormMode.RANGE_M100_100),
"elbow_flex": Motor(3, "xl330-m288", MotorNormMode.RANGE_M100_100),
"wrist_flex": Motor(4, "xl330-m288", MotorNormMode.RANGE_M100_100),
"wrist_roll": Motor(5, "xl330-m288", MotorNormMode.RANGE_M100_100),
"gripper": Motor(6, "xl330-m288", MotorNormMode.RANGE_0_100),
},
calibration=self.calibration,
)
self.cameras = make_cameras_from_configs(config.cameras)
@property
def state_feature(self) -> dict:
return {
"dtype": "float32",
"shape": (len(self.arm),),
"names": {"motors": list(self.arm.motors)},
}
@property
def action_feature(self) -> dict:
return self.state_feature
@property
def camera_features(self) -> dict[str, dict]:
cam_ft = {}
for cam_key, cam in self.cameras.items():
cam_ft[cam_key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def is_connected(self) -> bool:
# TODO(aliberts): add cam.is_connected for cam in self.cameras
return self.arm.is_connected
def connect(self) -> None:
"""
We assume that at connection time, arm is in a rest position,
and torque can be safely disabled to run calibration.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.arm.connect()
if not self.is_calibrated:
self.calibrate()
for cam in self.cameras.values():
cam.connect()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.arm.is_calibrated
def calibrate(self) -> None:
logger.info(f"\nRunning calibration of {self}")
self.arm.disable_torque()
for name in self.arm.names:
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
input("Move robot to the middle of its range of motion and press ENTER....")
homing_offsets = self.arm.set_half_turn_homings()
full_turn_motors = ["shoulder_pan", "wrist_roll"]
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
logger.info(
f"Move all joints except {full_turn_motors} sequentially through their entire "
"ranges of motion.\nRecording positions. Press ENTER to stop..."
)
range_mins, range_maxes = self.arm.record_ranges_of_motion(unknown_range_motors)
for name in full_turn_motors:
range_mins[name] = 0
range_maxes[name] = 4095
self.calibration = {}
for name, motor in self.arm.motors.items():
self.calibration[name] = MotorCalibration(
id=motor.id,
drive_mode=0,
homing_offset=homing_offsets[name],
range_min=range_mins[name],
range_max=range_maxes[name],
)
self.arm.write_calibration(self.calibration)
self._save_calibration()
logger.info(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
with self.arm.torque_disabled():
self.arm.configure_motors()
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling
# the arm, you could end up with a servo with a position 0 or 4095 at a crucial point
for name in self.arm.names:
if name != "gripper":
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
# Use 'position control current based' for gripper to be limited by the limit of the current. For
# the follower gripper, it means it can grasp an object without forcing too much even tho, its
# goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
# For the leader gripper, it means we can use it as a physical trigger, since we can force with
# our finger to make it move, and it will move back to its original target position when we
# release the force.
self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
# Set better PID values to close the gap between recorded states and actions
# TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
self.arm.write("Position_P_Gain", "elbow_flex", 1500)
self.arm.write("Position_I_Gain", "elbow_flex", 0)
self.arm.write("Position_D_Gain", "elbow_flex", 600)
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
obs_dict = {}
# Read arm position
start = time.perf_counter()
obs_dict[OBS_STATE] = self.arm.sync_read("Present_Position")
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
def send_action(self, action: dict[str, float]) -> dict[str, float]:
"""Command arm to move to a target joint configuration.
The relative action magnitude may be clipped depending on the configuration parameter
`max_relative_target`. In this case, the action sent differs from original action.
Thus, this function always returns the action actually sent.
Args:
action (dict[str, float]): The goal positions for the motors.
Returns:
dict[str, float]: The action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = action
# Cap goal position when too far away from present position.
# /!\ Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = self.arm.sync_read("Present_Position")
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
# Send goal position to the arm
self.arm.sync_write("Goal_Position", goal_pos)
return goal_pos
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.arm.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -1,89 +0,0 @@
from dataclasses import dataclass, field
from lerobot.common.cameras.configs import CameraConfig
from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.common.motors.configs import FeetechMotorsBusConfig, MotorsBusConfig
from lerobot.common.robots.config import RobotConfig
@RobotConfig.register_subclass("lekiwi")
@dataclass
class LeKiwiRobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# Network Configuration
ip: str = "192.168.0.193"
port: int = 5555
video_port: int = 5556
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"front": OpenCVCameraConfig(
camera_index="/dev/video0", fps=30, width=640, height=480, rotation=90
),
"wrist": OpenCVCameraConfig(
camera_index="/dev/video2", fps=30, width=640, height=480, rotation=180
),
}
)
calibration_dir: str = ".cache/calibration/lekiwi"
leader_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/tty.usbmodem585A0077581",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
},
),
}
)
follower_arms: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": FeetechMotorsBusConfig(
port="/dev/ttyACM0",
motors={
# name: (index, model)
"shoulder_pan": [1, "sts3215"],
"shoulder_lift": [2, "sts3215"],
"elbow_flex": [3, "sts3215"],
"wrist_flex": [4, "sts3215"],
"wrist_roll": [5, "sts3215"],
"gripper": [6, "sts3215"],
"left_wheel": (7, "sts3215"),
"back_wheel": (8, "sts3215"),
"right_wheel": (9, "sts3215"),
},
),
}
)
teleop_keys: dict[str, str] = field(
default_factory=lambda: {
# Movement
"forward": "w",
"backward": "s",
"left": "a",
"right": "d",
"rotate_left": "z",
"rotate_right": "x",
# Speed control
"speed_up": "r",
"speed_down": "f",
# quit teleop
"quit": "q",
}
)
mock: bool = False

View File

@@ -1,692 +0,0 @@
import base64
import json
import os
import sys
from pathlib import Path
import cv2
import numpy as np
import torch
import zmq
from lerobot.common.cameras.utils import make_cameras_from_configs
from lerobot.common.errors import DeviceNotConnectedError
from lerobot.common.motors.feetech.feetech import TorqueMode
from lerobot.common.motors.feetech.feetech_calibration import run_full_arm_calibration
from lerobot.common.motors.motors_bus import MotorsBus
from lerobot.common.motors.utils import make_motors_buses_from_configs
from lerobot.common.robots.lekiwi.configuration_lekiwi import LeKiwiRobotConfig
from lerobot.common.robots.utils import get_arm_id
PYNPUT_AVAILABLE = True
try:
# Only import if there's a valid X server or if we're not on a Pi
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
print("No DISPLAY set. Skipping pynput import.")
raise ImportError("pynput blocked intentionally due to no display.")
from pynput import keyboard
except ImportError:
keyboard = None
PYNPUT_AVAILABLE = False
except Exception as e:
keyboard = None
PYNPUT_AVAILABLE = False
print(f"Could not import pynput: {e}")
class MobileManipulator:
"""
MobileManipulator is a class for connecting to and controlling a remote mobile manipulator robot.
The robot includes a three omniwheel mobile base and a remote follower arm.
The leader arm is connected locally (on the laptop) and its joint positions are recorded and then
forwarded to the remote follower arm (after applying a safety clamp).
In parallel, keyboard teleoperation is used to generate raw velocity commands for the wheels.
"""
def __init__(self, config: LeKiwiRobotConfig):
"""
Expected keys in config:
- ip, port, video_port for the remote connection.
- calibration_dir, leader_arms, follower_arms, max_relative_target, etc.
"""
self.robot_type = config.type
self.config = config
self.remote_ip = config.ip
self.remote_port = config.port
self.remote_port_video = config.video_port
self.calibration_dir = Path(self.config.calibration_dir)
self.logs = {}
self.teleop_keys = self.config.teleop_keys
# For teleoperation, the leader arm (local) is used to record the desired arm pose.
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
self.cameras = make_cameras_from_configs(self.config.cameras)
self.is_connected = False
self.last_frames = {}
self.last_present_speed = {}
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32)
# Define three speed levels and a current index
self.speed_levels = [
{"xy": 0.1, "theta": 30}, # slow
{"xy": 0.2, "theta": 60}, # medium
{"xy": 0.3, "theta": 90}, # fast
]
self.speed_index = 0 # Start at slow
# ZeroMQ context and sockets.
self.context = None
self.cmd_socket = None
self.video_socket = None
# Keyboard state for base teleoperation.
self.running = True
self.pressed_keys = {
"forward": False,
"backward": False,
"left": False,
"right": False,
"rotate_left": False,
"rotate_right": False,
}
if PYNPUT_AVAILABLE:
print("pynput is available - enabling local keyboard listener.")
self.listener = keyboard.Listener(
on_press=self.on_press,
on_release=self.on_release,
)
self.listener.start()
else:
print("pynput not available - skipping local keyboard listener.")
self.listener = None
def get_motor_names(self, arms: dict[str, MotorsBus]) -> list:
return [f"{arm}_{motor}" for arm, bus in arms.items() for motor in bus.motors]
@property
def camera_features(self) -> dict:
cam_ft = {}
for cam_key, cam in self.cameras.items():
key = f"observation.images.{cam_key}"
cam_ft[key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def motor_features(self) -> dict:
follower_arm_names = [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
]
observations = ["x_mm", "y_mm", "theta"]
combined_names = follower_arm_names + observations
return {
"action": {
"dtype": "float32",
"shape": (len(combined_names),),
"names": combined_names,
},
"observation.state": {
"dtype": "float32",
"shape": (len(combined_names),),
"names": combined_names,
},
}
@property
def features(self):
return {**self.motor_features, **self.camera_features}
@property
def has_camera(self):
return len(self.cameras) > 0
@property
def num_cameras(self):
return len(self.cameras)
@property
def available_arms(self):
available = []
for name in self.leader_arms:
available.append(get_arm_id(name, "leader"))
for name in self.follower_arms:
available.append(get_arm_id(name, "follower"))
return available
def on_press(self, key):
try:
# Movement
if key.char == self.teleop_keys["forward"]:
self.pressed_keys["forward"] = True
elif key.char == self.teleop_keys["backward"]:
self.pressed_keys["backward"] = True
elif key.char == self.teleop_keys["left"]:
self.pressed_keys["left"] = True
elif key.char == self.teleop_keys["right"]:
self.pressed_keys["right"] = True
elif key.char == self.teleop_keys["rotate_left"]:
self.pressed_keys["rotate_left"] = True
elif key.char == self.teleop_keys["rotate_right"]:
self.pressed_keys["rotate_right"] = True
# Quit teleoperation
elif key.char == self.teleop_keys["quit"]:
self.running = False
return False
# Speed control
elif key.char == self.teleop_keys["speed_up"]:
self.speed_index = min(self.speed_index + 1, 2)
print(f"Speed index increased to {self.speed_index}")
elif key.char == self.teleop_keys["speed_down"]:
self.speed_index = max(self.speed_index - 1, 0)
print(f"Speed index decreased to {self.speed_index}")
except AttributeError:
# e.g., if key is special like Key.esc
if key == keyboard.Key.esc:
self.running = False
return False
def on_release(self, key):
try:
if hasattr(key, "char"):
if key.char == self.teleop_keys["forward"]:
self.pressed_keys["forward"] = False
elif key.char == self.teleop_keys["backward"]:
self.pressed_keys["backward"] = False
elif key.char == self.teleop_keys["left"]:
self.pressed_keys["left"] = False
elif key.char == self.teleop_keys["right"]:
self.pressed_keys["right"] = False
elif key.char == self.teleop_keys["rotate_left"]:
self.pressed_keys["rotate_left"] = False
elif key.char == self.teleop_keys["rotate_right"]:
self.pressed_keys["rotate_right"] = False
except AttributeError:
pass
def connect(self):
if not self.leader_arms:
raise ValueError("MobileManipulator has no leader arm to connect.")
for name in self.leader_arms:
print(f"Connecting {name} leader arm.")
self.calibrate_leader()
# Set up ZeroMQ sockets to communicate with the remote mobile robot.
self.context = zmq.Context()
self.cmd_socket = self.context.socket(zmq.PUSH)
connection_string = f"tcp://{self.remote_ip}:{self.remote_port}"
self.cmd_socket.connect(connection_string)
self.cmd_socket.setsockopt(zmq.CONFLATE, 1)
self.video_socket = self.context.socket(zmq.PULL)
video_connection = f"tcp://{self.remote_ip}:{self.remote_port_video}"
self.video_socket.connect(video_connection)
self.video_socket.setsockopt(zmq.CONFLATE, 1)
print(
f"[INFO] Connected to remote robot at {connection_string} and video stream at {video_connection}."
)
self.is_connected = True
def load_or_run_calibration_(self, name, arm, arm_type):
arm_id = get_arm_id(name, arm_type)
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
if arm_calib_path.exists():
with open(arm_calib_path) as f:
calibration = json.load(f)
else:
print(f"Missing calibration file '{arm_calib_path}'")
calibration = run_full_arm_calibration(arm, self.robot_type, name, arm_type)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f:
json.dump(calibration, f)
return calibration
def calibrate_leader(self):
for name, arm in self.leader_arms.items():
# Connect the bus
arm.connect()
# Disable torque on all motors
for motor_id in arm.motors:
arm.write("Torque_Enable", TorqueMode.DISABLED.value, motor_id)
# Now run calibration
calibration = self.load_or_run_calibration_(name, arm, "leader")
arm.set_calibration(calibration)
def calibrate_follower(self):
for name, bus in self.follower_arms.items():
bus.connect()
# Disable torque on all motors
for motor_id in bus.motors:
bus.write("Torque_Enable", 0, motor_id)
# Then filter out wheels
arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
if not arm_only_dict:
continue
original_motors = bus.motors
bus.motors = arm_only_dict
calibration = self.load_or_run_calibration_(name, bus, "follower")
bus.set_calibration(calibration)
bus.motors = original_motors
def _get_data(self):
"""
Polls the video socket for up to 15 ms. If data arrives, decode only
the *latest* message, returning frames, speed, and arm state. If
nothing arrives for any field, use the last known values.
"""
frames = {}
present_speed = {}
remote_arm_state_tensor = torch.zeros(6, dtype=torch.float32)
# Poll up to 15 ms
poller = zmq.Poller()
poller.register(self.video_socket, zmq.POLLIN)
socks = dict(poller.poll(15))
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
# No new data arrived → reuse ALL old data
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
# Drain all messages, keep only the last
last_msg = None
while True:
try:
obs_string = self.video_socket.recv_string(zmq.NOBLOCK)
last_msg = obs_string
except zmq.Again:
break
if not last_msg:
# No new message → also reuse old
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
# Decode only the final message
try:
observation = json.loads(last_msg)
images_dict = observation.get("images", {})
new_speed = observation.get("present_speed", {})
new_arm_state = observation.get("follower_arm_state", None)
# Convert images
for cam_name, image_b64 in images_dict.items():
if image_b64:
jpg_data = base64.b64decode(image_b64)
np_arr = np.frombuffer(jpg_data, dtype=np.uint8)
frame_candidate = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
if frame_candidate is not None:
frames[cam_name] = frame_candidate
# If remote_arm_state is None and frames is None there is no message then use the previous message
if new_arm_state is not None and frames is not None:
self.last_frames = frames
remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
self.last_remote_arm_state = remote_arm_state_tensor
present_speed = new_speed
self.last_present_speed = new_speed
else:
frames = self.last_frames
remote_arm_state_tensor = self.last_remote_arm_state
present_speed = self.last_present_speed
except Exception as e:
print(f"[DEBUG] Error decoding video message: {e}")
# If decode fails, fall back to old data
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
return frames, present_speed, remote_arm_state_tensor
def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
state_tensor = torch.zeros(3, dtype=torch.int32)
if present_speed:
decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
if "1" in decoded:
state_tensor[0] = decoded["1"]
if "2" in decoded:
state_tensor[1] = decoded["2"]
if "3" in decoded:
state_tensor[2] = decoded["3"]
return state_tensor
def teleop_step(
self, record_data: bool = False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not self.is_connected:
raise DeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
speed_setting = self.speed_levels[self.speed_index]
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
theta_speed = speed_setting["theta"] # e.g. 30, 60, or 90
# Prepare to assign the position of the leader to the follower
arm_positions = []
for name in self.leader_arms:
pos = self.leader_arms[name].read("Present_Position")
pos_tensor = torch.from_numpy(pos).float()
# Instead of pos_tensor.item(), use tolist() to convert the entire tensor to a list
arm_positions.extend(pos_tensor.tolist())
# (The rest of your code for generating wheel commands remains unchanged)
x_cmd = 0.0 # m/s forward/backward
y_cmd = 0.0 # m/s lateral
theta_cmd = 0.0 # deg/s rotation
if self.pressed_keys["forward"]:
x_cmd += xy_speed
if self.pressed_keys["backward"]:
x_cmd -= xy_speed
if self.pressed_keys["left"]:
y_cmd += xy_speed
if self.pressed_keys["right"]:
y_cmd -= xy_speed
if self.pressed_keys["rotate_left"]:
theta_cmd += theta_speed
if self.pressed_keys["rotate_right"]:
theta_cmd -= theta_speed
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions}
self.cmd_socket.send_string(json.dumps(message))
if not record_data:
return
obs_dict = self.capture_observation()
arm_state_tensor = torch.tensor(arm_positions, dtype=torch.float32)
wheel_velocity_tuple = self.wheel_raw_to_body(wheel_commands)
wheel_velocity_mm = (
wheel_velocity_tuple[0] * 1000.0,
wheel_velocity_tuple[1] * 1000.0,
wheel_velocity_tuple[2],
)
wheel_tensor = torch.tensor(wheel_velocity_mm, dtype=torch.float32)
action_tensor = torch.cat([arm_state_tensor, wheel_tensor])
action_dict = {"action": action_tensor}
return obs_dict, action_dict
def capture_observation(self) -> dict:
"""
Capture observations from the remote robot: current follower arm positions,
present wheel speeds (converted to body-frame velocities: x, y, theta),
and a camera frame.
"""
if not self.is_connected:
raise DeviceNotConnectedError("Not connected. Run `connect()` first.")
frames, present_speed, remote_arm_state_tensor = self._get_data()
body_state = self.wheel_raw_to_body(present_speed)
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
obs_dict = {"observation.state": combined_state_tensor}
# Loop over each configured camera
for cam_name, cam in self.cameras.items():
frame = frames.get(cam_name, None)
if frame is None:
# Create a black image using the camera's configured width, height, and channels
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
return obs_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor:
if not self.is_connected:
raise DeviceNotConnectedError("Not connected. Run `connect()` first.")
# Ensure the action tensor has at least 9 elements:
# - First 6: arm positions.
# - Last 3: base commands.
if action.numel() < 9:
# Pad with zeros if there are not enough elements.
padded = torch.zeros(9, dtype=action.dtype)
padded[: action.numel()] = action
action = padded
# Extract arm and base actions.
arm_actions = action[:6].flatten()
base_actions = action[6:].flatten()
x_cmd_mm = base_actions[0].item() # mm/s
y_cmd_mm = base_actions[1].item() # mm/s
theta_cmd = base_actions[2].item() # deg/s
# Convert mm/s to m/s for the kinematics calculations.
x_cmd = x_cmd_mm / 1000.0 # m/s
y_cmd = y_cmd_mm / 1000.0 # m/s
# Compute wheel commands from body commands.
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
arm_positions_list = arm_actions.tolist()
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions_list}
self.cmd_socket.send_string(json.dumps(message))
return action
def print_logs(self):
pass
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError("Not connected.")
if self.cmd_socket:
stop_cmd = {
"raw_velocity": {"left_wheel": 0, "back_wheel": 0, "right_wheel": 0},
"arm_positions": {},
}
self.cmd_socket.send_string(json.dumps(stop_cmd))
self.cmd_socket.close()
if self.video_socket:
self.video_socket.close()
if self.context:
self.context.term()
if PYNPUT_AVAILABLE:
self.listener.stop()
self.is_connected = False
print("[INFO] Disconnected from remote robot.")
def __del__(self):
if getattr(self, "is_connected", False):
self.disconnect()
if PYNPUT_AVAILABLE:
self.listener.stop()
@staticmethod
def degps_to_raw(degps: float) -> int:
steps_per_deg = 4096.0 / 360.0
speed_in_steps = abs(degps) * steps_per_deg
speed_int = int(round(speed_in_steps))
if speed_int > 0x7FFF:
speed_int = 0x7FFF
if degps < 0:
return speed_int | 0x8000
else:
return speed_int & 0x7FFF
@staticmethod
def raw_to_degps(raw_speed: int) -> float:
steps_per_deg = 4096.0 / 360.0
magnitude = raw_speed & 0x7FFF
degps = magnitude / steps_per_deg
if raw_speed & 0x8000:
degps = -degps
return degps
def body_to_wheel_raw(
self,
x_cmd: float,
y_cmd: float,
theta_cmd: float,
wheel_radius: float = 0.05,
base_radius: float = 0.125,
max_raw: int = 3000,
) -> dict:
"""
Convert desired body-frame velocities into wheel raw commands.
Parameters:
x_cmd : Linear velocity in x (m/s).
y_cmd : Linear velocity in y (m/s).
theta_cmd : Rotational velocity (deg/s).
wheel_radius: Radius of each wheel (meters).
base_radius : Distance from the center of rotation to each wheel (meters).
max_raw : Maximum allowed raw command (ticks) per wheel.
Returns:
A dictionary with wheel raw commands:
{"left_wheel": value, "back_wheel": value, "right_wheel": value}.
Notes:
- Internally, the method converts theta_cmd to rad/s for the kinematics.
- The raw command is computed from the wheels angular speed in deg/s
using degps_to_raw(). If any command exceeds max_raw, all commands
are scaled down proportionally.
"""
# Convert rotational velocity from deg/s to rad/s.
theta_rad = theta_cmd * (np.pi / 180.0)
# Create the body velocity vector [x, y, theta_rad].
velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
# Define the wheel mounting angles with a -90° offset.
angles = np.radians(np.array([240, 120, 0]) - 90)
# Build the kinematic matrix: each row maps body velocities to a wheels linear speed.
# The third column (base_radius) accounts for the effect of rotation.
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
# Compute each wheels linear speed (m/s) and then its angular speed (rad/s).
wheel_linear_speeds = m.dot(velocity_vector)
wheel_angular_speeds = wheel_linear_speeds / wheel_radius
# Convert wheel angular speeds from rad/s to deg/s.
wheel_degps = wheel_angular_speeds * (180.0 / np.pi)
# Scaling
steps_per_deg = 4096.0 / 360.0
raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps]
max_raw_computed = max(raw_floats)
if max_raw_computed > max_raw:
scale = max_raw / max_raw_computed
wheel_degps = wheel_degps * scale
# Convert each wheels angular speed (deg/s) to a raw integer.
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
def wheel_raw_to_body(
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
) -> tuple:
"""
Convert wheel raw command feedback back into body-frame velocities.
Parameters:
wheel_raw : Dictionary with raw wheel commands (keys: "left_wheel", "back_wheel", "right_wheel").
wheel_radius: Radius of each wheel (meters).
base_radius : Distance from the robot center to each wheel (meters).
Returns:
A tuple (x_cmd, y_cmd, theta_cmd) where:
x_cmd : Linear velocity in x (m/s).
y_cmd : Linear velocity in y (m/s).
theta_cmd : Rotational velocity in deg/s.
"""
# Extract the raw values in order.
raw_list = [
int(wheel_raw.get("left_wheel", 0)),
int(wheel_raw.get("back_wheel", 0)),
int(wheel_raw.get("right_wheel", 0)),
]
# Convert each raw command back to an angular speed in deg/s.
wheel_degps = np.array([MobileManipulator.raw_to_degps(r) for r in raw_list])
# Convert from deg/s to rad/s.
wheel_radps = wheel_degps * (np.pi / 180.0)
# Compute each wheels linear speed (m/s) from its angular speed.
wheel_linear_speeds = wheel_radps * wheel_radius
# Define the wheel mounting angles with a -90° offset.
angles = np.radians(np.array([240, 120, 0]) - 90)
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
# Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.
m_inv = np.linalg.inv(m)
velocity_vector = m_inv.dot(wheel_linear_speeds)
x_cmd, y_cmd, theta_rad = velocity_vector
theta_cmd = theta_rad * (180.0 / np.pi)
return (x_cmd, y_cmd, theta_cmd)
class LeKiwi:
def __init__(self, motor_bus):
"""
Initializes the LeKiwi with Feetech motors bus.
"""
self.motor_bus = motor_bus
self.motor_ids = ["left_wheel", "back_wheel", "right_wheel"]
# Initialize motors in velocity mode.
self.motor_bus.write("Lock", 0)
self.motor_bus.write("Mode", [1, 1, 1], self.motor_ids)
self.motor_bus.write("Lock", 1)
print("Motors set to velocity mode.")
def read_velocity(self):
"""
Reads the raw speeds for all wheels. Returns a dictionary with motor names:
"""
raw_speeds = self.motor_bus.read("Present_Speed", self.motor_ids)
return {
"left_wheel": int(raw_speeds[0]),
"back_wheel": int(raw_speeds[1]),
"right_wheel": int(raw_speeds[2]),
}
def set_velocity(self, command_speeds):
"""
Sends raw velocity commands (16-bit encoded values) directly to the motor bus.
The order of speeds must correspond to self.motor_ids.
"""
self.motor_bus.write("Goal_Speed", command_speeds, self.motor_ids)
def stop(self):
"""Stops the robot by setting all motor speeds to zero."""
self.motor_bus.write("Goal_Speed", [0, 0, 0], self.motor_ids)
print("Motors stopped.")

View File

@@ -1,4 +0,0 @@
from .configuration_moss import MossRobotConfig
from .robot_moss import MossRobot
__all__ = ["MossRobotConfig", "MossRobot"]

View File

@@ -1,30 +0,0 @@
from dataclasses import dataclass, field
from lerobot.common.cameras import CameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("moss")
@dataclass
class MossRobotConfig(RobotConfig):
# Port to connect to the robot
port: str
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
mock: bool = False
# motors
shoulder_pan: tuple = (1, "sts3215")
shoulder_lift: tuple = (2, "sts3215")
elbow_flex: tuple = (3, "sts3215")
wrist_flex: tuple = (4, "sts3215")
wrist_roll: tuple = (5, "sts3215")
gripper: tuple = (6, "sts3215")
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -1,223 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import time
import numpy as np
from lerobot.common.cameras.utils import make_cameras_from_configs
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.motors import TorqueMode
from lerobot.common.motors.feetech import (
FeetechMotorsBus,
apply_feetech_offsets_from_calibration,
run_full_arm_calibration,
)
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .configuration_moss import MossRobotConfig
class MossRobot(Robot):
"""
[Moss Arm](https://github.com/jess-moss/moss-robot-arms) designed by Jess Moss
"""
config_class = MossRobotConfig
name = "moss"
def __init__(self, config: MossRobotConfig):
super().__init__(config)
self.config = config
self.robot_type = config.type
self.arm = FeetechMotorsBus(
port=self.config.port,
motors={
"shoulder_pan": config.shoulder_pan,
"shoulder_lift": config.shoulder_lift,
"elbow_flex": config.elbow_flex,
"wrist_flex": config.wrist_flex,
"wrist_roll": config.wrist_roll,
"gripper": config.gripper,
},
)
self.cameras = make_cameras_from_configs(config.cameras)
self.is_connected = False
self.logs = {}
@property
def state_feature(self) -> dict:
return {
"dtype": "float32",
"shape": (len(self.arm),),
"names": {"motors": list(self.arm.motors)},
}
@property
def action_feature(self) -> dict:
return self.state_feature
@property
def camera_features(self) -> dict[str, dict]:
cam_ft = {}
for cam_key, cam in self.cameras.items():
cam_ft[cam_key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
def connect(self) -> None:
if self.is_connected:
raise DeviceAlreadyConnectedError(
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
)
logging.info("Connecting arm.")
self.arm.connect()
# We assume that at connection time, arm is in a rest position,
# and torque can be safely disabled to run calibration.
self.arm.write("Torque_Enable", TorqueMode.DISABLED.value)
self.calibrate()
# Mode=0 for Position Control
self.arm.write("Mode", 0)
# Set P_Coefficient to lower value to avoid shakiness (Default is 32)
self.arm.write("P_Coefficient", 16)
# Set I_Coefficient and D_Coefficient to default value 0 and 32
self.arm.write("I_Coefficient", 0)
self.arm.write("D_Coefficient", 32)
# Close the write lock so that Maximum_Acceleration gets written to EPROM address,
# which is mandatory for Maximum_Acceleration to take effect after rebooting.
self.arm.write("Lock", 0)
# Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of
# the motors. Note: this configuration is not in the official STS3215 Memory Table
self.arm.write("Maximum_Acceleration", 254)
self.arm.write("Acceleration", 254)
logging.info("Activating torque.")
self.arm.write("Torque_Enable", TorqueMode.ENABLED.value)
# Check arm can be read
self.arm.read("Present_Position")
# Connect the cameras
for cam in self.cameras.values():
cam.connect()
self.is_connected = True
def calibrate(self) -> None:
"""After calibration all motors function in human interpretable ranges.
Rotations are expressed in degrees in nominal range of [-180, 180],
and linear motions (like gripper of Aloha) in nominal range of [0, 100].
"""
if self.calibration_fpath.exists():
with open(self.calibration_fpath) as f:
calibration = json.load(f)
else:
# TODO(rcadene): display a warning in __init__ if calibration file not available
logging.info(f"Missing calibration file '{self.calibration_fpath}'")
calibration = run_full_arm_calibration(self.arm, self.robot_type, self.name, "follower")
logging.info(f"Calibration is done! Saving calibration file '{self.calibration_fpath}'")
self.calibration_fpath.parent.mkdir(parents=True, exist_ok=True)
with open(self.calibration_fpath, "w") as f:
json.dump(calibration, f)
self.arm.set_calibration(calibration)
apply_feetech_offsets_from_calibration(self.arm, calibration)
def get_observation(self) -> dict[str, np.ndarray]:
"""The returned observations do not have a batch dimension."""
if not self.is_connected:
raise DeviceNotConnectedError(
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
)
obs_dict = {}
# Read arm position
before_read_t = time.perf_counter()
obs_dict[OBS_STATE] = self.arm.read("Present_Position")
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
# Capture images from cameras
for cam_key, cam in self.cameras.items():
before_camread_t = time.perf_counter()
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
self.logs[f"read_camera_{cam_key}_dt_s"] = cam.logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{cam_key}_dt_s"] = time.perf_counter() - before_camread_t
return obs_dict
def send_action(self, action: np.ndarray) -> np.ndarray:
"""Command arm to move to a target joint configuration.
The relative action magnitude may be clipped depending on the configuration parameter
`max_relative_target`. In this case, the action sent differs from original action.
Thus, this function always returns the action actually sent.
Args:
action (np.ndarray): array containing the goal positions for the motors.
Raises:
RobotDeviceNotConnectedError: if robot is not connected.
Returns:
np.ndarray: the action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
)
goal_pos = action
# Cap goal position when too far away from present position.
# /!\ Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = self.arm.read("Present_Position")
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
# Send goal position to the arm
self.arm.write("Goal_Position", goal_pos.astype(np.int32))
return goal_pos
def print_logs(self):
# TODO(aliberts): move robot-specific logs logic here
pass
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(
"ManipulatorRobot is not connected. You need to run `robot.connect()` before disconnecting."
)
self.arm.disconnect()
for cam in self.cameras.values():
cam.disconnect()
self.is_connected = False

View File

@@ -1,95 +0,0 @@
import abc
from pathlib import Path
from typing import Any
import draccus
from lerobot.common.constants import HF_LEROBOT_CALIBRATION, ROBOTS
from lerobot.common.motors import MotorCalibration
from .config import RobotConfig
# TODO(aliberts): action/obs typing such as Generic[ObsType, ActType] similar to gym.Env ?
# https://github.com/Farama-Foundation/Gymnasium/blob/3287c869f9a48d99454306b0d4b4ec537f0f35e3/gymnasium/core.py#L23
class Robot(abc.ABC):
"""The main LeRobot class for implementing robots."""
# Set these in ALL subclasses
config_class: RobotConfig
name: str
def __init__(self, config: RobotConfig):
self.robot_type = self.name
self.id = config.id
self.calibration_dir = (
config.calibration_dir if config.calibration_dir else HF_LEROBOT_CALIBRATION / ROBOTS / self.name
)
self.calibration_dir.mkdir(parents=True, exist_ok=True)
self.calibration_fpath = self.calibration_dir / f"{self.id}.json"
self.calibration: dict[str, MotorCalibration] = {}
if self.calibration_fpath.is_file():
self._load_calibration()
def __str__(self) -> str:
return f"{self.id} {self.__class__.__name__}"
# TODO(aliberts): create a proper Feature class for this that links with datasets
@abc.abstractproperty
def state_feature(self) -> dict:
pass
@abc.abstractproperty
def action_feature(self) -> dict:
pass
@abc.abstractproperty
def camera_features(self) -> dict[str, dict]:
pass
@abc.abstractproperty
def is_connected(self) -> bool:
pass
@abc.abstractmethod
def connect(self) -> None:
"""Connects to the robot."""
pass
@abc.abstractproperty
def is_calibrated(self) -> bool:
pass
@abc.abstractmethod
def calibrate(self) -> None:
"""Calibrates the robot."""
pass
def _load_calibration(self, fpath: Path | None = None) -> None:
fpath = self.calibration_fpath if fpath is None else fpath
with open(fpath) as f, draccus.config_type("json"):
self.calibration = draccus.load(dict[str, MotorCalibration], f)
def _save_calibration(self, fpath: Path | None = None) -> None:
fpath = self.calibration_fpath if fpath is None else fpath
with open(fpath, "w") as f, draccus.config_type("json"):
draccus.dump(self.calibration, f, indent=4)
@abc.abstractmethod
def configure(self) -> None:
pass
@abc.abstractmethod
def get_observation(self) -> dict[str, Any]:
"""Gets observation from the robot."""
pass
@abc.abstractmethod
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
"""Sends actions to the robot."""
pass
@abc.abstractmethod
def disconnect(self) -> None:
"""Disconnects from the robot."""
pass

View File

@@ -1,2 +0,0 @@
from .config_so100_follower import SO100FollowerConfig
from .so100_follower import SO100Follower

View File

@@ -1,22 +0,0 @@
from dataclasses import dataclass, field
from lerobot.common.cameras import CameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("so100_follower")
@dataclass
class SO100FollowerConfig(RobotConfig):
# Port to connect to the arm
port: str
disable_torque_on_disconnect: bool = True
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# cameras
cameras: dict[str, CameraConfig] = field(default_factory=dict)

View File

@@ -1,215 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from typing import Any
from lerobot.common.cameras.utils import make_cameras_from_configs
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.common.motors.feetech import (
FeetechMotorsBus,
OperatingMode,
)
from ..robot import Robot
from ..utils import ensure_safe_goal_position
from .config_so100_follower import SO100FollowerConfig
logger = logging.getLogger(__name__)
class SO100Follower(Robot):
"""
[SO-100 Follower Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
"""
config_class = SO100FollowerConfig
name = "so100_follower"
def __init__(self, config: SO100FollowerConfig):
super().__init__(config)
self.config = config
self.arm = FeetechMotorsBus(
port=self.config.port,
motors={
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100),
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100),
"elbow_flex": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100),
"wrist_flex": Motor(4, "sts3215", MotorNormMode.RANGE_M100_100),
"wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100),
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
},
calibration=self.calibration,
)
self.cameras = make_cameras_from_configs(config.cameras)
@property
def state_feature(self) -> dict:
return {
"dtype": "float32",
"shape": (len(self.arm),),
"names": {"motors": list(self.arm.motors)},
}
@property
def action_feature(self) -> dict:
return self.state_feature
@property
def camera_features(self) -> dict[str, dict]:
cam_ft = {}
for cam_key, cam in self.cameras.items():
cam_ft[cam_key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def is_connected(self) -> bool:
# TODO(aliberts): add cam.is_connected for cam in self.cameras
return self.arm.is_connected
def connect(self) -> None:
"""
We assume that at connection time, arm is in a rest position,
and torque can be safely disabled to run calibration.
"""
if self.is_connected:
raise DeviceAlreadyConnectedError(f"{self} already connected")
self.arm.connect()
if not self.is_calibrated:
self.calibrate()
# Connect the cameras
for cam in self.cameras.values():
cam.connect()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.arm.is_calibrated
def calibrate(self) -> None:
logger.info(f"\nRunning calibration of {self}")
self.arm.disable_torque()
for name in self.arm.names:
self.arm.write("Operating_Mode", name, OperatingMode.POSITION.value)
input("Move robot to the middle of its range of motion and press ENTER....")
homing_offsets = self.arm.set_half_turn_homings()
full_turn_motor = "wrist_roll"
unknown_range_motors = [name for name in self.arm.names if name != full_turn_motor]
logger.info(
f"Move all joints except '{full_turn_motor}' sequentially through their "
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
)
range_mins, range_maxes = self.arm.record_ranges_of_motion(unknown_range_motors)
range_mins[full_turn_motor] = 0
range_maxes[full_turn_motor] = 4095
self.calibration = {}
for name, motor in self.arm.motors.items():
self.calibration[name] = MotorCalibration(
id=motor.id,
drive_mode=0,
homing_offset=homing_offsets[name],
range_min=range_mins[name],
range_max=range_maxes[name],
)
self.arm.write_calibration(self.calibration)
self._save_calibration()
print("Calibration saved to", self.calibration_fpath)
def configure(self) -> None:
with self.arm.torque_disabled():
self.arm.configure_motors()
for name in self.arm.names:
self.arm.write("Operating_Mode", name, OperatingMode.POSITION.value)
# Set P_Coefficient to lower value to avoid shakiness (Default is 32)
self.arm.write("P_Coefficient", name, 16)
# Set I_Coefficient and D_Coefficient to default value 0 and 32
self.arm.write("I_Coefficient", name, 0)
self.arm.write("D_Coefficient", name, 32)
def get_observation(self) -> dict[str, Any]:
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
obs_dict = {}
# Read arm position
start = time.perf_counter()
obs_dict[OBS_STATE] = self.arm.sync_read("Present_Position")
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
# Capture images from cameras
for cam_key, cam in self.cameras.items():
start = time.perf_counter()
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
return obs_dict
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
"""Command arm to move to a target joint configuration.
The relative action magnitude may be clipped depending on the configuration parameter
`max_relative_target`. In this case, the action sent differs from original action.
Thus, this function always returns the action actually sent.
Raises:
RobotDeviceNotConnectedError: if robot is not connected.
Returns:
the action sent to the motors, potentially clipped.
"""
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
goal_pos = action
# Cap goal position when too far away from present position.
# /!\ Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = self.arm.sync_read("Present_Position")
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
# Send goal position to the arm
self.arm.sync_write("Goal_Position", goal_pos)
return goal_pos
def disconnect(self):
if not self.is_connected:
raise DeviceNotConnectedError(f"{self} is not connected.")
self.arm.disconnect(self.config.disable_torque_on_disconnect)
for cam in self.cameras.values():
cam.disconnect()
logger.info(f"{self} disconnected.")

View File

@@ -1,44 +0,0 @@
from dataclasses import dataclass, field
from lerobot.common.cameras import CameraConfig
from lerobot.common.cameras.intel import RealSenseCameraConfig
from lerobot.common.cameras.opencv import OpenCVCameraConfig
from ..config import RobotConfig
@RobotConfig.register_subclass("stretch3")
@dataclass
class Stretch3RobotConfig(RobotConfig):
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: int | None = None
# cameras
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
"navigation": OpenCVCameraConfig(
camera_index="/dev/hello-nav-head-camera",
fps=10,
width=1280,
height=720,
rotation=-90,
),
"head": RealSenseCameraConfig(
name="Intel RealSense D435I",
fps=30,
width=640,
height=480,
rotation=90,
),
"wrist": RealSenseCameraConfig(
name="Intel RealSense D405",
fps=30,
width=640,
height=480,
),
}
)
mock: bool = False

View File

@@ -1,183 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import numpy as np
from stretch_body.gamepad_teleop import GamePadTeleop
from stretch_body.robot import Robot as StretchAPI
from stretch_body.robot_params import RobotParams
from lerobot.common.cameras.utils import make_cameras_from_configs
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
from lerobot.common.datasets.utils import get_nested_item
from ..robot import Robot
from .configuration_stretch3 import Stretch3RobotConfig
# {lerobot_keys: stretch.api.keys}
STRETCH_MOTORS = {
"head_pan.pos": "head.head_pan.pos",
"head_tilt.pos": "head.head_tilt.pos",
"lift.pos": "lift.pos",
"arm.pos": "arm.pos",
"wrist_pitch.pos": "end_of_arm.wrist_pitch.pos",
"wrist_roll.pos": "end_of_arm.wrist_roll.pos",
"wrist_yaw.pos": "end_of_arm.wrist_yaw.pos",
"gripper.pos": "end_of_arm.stretch_gripper.pos",
"base_x.vel": "base.x_vel",
"base_y.vel": "base.y_vel",
"base_theta.vel": "base.theta_vel",
}
class Stretch3Robot(Robot):
"""[Stretch 3](https://hello-robot.com/stretch-3-product), by Hello Robot."""
config_class = Stretch3RobotConfig
name = "stretch3"
def __init__(self, config: Stretch3RobotConfig):
super().__init__(config)
self.config = config
self.robot_type = self.config.type
self.api = StretchAPI()
self.cameras = make_cameras_from_configs(config.cameras)
self.is_connected = False
self.logs = {}
self.teleop = None # TODO remove
# TODO(aliberts): test this
RobotParams.set_logging_level("WARNING")
RobotParams.set_logging_formatter("brief_console_formatter")
self.state_keys = None
self.action_keys = None
@property
def state_feature(self) -> dict:
return {
"dtype": "float32",
"shape": (len(STRETCH_MOTORS),),
"names": {"motors": list(STRETCH_MOTORS)},
}
@property
def action_feature(self) -> dict:
return self.state_feature
@property
def camera_features(self) -> dict[str, dict]:
cam_ft = {}
for cam_key, cam in self.cameras.items():
cam_ft[cam_key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
def connect(self) -> None:
self.is_connected = self.api.startup()
if not self.is_connected:
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
raise ConnectionError()
for cam in self.cameras.values():
cam.connect()
self.is_connected = self.is_connected and cam.is_connected
if not self.is_connected:
print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError()
self.calibrate()
def calibrate(self) -> None:
if not self.api.is_homed():
self.api.home()
def _get_state(self) -> dict:
status = self.api.get_status()
return {k: get_nested_item(status, v, sep=".") for k, v in STRETCH_MOTORS.items()}
def get_observation(self) -> dict[str, np.ndarray]:
obs_dict = {}
# Read Stretch state
before_read_t = time.perf_counter()
state = self._get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
if self.state_keys is None:
self.state_keys = list(state)
state = np.asarray(list(state.values()))
obs_dict[OBS_STATE] = state
# Capture images from cameras
for cam_key, cam in self.cameras.items():
before_camread_t = time.perf_counter()
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
self.logs[f"read_camera_{cam_key}_dt_s"] = cam.logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{cam_key}_dt_s"] = time.perf_counter() - before_camread_t
return obs_dict
def send_action(self, action: np.ndarray) -> np.ndarray:
if not self.is_connected:
raise ConnectionError()
if self.teleop is None:
self.teleop = GamePadTeleop(robot_instance=False)
self.teleop.startup(robot=self)
if self.action_keys is None:
dummy_action = self.teleop.gamepad_controller.get_state()
self.action_keys = list(dummy_action.keys())
action_dict = dict(zip(self.action_keys, action.tolist(), strict=True))
before_write_t = time.perf_counter()
self.teleop.do_motion(state=action_dict, robot=self)
self.push_command()
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
# TODO(aliberts): return action_sent when motion is limited
return action
def print_logs(self) -> None:
pass
# TODO(aliberts): move robot-specific logs logic here
def teleop_safety_stop(self) -> None:
if self.teleop is not None:
self.teleop._safety_stop(robot=self)
def disconnect(self) -> None:
self.api.stop()
if self.teleop is not None:
self.teleop.gamepad_controller.stop()
self.teleop.stop()
for cam in self.cameras.values():
cam.disconnect()
self.is_connected = False

Some files were not shown because too many files have changed in this diff Show More