Compare commits

...

112 Commits

Author SHA1 Message Date
AdilZouitine
36714a14a7 Update tensor device assignment in ReplayBuffer class
- Changed the device assignment for tensors in the ReplayBuffer class from `device` to `storage_device` for consistency and improved resource management.
2025-03-21 14:21:58 +00:00
pre-commit-ci[bot]
68b8e274dd [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-03-20 12:58:44 +00:00
AdilZouitine
1a7b4ec890 Initialize log_alpha with the logarithm of temperature_init in SACPolicy
- Updated the SACPolicy class to set log_alpha using the logarithm of the initial temperature value from the configuration.
2025-03-20 12:57:34 +00:00
pre-commit-ci[bot]
1c9eccd279 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-03-19 18:53:27 +00:00
AdilZouitine
7551260104 Remove unused functions and imports from modeling_sac.py
- Deleted the `find_and_copy_params` function and the `Ensemble` class, as they were deemed unnecessary.
- Cleaned up imports by removing `from_modules` from `tensordict` to enhance code clarity.
- Simplified the assertion in the `Policy` class for better readability.
2025-03-19 18:53:01 +00:00
AdilZouitine
95758cb867 Add intervention rate tracking in act_with_policy function
- Introduced counters for tracking intervention steps and total steps during training.
- Calculated and logged the intervention rate at the end of each episode.
- Reset intervention counters after each episode to ensure accurate tracking.
2025-03-19 18:37:50 +00:00
AdilZouitine
2ecc34ceb9 - Updated the logging condition to use log_freq directly instead of accessing it through cfg.training.log_freq for improved readability and speed. 2025-03-19 13:40:23 +00:00
Eugene Mironov
8598e80718 [PORT HIL-SERL] Optimize training loop, extract config usage (#855)
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-03-19 14:27:32 +01:00
AdilZouitine
6fa3e5f9ad Enhance training information logging in learner server
- Added tracking for replay buffer size and offline replay buffer size during training steps.
2025-03-19 13:16:31 +00:00
AdilZouitine
b7bd13570f Update configuration files for improved performance and flexibility
- Increased frame rate in `maniskill_example.yaml` from 20 to 400 for enhanced simulation speed.
- Updated `sac_maniskill.yaml` to set `dataset_repo_id` to null and adjusted `grad_clip_norm` from 10.0 to 40.0.
- Changed `storage_device` from "cpu" to "cuda" for better resource utilization.
- Modified `save_freq` from 2000000 to 1000000 to optimize saving intervals.
- Enhanced input normalization parameters for `observation.state` and `observation.image` in SAC policy.
- Adjusted `num_critics` from 10 to 2 and `policy_parameters_push_frequency` from 1 to 4 for improved training dynamics.
- Updated `learner_server.py` to utilize `offline_buffer_capacity` for replay buffer initialization.
- Changed action multiplier in `maniskill_manipulator.py` from 1 to 0.03 for finer control over actions.
2025-03-19 09:56:02 +00:00
pre-commit-ci[bot]
f899edb57f [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-03-18 14:57:58 +00:00
AdilZouitine
17ec837a7a Refactor SACPolicy and learner server for improved replay buffer management
- Updated SACPolicy to create critic heads using a list comprehension for better readability.
- Simplified the saving and loading of models using `save_model` and `load_model` functions from the safetensors library.
- Introduced `initialize_offline_replay_buffer` function in the learner server to streamline offline dataset handling and replay buffer initialization.
- Enhanced logging for dataset loading processes to improve traceability during training.
2025-03-18 14:57:15 +00:00
Michel Aractingi
9e3c8461ca Add end effector action space to hil-serl (#861)
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-03-17 14:22:33 +01:00
AdilZouitine
1f23ef7889 Enhance SAC configuration and policy with gradient clipping and temperature management
- Introduced `grad_clip_norm` parameter in SAC configuration for gradient clipping
- Updated SACPolicy to store temperature as an instance variable for consistent usage
- Modified loss calculations in SACPolicy to utilize the instance temperature
- Enhanced MLP and CriticHead to support a customizable final activation function
- Implemented gradient clipping in the learner server during training steps for both actor and critic
- Added tracking for gradient norms in training information
2025-03-17 11:59:21 +00:00
pre-commit-ci[bot]
41219fe81e [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-03-12 10:16:55 +00:00
AdilZouitine
5081c145dc Add custom save and load methods for SAC policy
- Implement `_save_pretrained` method to handle TensorDict state saving
- Add `_from_pretrained` class method for loading SAC policy from files
- Create utility function `find_and_copy_params` to handle parameter copying
2025-03-12 10:15:37 +00:00
AdilZouitine
25b88f3b86 Remove torch.no_grad decorator and optimize next action prediction in SAC policy
- Removed `@torch.no_grad` decorator from Unnormalize forward method

- Added TODO comment for optimizing next action prediction in SAC policy
- Minor formatting adjustment in NaN assertion for log standard deviation
Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com>
2025-03-12 09:46:47 +00:00
s1lent4gnt
d711e20b5f [Port HIL-SERL] Balanced sampler function speed up and refactor to align with train.py (#715)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-03-12 10:35:30 +01:00
Eugene Mironov
700f00c014 [HIL-SERL] Migrate threading to multiprocessing (#759)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-03-05 11:19:31 +01:00
pre-commit-ci[bot]
584cad808e [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-03-04 13:38:48 +00:00
AdilZouitine
d8a1758122 Add storage device configuration for SAC policy and replay buffer
- Introduce `storage_device` parameter in SAC configuration and training settings
- Update learner server to use configurable storage device for replay buffer
- Reduce online buffer capacity in ManiSkill configuration
- Modify replay buffer initialization to support custom storage device
2025-03-04 13:22:35 +00:00
AdilZouitine
1df9ee4f2d Add memory optimization option to ReplayBuffer
- Introduce `optimize_memory` parameter to reduce memory usage in replay buffer
- Implement simplified memory optimization by not storing duplicate next_states
- Update learner server and buffer initialization to use memory optimization by default
2025-02-25 19:04:58 +00:00
AdilZouitine
5b4a7aa81d Add storage device parameter to replay buffer initialization
- Specify storage device for replay buffer to optimize memory management
2025-02-25 15:30:39 +00:00
AdilZouitine
ef8d943e54 Refactor ReplayBuffer with tensor-based storage and improved sampling efficiency
- Replaced list-based memory storage with pre-allocated tensor storage
- Optimized sampling process with direct tensor indexing
- Added support for DrQ image augmentation during sampling for offline dataset
- Improved dataset conversion with more robust episode handling
- Enhanced buffer initialization and state tracking
- Added comprehensive testing for buffer conversion and sampling
2025-02-25 14:26:44 +00:00
AdilZouitine
42a038173f Update ManiSkill configuration and replay buffer to support truncation and dataset handling
- Reduced image size in ManiSkill environment configuration from 128 to 64
- Added support for truncation in replay buffer and actor server
- Updated SAC policy configuration to use a specific dataset and modify vision encoder settings
- Improved dataset conversion process with progress tracking and task naming
- Added flexibility for joint action space masking in learner server
2025-02-24 16:53:37 +00:00
Michel Aractingi
546719137a Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen.
Added tensordict dependencies
Updated the version of torch and torchvision

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-21 10:13:43 +00:00
Eugene Mironov
3ffe0cf0f4 [Port HIL-SERL] Adjust Actor-Learner architecture & clean up dependency management for HIL-SERL (#722) 2025-02-21 10:29:00 +01:00
AdilZouitine
ff82367c62 Refactor SAC policy with performance optimizations and multi-camera support
- Introduced Ensemble and CriticHead classes for more efficient critic network handling
- Added support for multiple camera inputs in observation encoder
- Optimized image encoding by batching image processing
- Updated configuration for ManiSkill environment with reduced image size and action scaling
- Compiled critic networks for improved performance
- Simplified normalization and ensemble handling in critic networks
Co-authored-by: michel-aractingi <michel.aractingi@gmail.com>
2025-02-20 17:14:27 +00:00
Michel Aractingi
ff47c0b0d3 - Fixed big issue in the loading of the policy parameters sent by the learner to the actor -- pass only the actor to the update_policy_parameters and remove strict=False
- Fixed big issue in the normalization of the actions in the `forward` function of the critic -- remove the `torch.no_grad` decorator in `normalize.py` in the normalization function
- Fixed performance issue to boost the optimization frequency by setting the storage device to be the same as the device of learning.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-19 16:22:51 +00:00
AdilZouitine
befa1fe9af Re-enable parameter push thread in learner server
- Uncomment and start the param_push_thread
- Restore thread joining for param_push_thread
2025-02-17 10:26:33 +00:00
AdilZouitine
446f434a8e Improve wandb logging and custom step tracking in logger
- Modify logger to support multiple custom step keys
- Update logging method to handle custom step keys more flexibly

- Enhance logging of optimization step and frequency
Co-authored-by: michel-aractingi  <michel.aractingi@gmail.com>
2025-02-17 10:08:49 +00:00
AdilZouitine
2f3370e42f Add maniskill support.
Co-authored-by: Michel Aractingi <michel.aractingi@gmail.com>
2025-02-14 19:53:29 +00:00
Michel Aractingi
7ae368e983 Fixed bug in the action scale of the intervention actions and offline dataset actions. (scale by inverse delta)
Co-authored-by: Adil Zouitine <adizouitinegm@gmail.com>
2025-02-14 15:17:16 +01:00
Michel Aractingi
36711d766a Modified crop_dataset_roi interface to automatically write the cropped parameters to a json file in the meta of the dataset
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-14 12:32:45 +01:00
Michel Aractingi
c9e50bb9b1 Optimized the replay buffer from the memory side to store data on cpu instead of a gpu device and send the batches to the gpu.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-13 18:03:57 +01:00
Michel Aractingi
95de8e273d nit
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-13 17:12:57 +01:00
Michel Aractingi
b07d95f0dd removed uncomment in actor server
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-13 16:53:33 +01:00
Michel Aractingi
d9a70376d8 Changed the init_final value to center the starting mean and std of the policy
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-13 16:42:43 +01:00
Michel Aractingi
0c32008466 Changed bounds for a new so100 robot
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-13 15:43:30 +01:00
Michel Aractingi
c462a478c7 Hardcoded some normalization parameters. TODO refactor
Added masking actions on the level of the intervention actions and offline dataset

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-13 14:27:14 +01:00
Michel Aractingi
459f22ed30 fix log_alpha in modeling_sac: change to nn.parameter
added pretrained vision model in policy

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-13 11:26:24 +01:00
Michel Aractingi
dc086dc21f Added logging for interventions to monitor the rate of interventions through time
Added an s keyboard command to force success in the case the reward classifier fails

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-13 11:04:49 +01:00
Michel Aractingi
b9217b06db Added possiblity to record and replay delta actions during teleoperation rather than absolute actions
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-12 19:25:41 +01:00
Yoel
6868c88ef1 [PORT-Hilserl] classifier fixes (#695)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-11 11:39:17 +01:00
Eugene Mironov
a1d16fb400 [Port HIL-SERL] Add resnet-10 as default encoder for HIL-SERL (#696)
Co-authored-by: Khalil Meftah <kmeftah.khalil@gmail.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Ke Wang <superwk1017@gmail.com>
2025-02-11 11:37:00 +01:00
Michel Aractingi
a7db3959f5 - Added JointMaskingActionSpace wrapper in gym_manipulator in order to select which joints will be controlled. For example, we can disable the gripper actions for some tasks.
- Added Nan detection mechanisms in the actor, learner and gym_manipulator for the case where we encounter nans in the loop.
- changed the non-blocking in the `.to(device)` functions to only work for the case of cuda because they were causing nans when running the policy on mps
- Added some joint clipping and limits in the env, robot and policy configs. TODO clean this part and make the limits in one config file only.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-11 11:34:46 +01:00
Michel Aractingi
b5f89439ff Added sac_real config file in the policym configs dir.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-10 16:08:13 +01:00
Michel Aractingi
d51374ce12 Several fixes to move the actor_server and learner_server code from the maniskill environment to the real robot environment.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-10 16:03:39 +01:00
Eugene Mironov
b63738674c [HIL-SERL port] Add Reward classifier benchmark tracking to chose best visual encoder (#688) 2025-02-06 18:39:51 +01:00
Michel Aractingi
12525242ce - Added lerobot/scripts/server/gym_manipulator.py that contains all the necessary wrappers to run a gym-style env around the real robot.
- Added `lerobot/scripts/server/find_joint_limits.py` to test the min and max angles of the motion you wish the robot to explore during RL training.
- Added logic in `manipulator.py` to limit the maximum possible joint angles to allow motion within a predefined joint position range. The limits are specified in the yaml config for each robot. Checkout the so100.yaml.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-06 16:29:37 +01:00
Michel Aractingi
7d5a9530f7 fixed bug in crop_dataset_roi.py
added missing buffer.pt in server dir

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-05 18:22:50 +00:00
Michel Aractingi
e0527b4a6b Added additional wrappers for the environment: Action repeat, keyboard interface, reset wrapper
Tested the reset mechanism and keyboard interface and the convert wrapper on the robots.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-04 17:41:14 +00:00
Michel Aractingi
efb1982eec Added crop_dataset_roi.py that allows you to load a lerobotdataset -> crop its images -> create a new lerobot dataset with the cropped and resized images.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-03 17:48:35 +00:00
Michel Aractingi
2211209be5 - Added base gym env class for the real robot environment.
- Added several wrappers around the base gym env robot class.
- Including: time limit, reward classifier, crop images, preprocess observations.
- Added an interactive script crop_roi.py where the user can interactively select the roi in the observation images and return the correct crop values that will improve the policy and reward classifier performance.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-03 15:07:59 +00:00
Michel Aractingi
506821c7df - Refactor observation encoder in modeling_sac.py
- added `torch.compile` to the actor and learner servers.
- organized imports in `train_sac.py`
- optimized the parameters push by not sending the frozen pre-trained encoder.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-03 15:07:58 +00:00
Yoel
f1c8bfe01e [Port HIL-SERL] Add HF vision encoder option in SAC (#651)
Added support with custom pretrained vision encoder to the modeling sac implementation. Great job @ChorntonYoel !
2025-02-03 15:07:58 +00:00
Michel Aractingi
7c89bd1018 Cleaned learner_server.py. Added several block function to improve readability.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-03 15:07:58 +00:00
Michel Aractingi
367dfe51c6 Added support for checkpointing the policy. We can save and load the policy state dict, optimizers state, optimization step and interaction step
Added functions for converting the replay buffer from and to LeRobotDataset. When we want to save the replay buffer, we convert it first to LeRobotDataset format and save it locally and vice-versa.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-03 15:07:58 +00:00
Michel Aractingi
e856ffc91e Removed unnecessary time.sleep in the streaming server on the learner side
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-03 15:07:58 +00:00
Michel Aractingi
9aabe212ea Added missing config files env/maniskill_example.yaml and policy/sac_maniskill.yaml that are necessary to run the lerobot implementation of sac with the maniskill baselines.
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-03 15:07:58 +00:00
Michel Aractingi
42618f4bd6 - Added additional logging information in wandb around the timings of the policy loop and optimization loop.
- Optimized critic design that improves the performance of the learner loop by a factor of 2
- Cleaned the code and fixed style issues

- Completed the config with actor_learner_config field that contains host-ip and port elemnts that are necessary for the actor-learner servers.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-03 15:07:58 +00:00
Michel Aractingi
36576c958f FREEDOM, added back the optimization loop code in learner_server.py
Ran experiment with pushcube env from maniskill. The learning seem to work.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-03 15:07:58 +00:00
Michel Aractingi
322a78a378 Added server directory in lerobot/scripts that contains scripts and the protobuf message types to split training into two processes, acting and learning. The actor rollouts the policy and collects interaction data while the learner recieves the data, trains the policy and sends the updated parameters to the actor. The two scripts are ran simultaneously
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-02-03 15:07:58 +00:00
AdilZouitine
d75b44f89f Stable version of rlpd + drq 2025-02-03 15:07:57 +00:00
AdilZouitine
1fb03d4cf2 Add type annotations and restructure SACConfig class fields 2025-02-03 15:07:57 +00:00
Adil Zouitine
7d2970fdfe Change SAC policy implementation with configuration and modeling classes 2025-02-03 15:07:50 +00:00
Adil Zouitine
8105efb338 Add rlpd tricks 2025-02-03 15:06:18 +00:00
Adil Zouitine
c1d4bf4b63 SAC works 2025-02-03 15:06:18 +00:00
Adil Zouitine
86df8a433d remove breakpoint 2025-02-03 15:06:18 +00:00
Adil Zouitine
956c547254 [WIP] correct sac implementation 2025-02-03 15:06:18 +00:00
Adil Zouitine
be965019bd Add rlpd tricks 2025-02-03 15:06:18 +00:00
Adil Zouitine
a0a50de8c9 SAC works 2025-02-03 15:06:18 +00:00
Adil Zouitine
c86dace4c2 remove breakpoint 2025-02-03 15:06:18 +00:00
Adil Zouitine
472a7f58ad [WIP] correct sac implementation 2025-02-03 15:06:14 +00:00
Pradeep Kadubandi
068efce3f8 Fix for the issue https://github.com/huggingface/lerobot/issues/638 (#639) 2025-02-03 15:04:03 +00:00
Philip Fung
df7310ea40 fixes to SO-100 readme (#600)
Co-authored-by: Philip Fung <no@one>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-02-03 15:04:03 +00:00
Mishig
100f54ee07 [viz] Fixes & updates to html visualizer (#617) 2025-02-03 15:04:03 +00:00
CharlesCNorton
c2f7af3339 typo fix: batch_convert_dataset_v1_to_v2.py (#615)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-02-03 15:04:03 +00:00
Ville Kuosmanen
a1b5d0faf2 fix(visualise): use correct language description for each episode id (#604)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-02-03 15:04:03 +00:00
CharlesCNorton
d6498150bf fix(docs): typos in benchmark readme.md (#614)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2025-02-03 15:04:03 +00:00
Simon Alibert
31c34a4a49 Fix Quality workflow (#622) 2025-02-03 15:04:03 +00:00
CharlesCNorton
b1cfb6a710 Update README.md (#612) 2025-02-03 15:04:02 +00:00
Eugene Mironov
4a43c83522 Fix broken create_lerobot_dataset_card (#590) 2025-02-03 15:04:02 +00:00
Mishig
0a4e9e25d0 [vizualizer] for LeRobodDataset V2 (#576) 2025-02-03 15:04:02 +00:00
Michel Aractingi
3bb5ed5e91 Extend reward classifier for multiple camera views (#626) 2025-01-13 13:57:49 +01:00
Eugene Mironov
c5bca1cf0f [Port HIL_SERL] Final fixes for the Reward Classifier (#598) 2025-01-06 11:34:00 +01:00
Michel Aractingi
35de91ef2b added temporary fix for missing task_index key in online environment 2024-12-30 13:47:28 +00:00
Michel Aractingi
ee306e2f9b split encoder for critic and actor 2024-12-29 23:59:39 +00:00
Michel Aractingi
bae3b02928 style fixes 2024-12-29 14:35:21 +00:00
KeWang1017
5b4adc00bb Refactor SAC configuration and policy for improved action sampling and stability
- Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions.
- Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior.
- Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
2024-12-29 14:27:19 +00:00
KeWang1017
22fbc9ea4a Refine SAC configuration and policy for enhanced performance
- Updated standard deviation parameterization in SACConfig to 'softplus' with defined min and max values for improved stability.
- Modified action sampling in SACPolicy to use reparameterized sampling, ensuring better gradient flow and log probability calculations.
- Cleaned up log probability calculations in TanhMultivariateNormalDiag for clarity and efficiency.
- Increased evaluation frequency in YAML configuration to 50000 for more efficient training cycles.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
2024-12-29 14:21:49 +00:00
KeWang1017
ca74a13d61 Refactor SACPolicy for improved action sampling and standard deviation handling
- Updated action selection to use distribution sampling and log probabilities for better stochastic behavior.
- Enhanced standard deviation clamping to prevent extreme values, ensuring stability in policy outputs.
- Cleaned up code by removing unnecessary comments and improving readability.

These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
2024-12-29 14:17:25 +00:00
KeWang1017
18a4598986 trying to get sac running 2024-12-29 14:14:13 +00:00
Michel Aractingi
dc54d357ca Added normalization schemes and style checks 2024-12-29 12:51:21 +00:00
Michel Aractingi
08ec971086 added optimizer and sac to factory.py 2024-12-23 14:12:03 +01:00
Eugene Mironov
b53d6e0ff2 [HIL-SERL PORT] Fix linter issues (#588) 2024-12-23 10:44:29 +01:00
Eugene Mironov
70b652f791 [Port Hil-SERL] Add unit tests for the reward classifier & fix imports & check script (#578) 2024-12-23 10:43:55 +01:00
Michel Aractingi
7b68bfb73b added comments from kewang 2024-12-17 18:03:46 +01:00
KeWang1017
7e0f20fbf2 Enhance SAC configuration and policy with new parameters and subsampling logic
- Added `num_subsample_critics`, `critic_target_update_weight`, and `utd_ratio` to SACConfig.
- Implemented target entropy calculation in SACPolicy if not provided.
- Introduced subsampling of critics to prevent overfitting during updates.
- Updated temperature loss calculation to use the new target entropy.
- Added comments for future UTD update implementation.

These changes improve the flexibility and performance of the SAC implementation.
2024-12-17 17:58:11 +01:00
KeWang
def42ff487 Port SAC WIP (#581)
Co-authored-by: KeWang1017 <ke.wang@helloleap.ai>
2024-12-17 16:16:59 +01:00
Michel Aractingi
c9af8e36a7 completed losses 2024-12-17 16:16:36 +01:00
Michel Aractingi
ed66c92383 nit in control_robot.py 2024-12-17 11:04:56 +07:00
Michel Aractingi
668d493bf9 Update lerobot/scripts/train_hilserl_classifier.py
Co-authored-by: Yoel <yoel.chornton@gmail.com>
2024-12-17 02:44:31 +07:00
Claudio Coppola
67f4d7ea7a LerobotDataset pushable to HF from any folder (#563) 2024-12-17 02:44:23 +07:00
berjaoui
4b0c88ff8e Update 7_get_started_with_real_robot.md (#559) 2024-12-17 02:44:11 +07:00
Michel Aractingi
b19fef9d18 Control simulated robot with real leader (#514)
Co-authored-by: Remi <remi.cadene@huggingface.co>
2024-12-17 02:44:03 +07:00
Remi
1612e00e63 Fix missing local_files_only in record/replay (#540)
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
2024-12-17 02:43:10 +07:00
Michel Aractingi
c3bc136420 Refactor OpenX (#505) 2024-12-17 02:42:59 +07:00
Eugene Mironov
1020bc3108 Fixup 2024-12-17 02:42:53 +07:00
Michel Aractingi
7fcf638c0d Add human intervention mechanism and eval_robot script to evaluate policy on the robot (#541)
Co-authored-by: Yoel <yoel.chornton@gmail.com>
2024-12-17 02:41:31 +07:00
Yoel
e35546f58e Reward classifier and training (#528)
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai>
Co-authored-by: resolver101757 <kelster101757@hotmail.com>
Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com>
Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2024-12-17 02:41:29 +07:00
Michel Aractingi
1aa8d4ac91 nit 2024-12-17 02:39:15 +07:00
141 changed files with 16685 additions and 1240 deletions

View File

@@ -50,7 +50,7 @@ jobs:
uses: actions/checkout@v3
- name: Install poetry
run: pipx install poetry
run: pipx install "poetry<2.0.0"
- name: Poetry check
run: poetry check
@@ -64,7 +64,7 @@ jobs:
uses: actions/checkout@v3
- name: Install poetry
run: pipx install poetry
run: pipx install "poetry<2.0.0"
- name: Install poetry-relax
run: poetry self add poetry-relax

View File

@@ -17,6 +17,7 @@ repos:
rev: v3.19.0
hooks:
- id: pyupgrade
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.2
hooks:

View File

@@ -68,7 +68,7 @@
### Acknowledgment
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.

View File

@@ -21,7 +21,7 @@ How to decode videos?
## Variables
**Image content & size**
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
For these reasons, we run this benchmark on four representative datasets:
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
@@ -63,7 +63,7 @@ This of course is affected by the `-g` parameter during encoding, which specifie
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
Additionally, because some policies might request single timestamps that are a few frames appart, we also have the following scenario:
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
@@ -85,8 +85,8 @@ However, due to how video decoding is implemented with `pyav`, we don't have acc
**Average Structural Similarity Index Measure (higher is better)**
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
One aspect that can't be measured here with those metrics is the compatibility of the encoding accross platforms, in particular on web browser, for visualization purposes.
h264, h265 and AV1 are all commonly used codecs and should not be pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
- `yuv420p` is more widely supported across various platforms, including web browsers.
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
@@ -116,7 +116,7 @@ Additional encoding parameters exist that are not included in this benchmark. In
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
See the documentation mentioned above for more detailled info on these settings and for a more comprehensive list of other parameters.
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
- `torchaudio`

View File

@@ -32,7 +32,11 @@ import numpy as np
import pandas as pd
import PIL
import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from skimage.metrics import (
mean_squared_error,
peak_signal_noise_ratio,
structural_similarity,
)
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@@ -81,7 +85,9 @@ def get_directory_size(directory: Path) -> int:
return total_size
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
def load_original_frames(
imgs_dir: Path, timestamps: list[float], fps: int
) -> torch.Tensor:
frames = []
for ts in timestamps:
idx = int(ts * fps)
@@ -94,7 +100,11 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
def save_decoded_frames(
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
imgs_dir: Path,
save_dir: Path,
frames: torch.Tensor,
timestamps: list[float],
fps: int,
) -> None:
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
return
@@ -104,7 +114,10 @@ def save_decoded_frames(
idx = int(ts * fps)
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
shutil.copyfile(
imgs_dir / f"frame_{idx:06d}.png",
save_dir / f"frame_{idx:06d}_original.png",
)
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
@@ -116,11 +129,17 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
hf_dataset = dataset.hf_dataset.with_format(None)
# We only save images from the first camera
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
img_keys = [
key for key in hf_dataset.features if key.startswith("observation.image")
]
imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate(
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
tqdm(
imgs_dataset,
desc=f"saving {dataset.repo_id} first episode images",
leave=False,
)
):
img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
@@ -129,7 +148,9 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
break
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
def sample_timestamps(
timestamps_mode: str, ep_num_images: int, fps: int
) -> list[float]:
# Start at 5 to allow for 2_frames_4_space and 6_frames
idx = random.randint(5, ep_num_images - 1)
match timestamps_mode:
@@ -154,7 +175,9 @@ def decode_video_frames(
backend: str,
) -> torch.Tensor:
if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend
)
else:
raise NotImplementedError(backend)
@@ -181,7 +204,9 @@ def benchmark_decoding(
}
with time_benchmark:
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
frames = decode_video_frames(
video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend
)
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
with time_benchmark:
@@ -190,12 +215,18 @@ def benchmark_decoding(
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
for i in range(num_frames):
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
result["mse_values"].append(
mean_squared_error(original_frames_np[i], frames_np[i])
)
result["psnr_values"].append(
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
peak_signal_noise_ratio(
original_frames_np[i], frames_np[i], data_range=1.0
)
)
result["ssim_values"].append(
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
structural_similarity(
original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0
)
)
if save_frames and sample == 0:
@@ -215,7 +246,9 @@ def benchmark_decoding(
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
for future in tqdm(
as_completed(futures), total=num_samples, desc="samples", leave=False
):
result = future.result()
load_times_video_ms.append(result["load_time_video_ms"])
load_times_images_ms.append(result["load_time_images_ms"])
@@ -275,9 +308,13 @@ def benchmark_encoding_decoding(
random.seed(seed)
benchmark_table = []
for timestamps_mode in tqdm(
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
decoding_cfg["timestamps_modes"],
desc="decodings (timestamps_modes)",
leave=False,
):
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
for backend in tqdm(
decoding_cfg["backends"], desc="decodings (backends)", leave=False
):
benchmark_row = benchmark_decoding(
imgs_dir,
video_path,
@@ -355,14 +392,23 @@ def main(
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
# We only use the first episode
save_first_episode(imgs_dir, dataset)
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
for key, values in tqdm(
encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False
):
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format
encoding_cfg[key] = value
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
args_path = Path(
"_".join(str(value) for value in encoding_cfg.values())
)
video_path = (
output_dir
/ "videos"
/ args_path
/ f"{repo_id.replace('/', '_')}.mp4"
)
benchmark_table += benchmark_encoding_decoding(
dataset,
video_path,
@@ -388,7 +434,9 @@ def main(
# Concatenate all results
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
concatenated_df = pd.concat(df_list, ignore_index=True)
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
concatenated_path = (
output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
)
concatenated_df.to_csv(concatenated_path, header=True, index=False)

18
checkport.py Normal file
View File

@@ -0,0 +1,18 @@
import socket
def check_port(host, port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
s.connect((host, port))
print(f"Connection successful to {host}:{port}!")
except Exception as e:
print(f"Connection failed to {host}:{port}: {e}")
finally:
s.close()
if __name__ == "__main__":
host = "127.0.0.1" # or "localhost"
port = 51350
check_port(host, port)

View File

@@ -0,0 +1,11 @@
FROM huggingface/lerobot-gpu:latest
RUN apt-get update && apt-get install -y --no-install-recommends \
libvulkan1 vulkan-tools \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[mani-skill]"
# Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl"

View File

@@ -1,25 +1,31 @@
This tutorial explains how to use [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot.
# Using the [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot
## Source the parts
## A. Source the parts
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
## Install LeRobot
## B. Install LeRobot
On your computer:
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
```bash
mkdir -p ~/miniconda3
# Linux:
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
# Mac M-series:
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
# Mac Intel:
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
~/miniconda3/bin/conda init bash
```
2. Restart shell or `source ~/.bashrc`
2. Restart shell or `source ~/.bashrc` (*Mac*: `source ~/.bash_profile`) or `source ~/.zshrc` if you're using zshell
3. Create and activate a fresh conda environment for lerobot
```bash
@@ -36,23 +42,30 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
cd ~/lerobot && pip install -e ".[feetech]"
```
For Linux only (not Mac), install extra dependencies for recording datasets:
*For Linux only (not Mac)*: install extra dependencies for recording datasets:
```bash
conda install -y -c conda-forge ffmpeg
pip uninstall -y opencv-python
conda install -y -c conda-forge "opencv>=4.10.0"
```
## Configure the motors
## C. Configure the motors
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the use of our scripts below.
### 1. Find the USB ports associated to each arm
**Find USB ports associated to your arms**
To find the correct ports for each arm, run the utility script twice:
Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm.
#### a. Run the script to find ports
Follow Step 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I), which illustrates the use of our scripts below.
To find the port for each bus servo adapter, run the utility script:
```bash
python lerobot/scripts/find_motors_bus_port.py
```
#### b. Example outputs
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
```
Finding all available ports for the MotorBus.
@@ -64,7 +77,6 @@ Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
Reconnect the usb cable.
```
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
```
Finding all available ports for the MotorBus.
@@ -77,13 +89,20 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the usb cable.
```
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
#### c. Troubleshooting
On Linux, you might need to give access to the USB ports by running:
```bash
sudo chmod 666 /dev/ttyACM0
sudo chmod 666 /dev/ttyACM1
```
**Configure your motors**
#### d. Update YAML file
Now that you have the ports, modify the *port* sections in `so100.yaml`
### 2. Configure the motors
#### a. Set IDs for all 12 motors
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
```bash
python lerobot/scripts/configure_motor.py \
@@ -94,7 +113,7 @@ python lerobot/scripts/configure_motor.py \
--ID 1
```
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
*Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).*
Then unplug your motor and plug the second motor and set its ID to 2.
```bash
@@ -108,23 +127,25 @@ python lerobot/scripts/configure_motor.py \
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
**Remove the gears of the 6 leader motors**
Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
**Add motor horn to the motors**
Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
#### b. Remove the gears of the 6 leader motors
Follow step 2 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=248). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
#### c. Add motor horn to all 12 motors
Follow step 3 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=569). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
## Assemble the arms
## D. Assemble the arms
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
Follow step 4 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=610). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
## Calibrate
## E. Calibrate
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
**Manual calibration of follower arm**
/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
#### a. Manual calibration of follower arm
/!\ Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
You will need to move the follower arm to these positions sequentially:
@@ -139,8 +160,8 @@ python lerobot/scripts/control_robot.py calibrate \
--robot-overrides '~cameras' --arms main_follower
```
**Manual calibration of leader arm**
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
#### b. Manual calibration of leader arm
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
| 1. Zero position | 2. Rotated position | 3. Rest position |
|---|---|---|
@@ -153,7 +174,7 @@ python lerobot/scripts/control_robot.py calibrate \
--robot-overrides '~cameras' --arms main_leader
```
## Teleoperate
## F. Teleoperate
**Simple teleop**
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
@@ -165,14 +186,14 @@ python lerobot/scripts/control_robot.py teleoperate \
```
**Teleop with displaying cameras**
#### a. Teleop with displaying cameras
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
```bash
python lerobot/scripts/control_robot.py teleoperate \
--robot-path lerobot/configs/robot/so100.yaml
```
## Record a dataset
## G. Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
@@ -201,7 +222,7 @@ python lerobot/scripts/control_robot.py record \
--push-to-hub 1
```
## Visualize a dataset
## H. Visualize a dataset
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
```bash
@@ -214,7 +235,7 @@ python lerobot/scripts/visualize_dataset_html.py \
--repo-id ${HF_USER}/so100_test
```
## Replay an episode
## I. Replay an episode
Now try to replay the first episode on your robot:
```bash
@@ -225,7 +246,7 @@ python lerobot/scripts/control_robot.py replay \
--episode 0
```
## Train a policy
## J. Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
@@ -248,7 +269,7 @@ Let's explain it:
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
## Evaluate your policy
## K. Evaluate your policy
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
```bash
@@ -268,7 +289,7 @@ As you can see, it's almost the same command as previously used to record your t
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_so100_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_so100_test`).
## More
## L. More Information
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.

View File

@@ -0,0 +1,94 @@
# Training a HIL-SERL Reward Classifier with LeRobot
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
---
## Training Script Overview
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
1. **Configuration Loading**
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
2. **Dataset Initialization**
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
3. **Classifier Initialization**
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
- A single probability (binary classification), or
- Logits (multi-class classification).
4. **Training Loop Execution**
The script performs:
- Forward and backward passes,
- Optimization steps,
- Periodic logging, evaluation, and checkpoint saving.
---
## Configuring with Hydra
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
### Config File Setup
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
```yaml
# Example: lerobot/configs/policy/reward_classifier.yaml
dataset_repo_id: "my_dataset_repo_id"
## Typical logs and metrics
```
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
After that, you will see training log like this one:
```
[2024-11-29 18:26:36,999][root][INFO] -
Epoch 5/5
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
```
or evaluation log like:
```
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
```
### Metrics Tracking with Weights & Biases (WandB)
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
- **Training Metrics**:
- `train/accuracy`
- `train/loss`
- `train/dataloading_s`
- **Evaluation Metrics**:
- `eval/accuracy`
- `eval/loss`
- `eval/eval_s`
#### Additional Features
You can also log sample predictions during evaluation. Each logged sample will include:
- The **input image**.
- The **predicted label**.
- The **true label**.
- The **classifier's "confidence" (logits/probability)**.
These logs can be useful for diagnosing and debugging performance issues.
#### Generate protobuf files
```bash
python -m grpc_tools.protoc \
-I lerobot/scripts/server \
--python_out=lerobot/scripts/server \
--grpc_python_out=lerobot/scripts/server \
lerobot/scripts/server/hilserl.proto
```

View File

@@ -18,7 +18,10 @@ import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
# We ported a number of existing datasets ourselves, use this to see the list:
print("List of available datasets:")
@@ -26,7 +29,10 @@ pprint(lerobot.available_datasets)
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
hub_api = HfApi()
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
repo_ids = [
info.id
for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])
]
pprint(repo_ids)
# Or simply explore them in your web browser directly at:
@@ -41,7 +47,9 @@ ds_meta = LeRobotDatasetMetadata(repo_id)
# structure of the dataset without downloading the actual data yet (only metadata files — which are
# lightweight).
print(f"Total number of episodes: {ds_meta.total_episodes}")
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
print(
f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}"
)
print(f"Frames per second used during data collection: {ds_meta.fps}")
print(f"Robot type: {ds_meta.robot_type}")
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")

View File

@@ -32,7 +32,9 @@ if torch.cuda.is_available():
print("GPU is available. Device set to:", device)
else:
device = torch.device("cpu")
print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.")
print(
f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU."
)
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
policy.diffusion.num_inference_steps = 10

View File

@@ -31,7 +31,24 @@ delta_timestamps = {
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to supervise the policy.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
"action": [
-0.1,
0.0,
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1.0,
1.1,
1.2,
1.3,
1.4,
],
}
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)

View File

@@ -34,10 +34,14 @@ transforms = v2.Compose(
)
# Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms)
transformed_dataset = LeRobotDataset(
dataset_repo_id, episodes=[0], image_transforms=transforms
)
# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]
transformed_frame = transformed_dataset[first_idx][
transformed_dataset.meta.camera_keys[0]
]
# Create a directory to store output images
output_dir = Path("outputs/image_transforms")

View File

@@ -14,7 +14,10 @@ from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
)
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
device = torch.device("cuda")
@@ -37,7 +40,24 @@ delta_timestamps = {
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to calculate the loss.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
"action": [
-0.1,
0.0,
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1.0,
1.1,
1.2,
1.3,
1.4,
],
}
# Load the last 10% of episodes of the dataset as a validation set.
@@ -53,8 +73,12 @@ print(f"Number of episodes in full dataset: {total_episodes}")
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
# - Load train an val datasets
train_dataset = LeRobotDataset("lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps)
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
train_dataset = LeRobotDataset(
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
)
val_dataset = LeRobotDataset(
"lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps
)
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")

View File

@@ -69,7 +69,9 @@ def load_raw_dataset(zarr_path: Path):
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
print(
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
)
raise e
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
@@ -81,7 +83,9 @@ def calculate_coverage(zarr_data):
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
print(
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
)
raise e
block_pos = zarr_data["state"][:, 2:4]
@@ -111,7 +115,9 @@ def calculate_coverage(zarr_data):
]
space.add(*walls)
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
block_body, block_shapes = PushTEnv.add_tee(
space, block_pos[i].tolist(), block_angle[i].item()
)
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area

View File

@@ -182,7 +182,11 @@ available_real_world_datasets = [
]
available_datasets = sorted(
set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
set(
itertools.chain(
*available_datasets_per_env.values(), available_real_world_datasets
)
)
)
# lists all available policies from `lerobot/common/policies`
@@ -224,9 +228,13 @@ available_policies_per_env = {
"dora_aloha_real": ["act_aloha_real"],
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
env_task_pairs = [
(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks
]
env_dataset_pairs = [
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
(env, dataset)
for env, datasets in available_datasets_per_env.items()
for dataset in datasets
]
env_dataset_policy_triplets = [
(env, dataset, policy)

View File

@@ -45,12 +45,20 @@ def get_stats_einops_patterns(dataset, num_workers=0):
if key in dataset.meta.camera_keys:
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
assert (
c < h and c < w
), f"expect channel first images, but instead {batch[key].shape}"
# sanity check that images are float32 in range [0,1]
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
assert (
batch[key].dtype == torch.float32
), f"expect torch.float32, but instead {batch[key].dtype=}"
assert (
batch[key].max() <= 1
), f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert (
batch[key].min() >= 0
), f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
@@ -98,7 +106,11 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
tqdm.tqdm(
dataloader,
total=ceil(max_num_samples / batch_size),
desc="Compute mean, min, max",
)
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
@@ -113,9 +125,16 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
mean[key] = (
mean[key]
+ this_batch_size * (batch_mean - mean[key]) / running_item_count
)
max[key] = torch.maximum(
max[key], einops.reduce(batch[key], pattern, "max")
)
min[key] = torch.minimum(
min[key], einops.reduce(batch[key], pattern, "min")
)
if i == ceil(max_num_samples / batch_size) - 1:
break
@@ -124,7 +143,9 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
tqdm.tqdm(
dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std"
)
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
@@ -138,7 +159,9 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
std[key] = (
std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
)
if i == ceil(max_num_samples / batch_size) - 1:
break
@@ -177,13 +200,19 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
stats[data_key][stat_key] = einops.reduce(
torch.stack(
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
[
ds.meta.stats[data_key][stat_key]
for ds in ls_datasets
if data_key in ds.meta.stats
],
dim=0,
),
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
total_samples = sum(
d.num_frames for d in ls_datasets if data_key in d.meta.stats
)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of

View File

@@ -74,7 +74,25 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
image_transforms = None
if cfg.training.image_transforms.enable:
cfg_tf = cfg.training.image_transforms
default_tf = OmegaConf.create(
{
"brightness": {"weight": 0.0, "min_max": None},
"contrast": {"weight": 0.0, "min_max": None},
"saturation": {"weight": 0.0, "min_max": None},
"hue": {"weight": 0.0, "min_max": None},
"sharpness": {"weight": 0.0, "min_max": None},
"max_num_transforms": None,
"random_order": False,
"image_size": None,
"interpolation": None,
"image_mean": None,
"image_std": None,
}
)
cfg_tf = OmegaConf.merge(
OmegaConf.create(default_tf), cfg.training.image_transforms
)
image_transforms = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
@@ -88,6 +106,12 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width)
if cfg_tf.image_size
else None,
interpolation=cfg_tf.interpolation,
image_mean=cfg_tf.image_mean,
image_std=cfg_tf.image_std,
)
if isinstance(cfg.dataset_repo_id, str):
@@ -111,6 +135,8 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
for stats_type, listconfig in stats_dict.items():
# example of stats_type: min, max, mean, std
stats = OmegaConf.to_container(listconfig, resolve=True)
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
dataset.meta.stats[key][stats_type] = torch.tensor(
stats, dtype=torch.float32
)
return dataset

View File

@@ -109,7 +109,9 @@ class AsyncImageWriter:
self._stopped = False
if num_threads <= 0 and num_processes <= 0:
raise ValueError("Number of threads and processes must be greater than zero.")
raise ValueError(
"Number of threads and processes must be greater than zero."
)
if self.num_processes == 0:
# Use threading
@@ -123,12 +125,16 @@ class AsyncImageWriter:
# Use multiprocessing
self.queue = multiprocessing.JoinableQueue()
for _ in range(self.num_processes):
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
p = multiprocessing.Process(
target=worker_process, args=(self.queue, self.num_threads)
)
p.daemon = True
p.start()
self.processes.append(p)
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
def save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
):
if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time
image = image.cpu().numpy()

View File

@@ -68,7 +68,9 @@ from lerobot.common.robot_devices.robots.utils import Robot
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
CODEBASE_VERSION = "v2.0"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
LEROBOT_HOME = Path(
os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")
).expanduser()
class LeRobotDatasetMetadata:
@@ -84,7 +86,8 @@ class LeRobotDatasetMetadata:
# Load metadata
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
if not self.local_files_only:
self.pull_from_repo(allow_patterns="meta/")
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root)
@@ -107,7 +110,11 @@ class LeRobotDatasetMetadata:
@cached_property
def _hub_version(self) -> str | None:
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
return (
None
if self.local_files_only
else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
)
@property
def _version(self) -> str:
@@ -121,7 +128,9 @@ class LeRobotDatasetMetadata:
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
fpath = self.video_path.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index
)
return Path(fpath)
def get_episode_chunk(self, ep_index: int) -> int:
@@ -165,7 +174,11 @@ class LeRobotDatasetMetadata:
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
return [
key
for key, ft in self.features.items()
if ft["dtype"] in ["video", "image"]
]
@property
def names(self) -> dict[str, list | dict]:
@@ -214,7 +227,9 @@ class LeRobotDatasetMetadata:
task_index = self.task_to_task_index.get(task, None)
return task_index if task_index is not None else self.total_tasks
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
def save_episode(
self, episode_index: int, episode_length: int, task: str, task_index: int
) -> None:
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
@@ -256,7 +271,9 @@ class LeRobotDatasetMetadata:
"""
for key in self.video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
video_path = self.root / self.get_video_file_path(
ep_index=0, vid_key=key
)
self.info["features"][key]["info"] = get_video_info(video_path)
write_json(self.info, self.root / INFO_PATH)
@@ -291,7 +308,7 @@ class LeRobotDatasetMetadata:
obj.root.mkdir(parents=True, exist_ok=False)
if robot is not None:
features = get_features_from_robot(robot, use_videos)
features = {**(features or {}), **get_features_from_robot(robot)}
robot_type = robot.robot_type
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
@@ -307,7 +324,9 @@ class LeRobotDatasetMetadata:
features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.stats, obj.episodes = {}, {}, []
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
obj.info = create_empty_dataset_info(
CODEBASE_VERSION, fps, robot_type, features, use_videos
)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
@@ -443,7 +462,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root.mkdir(exist_ok=True, parents=True)
# Load metadata
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.local_files_only
)
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
@@ -451,10 +472,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Load actual data
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
self.episode_data_index = get_episode_data_index(
self.meta.episodes, self.episodes
)
# Check timestamps
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
check_timestamps_sync(
self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s
)
# Setup delta_indices
if self.delta_timestamps is not None:
@@ -500,7 +525,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset")
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
create_branch(
repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset"
)
def pull_from_repo(
self,
@@ -528,7 +555,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
files = None
ignore_patterns = None if download_videos else "videos/"
if self.episodes is not None:
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
files = [
str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes
]
if len(self.meta.video_keys) > 0 and download_videos:
video_files = [
str(self.meta.get_video_file_path(ep_idx, vid_key))
@@ -537,7 +566,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
]
files += video_files
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
if not self.local_files_only:
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
@@ -545,7 +575,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
path = str(self.root / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
else:
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
files = [
str(self.root / self.meta.get_data_file_path(ep_idx))
for ep_idx in self.episodes
]
hf_dataset = load_dataset("parquet", data_files=files, split="train")
# TODO(aliberts): hf_dataset.set_format("torch")
@@ -561,12 +594,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def num_frames(self) -> int:
"""Number of frames in selected episodes."""
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
return (
len(self.hf_dataset)
if self.hf_dataset is not None
else self.meta.total_frames
)
@property
def num_episodes(self) -> int:
"""Number of episodes selected."""
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
return (
len(self.episodes)
if self.episodes is not None
else self.meta.total_episodes
)
@property
def features(self) -> dict[str, dict]:
@@ -580,16 +621,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
return get_hf_features_from_features(self.features)
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
def _get_query_indices(
self, idx: int, ep_idx: int
) -> tuple[dict[str, list[int | bool]]]:
ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx]
query_indices = {
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
key: [
max(ep_start.item(), min(ep_end.item() - 1, idx + delta))
for delta in delta_idx
]
for key, delta_idx in self.delta_indices.items()
}
padding = { # Pad values outside of current episode range
f"{key}_is_pad": torch.BoolTensor(
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
[
(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item())
for delta in delta_idx
]
)
for key, delta_idx in self.delta_indices.items()
}
@@ -617,7 +666,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
if key not in self.meta.video_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
def _query_videos(
self, query_timestamps: dict[str, list[float]], ep_idx: int
) -> dict:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
@@ -647,7 +698,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_indices = None
if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
current_ep_idx = (
self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
)
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
@@ -679,19 +732,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
current_ep_idx = (
self.meta.total_episodes if episode_index is None else episode_index
)
return {
"size": 0,
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
**{
key: current_ep_idx if key == "episode_index" else []
for key in self.features
},
}
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
def _get_image_file_path(
self, episode_index: int, image_key: str, frame_index: int
) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self.root / fpath
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
def _save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path
) -> None:
if self.image_writer is None:
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
@@ -712,7 +774,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer = self.create_episode_buffer()
frame_index = self.episode_buffer["size"]
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
timestamp = (
frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
)
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
@@ -721,11 +785,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
raise ValueError(key)
if self.features[key]["dtype"] not in ["image", "video"]:
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
item = (
frame[key].numpy()
if isinstance(frame[key], torch.Tensor)
else frame[key]
)
self.episode_buffer[key].append(item)
elif self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
episode_index=self.episode_buffer["episode_index"],
image_key=key,
frame_index=frame_index,
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
@@ -734,7 +804,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["size"] += 1
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
def save_episode(
self, task: str, encode_videos: bool = True, episode_data: dict | None = None
) -> None:
"""
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
@@ -801,7 +873,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
ep_dataset = datasets.Dataset.from_dict(
episode_dict, features=self.hf_features, split="train"
)
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
write_parquet(ep_dataset, ep_data_path)
@@ -873,10 +947,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
return video_paths
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
def consolidate(
self, run_compute_stats: bool = True, keep_image_files: bool = False
) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
self.episode_data_index = get_episode_data_index(
self.meta.episodes, self.episodes
)
check_timestamps_sync(
self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s
)
if len(self.meta.video_keys) > 0:
self.encode_videos()
@@ -981,7 +1061,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
self.tolerances_s = (
tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
@@ -1058,7 +1140,13 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
features.update(
{
k: v
for k, v in dataset.hf_features.items()
if k not in self.disabled_features
}
)
return features
@property
@@ -1119,7 +1207,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
raise AssertionError(
"We expect the loop to break out as long as the index is within bounds."
)
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_features:

View File

@@ -131,7 +131,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
else:
self._delta_timestamps = None
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
def _make_data_spec(
self, data_spec: dict[str, Any], buffer_capacity: int
) -> dict[str, dict[str, Any]]:
"""Makes the data spec for np.memmap."""
if any(k.startswith("_") for k in data_spec):
raise ValueError(
@@ -154,14 +156,32 @@ class OnlineBuffer(torch.utils.data.Dataset):
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
# with real data rather than the dummy initialization.
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
OnlineBuffer.OCCUPANCY_MASK_KEY: {
"dtype": np.dtype("?"),
"shape": (buffer_capacity,),
},
OnlineBuffer.INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.FRAME_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.EPISODE_INDEX_KEY: {
"dtype": np.dtype("int64"),
"shape": (buffer_capacity,),
},
OnlineBuffer.TIMESTAMP_KEY: {
"dtype": np.dtype("float64"),
"shape": (buffer_capacity,),
},
}
for k, v in data_spec.items():
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
complete_data_spec[k] = {
"dtype": v["dtype"],
"shape": (buffer_capacity, *v["shape"]),
}
return complete_data_spec
def add_data(self, data: dict[str, np.ndarray]):
@@ -188,7 +208,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
# Shift the incoming indices if necessary.
if self.num_frames > 0:
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][
next_index - 1
]
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
@@ -223,7 +245,11 @@ class OnlineBuffer(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
np.unique(
self._data[OnlineBuffer.EPISODE_INDEX_KEY][
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]
]
)
)
@property
@@ -261,7 +287,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
)
)[0]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][
episode_data_indices
]
for data_key in self.delta_timestamps:
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
@@ -278,7 +306,8 @@ class OnlineBuffer(torch.utils.data.Dataset):
# Check violated query timestamps are all outside the episode range.
assert (
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
(query_ts[is_pad] < episode_timestamps[0])
| (episode_timestamps[-1] < query_ts[is_pad])
).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
") inside the episode range."
@@ -293,7 +322,9 @@ class OnlineBuffer(torch.utils.data.Dataset):
def get_data_by_key(self, key: str) -> torch.Tensor:
"""Returns all data for a given data key as a Tensor."""
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
return torch.from_numpy(
self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]
)
def compute_sampler_weights(
@@ -324,13 +355,19 @@ def compute_sampler_weights(
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
included here to avoid adding complexity.
"""
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
if len(offline_dataset) == 0 and (
online_dataset is None or len(online_dataset) == 0
):
raise ValueError(
"At least one of `offline_dataset` or `online_dataset` should be contain data."
)
if (online_dataset is None) ^ (online_sampling_ratio is None):
raise ValueError(
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
)
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
offline_sampling_ratio = (
0 if online_sampling_ratio is None else 1 - online_sampling_ratio
)
weights = []

View File

@@ -37,10 +37,16 @@ def check_chunks_compatible(chunks: tuple, shape: tuple):
assert c > 0
def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
def rechunk_recompress_array(
group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"
):
old_arr = group[name]
if chunks is None:
chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
chunks = (
(chunk_length,) + old_arr.chunks[1:]
if chunk_length is not None
else old_arr.chunks
)
check_chunks_compatible(chunks, old_arr.shape)
if compressor is None:
@@ -82,13 +88,18 @@ def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=No
for i in range(len(shape) - 1):
this_chunk_bytes = itemsize * np.prod(rshape[:i])
next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
if (
this_chunk_bytes <= target_chunk_bytes
and next_chunk_bytes > target_chunk_bytes
):
split_idx = i
rchunks = rshape[:split_idx]
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
this_max_chunk_length = rshape[split_idx]
next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
next_chunk_length = min(
this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes)
)
rchunks.append(next_chunk_length)
len_diff = len(shape) - len(rchunks)
rchunks.extend([1] * len_diff)
@@ -124,7 +135,13 @@ class ReplayBuffer:
root.require_group("data", overwrite=False)
meta = root.require_group("meta", overwrite=False)
if "episode_ends" not in meta:
meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
meta.zeros(
"episode_ends",
shape=(0,),
dtype=np.int64,
compressor=None,
overwrite=False,
)
return cls(root=root)
@classmethod
@@ -193,7 +210,11 @@ class ReplayBuffer:
root = zarr.group(store=store)
# copy without recompression
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
source=src_store,
dest=store,
source_path="/meta",
dest_path="/meta",
if_exists=if_exists,
)
data_group = root.create_group("data", overwrite=True)
if keys is None:
@@ -201,7 +222,9 @@ class ReplayBuffer:
for key in keys:
value = src_root["data"][key]
cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
cpr = cls._resolve_array_compressor(
compressors=compressors, key=key, array=value
)
if cks == value.chunks and cpr == value.compressor:
# copy without recompression
this_path = "/data/" + key
@@ -286,13 +309,17 @@ class ReplayBuffer:
meta_group = root.create_group("meta", overwrite=True)
# save meta, no chunking
for key, value in self.root["meta"].items():
_ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
_ = meta_group.array(
name=key, data=value, shape=value.shape, chunks=value.shape
)
# save data, chunk
data_group = root.create_group("data", overwrite=True)
for key, value in self.root["data"].items():
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
cpr = self._resolve_array_compressor(
compressors=compressors, key=key, array=value
)
if isinstance(value, zarr.Array):
if cks == value.chunks and cpr == value.compressor:
# copy without recompression
@@ -339,13 +366,19 @@ class ReplayBuffer:
@staticmethod
def resolve_compressor(compressor="default"):
if compressor == "default":
compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
compressor = numcodecs.Blosc(
cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE
)
elif compressor == "disk":
compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
compressor = numcodecs.Blosc(
"zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE
)
return compressor
@classmethod
def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
def _resolve_array_compressor(
cls, compressors: dict | str | numcodecs.abc.Codec, key, array
):
# allows compressor to be explicitly set to None
cpr = "nil"
if isinstance(compressors, dict):
@@ -404,7 +437,11 @@ class ReplayBuffer:
if self.backend == "zarr":
for key, value in np_data.items():
_ = meta_group.array(
name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
name=key,
data=value,
shape=value.shape,
chunks=value.shape,
overwrite=True,
)
else:
meta_group.update(np_data)
@@ -514,10 +551,18 @@ class ReplayBuffer:
# create array
if key not in self.data:
if is_zarr:
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
cks = self._resolve_array_chunks(
chunks=chunks, key=key, array=value
)
cpr = self._resolve_array_compressor(
compressors=compressors, key=key, array=value
)
arr = self.data.zeros(
name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
name=key,
shape=new_shape,
chunks=cks,
dtype=value.dtype,
compressor=cpr,
)
else:
# copy data to prevent modify
@@ -544,7 +589,9 @@ class ReplayBuffer:
# rechunk
if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
rechunk_recompress_array(
self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5)
)
def drop_episode(self):
is_zarr = self.backend == "zarr"

View File

@@ -38,7 +38,9 @@ import argparse
from pathlib import Path
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub._download_raw import AVAILABLE_RAW_REPO_IDS
from lerobot.common.datasets.push_dataset_to_hub._download_raw import (
AVAILABLE_RAW_REPO_IDS,
)
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
@@ -73,7 +75,9 @@ def encode_datasets(
check_repo_id(raw_repo_id)
dataset_repo_id_push = get_push_repo_id_from_raw(raw_repo_id, push_repo)
dataset_raw_dir = raw_dir / raw_repo_id
dataset_dir = local_dir / dataset_repo_id_push if local_dir is not None else None
dataset_dir = (
local_dir / dataset_repo_id_push if local_dir is not None else None
)
encoding = {
"vcodec": vcodec,
"pix_fmt": pix_fmt,

View File

@@ -133,7 +133,9 @@ class Jpeg2k(Codec):
)
def decode(self, buf, out=None):
return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out)
return imagecodecs.jpeg2k_decode(
buf, verbose=self.verbose, numthreads=self.numthreads, out=out
)
class JpegXl(Codec):

View File

@@ -44,7 +44,9 @@ from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def get_cameras(hdf5_data):
# ignore depth channel, not currently handled
# TODO(rcadene): add depth
rgb_cameras = [key for key in hdf5_data["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
rgb_cameras = [
key for key in hdf5_data["/observations/images"].keys() if "depth" not in key
] # noqa: SIM118
return rgb_cameras
@@ -73,7 +75,9 @@ def check_format(raw_dir) -> bool:
else:
assert data[f"/observations/images/{camera}"].ndim == 4
b, h, w, c = data[f"/observations/images/{camera}"].shape
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
assert (
c < h and c < w
), f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
def load_from_raw(
@@ -134,14 +138,17 @@ def load_from_raw(
# encode images to a mp4 video
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
video_path = videos_dir / fname
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
encode_video_frames(
tmp_imgs_dir, video_path, fps, **(encoding or {})
)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
@@ -181,15 +188,18 @@ def to_hf_dataset(data_dict, video) -> Dataset:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.velocity"].shape[1],
feature=Value(dtype="float32", id=None),
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.effort"].shape[1],
feature=Value(dtype="float32", id=None),
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)

View File

@@ -26,7 +26,9 @@ import torch
from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
)
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
@@ -42,11 +44,19 @@ def check_format(raw_dir) -> bool:
return True
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
def load_from_raw(
raw_dir: Path,
videos_dir: Path,
fps: int,
video: bool,
episodes: list[int] | None = None,
):
# Load data stream that will be used as reference for the timestamps synchronization
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
if len(reference_files) == 0:
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
raise ValueError(
f"Missing reference files for camera, starting with in '{raw_dir}'"
)
# select first camera in alphanumeric order
reference_key = sorted(reference_files)[0].stem
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
@@ -107,7 +117,9 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
# each episode starts with timestamp 0 to match the ones from the video
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(
lambda x: x - x.iloc[0]
)
del df["timestamp_utc"]
@@ -120,7 +132,9 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
expected_ep_ids = list(range(df["episode_index"].max() + 1))
if ep_ids != expected_ep_ids:
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
raise ValueError(
f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}"
)
# Create symlink to raw videos directory (that needs to be absolute not relative)
videos_dir.parent.mkdir(parents=True, exist_ok=True)
@@ -152,7 +166,9 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
data_dict[key] = torch.from_numpy(df[key].values)
# is vector
elif df[key].iloc[0].shape[0] > 1:
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
data_dict[key] = torch.stack(
[torch.from_numpy(x.copy()) for x in df[key].values]
)
else:
raise ValueError(key)
@@ -170,15 +186,18 @@ def to_hf_dataset(data_dict, video) -> Dataset:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.velocity"].shape[1],
feature=Value(dtype="float32", id=None),
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.effort"].shape[1],
feature=Value(dtype="float32", id=None),
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)

View File

@@ -143,7 +143,11 @@ def load_from_raw(
else:
state_keys.append(key)
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
lang_key = (
"language_instruction"
if "language_instruction" in dataset.element_spec
else None
)
print(" - image_keys: ", image_keys)
print(" - lang_key: ", lang_key)
@@ -202,7 +206,9 @@ def load_from_raw(
# If lang_key is present, convert the entire tensor at once
if lang_key is not None:
ep_dict["language_instruction"] = [x.numpy().decode("utf-8") for x in episode[lang_key]]
ep_dict["language_instruction"] = [
x.numpy().decode("utf-8") for x in episode[lang_key]
]
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
@@ -234,7 +240,8 @@ def load_from_raw(
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
@@ -259,7 +266,9 @@ def to_hf_dataset(data_dict, video) -> Dataset:
for key in data_dict:
# check if vector state obs
if key.startswith("observation.") and "observation.images." not in key:
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
features[key] = Sequence(
length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)
)
# check if image obs
elif "observation.images." in key:
if video:

View File

@@ -56,7 +56,9 @@ def check_format(raw_dir):
required_datasets.remove("meta/episode_ends")
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
assert all(
nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets
)
def load_from_raw(
@@ -76,7 +78,9 @@ def load_from_raw(
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
print(
"`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`"
)
raise e
# as define in gmy-pusht env: https://github.com/huggingface/gym-pusht/blob/e0684ff988d223808c0a9dcfaba9dc4991791370/gym_pusht/envs/pusht.py#L174
success_threshold = 0.95 # 95% coverage,
@@ -150,7 +154,9 @@ def load_from_raw(
]
space.add(*walls)
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
block_body, block_shapes = PushTEnv.add_tee(
space, block_pos[i].tolist(), block_angle[i].item()
)
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
@@ -159,7 +165,9 @@ def load_from_raw(
reward[i] = np.clip(coverage / success_threshold, 0, 1)
success[i] = coverage > success_threshold
if keypoints_instead_of_image:
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
keypoints[i] = torch.from_numpy(
PushTEnv.get_keypoints(block_shapes).flatten()
)
# last step of demonstration is considered done
done[-1] = True
@@ -184,7 +192,8 @@ def load_from_raw(
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
@@ -193,7 +202,9 @@ def load_from_raw(
if keypoints_instead_of_image:
ep_dict["observation.environment_state"] = keypoints
ep_dict["action"] = actions[from_idx:to_idx]
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["episode_index"] = torch.tensor(
[ep_idx] * num_frames, dtype=torch.int64
)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
# ep_dict["next.observation.image"] = image[1:],
@@ -220,7 +231,8 @@ def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
features["observation.image"] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
)
if keypoints_instead_of_image:
features["observation.environment_state"] = Sequence(
@@ -261,7 +273,9 @@ def from_raw_to_lerobot_format(
if fps is None:
fps = 10
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding)
data_dict = load_from_raw(
raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding
)
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
episode_data_index = calculate_episode_data_index(hf_dataset)
info = {

View File

@@ -26,7 +26,9 @@ from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import (
register_codecs,
)
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
@@ -61,7 +63,9 @@ def check_format(raw_dir) -> bool:
nb_frames = zarr_data["data/camera0_rgb"].shape[0]
required_datasets.remove("meta/episode_ends")
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
assert all(
nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets
)
def load_from_raw(
@@ -79,7 +83,9 @@ def load_from_raw(
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
eff_rot_axis_angle = torch.from_numpy(
zarr_data["data/robot0_eef_rot_axis_angle"][:]
)
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
@@ -129,24 +135,31 @@ def load_from_raw(
save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
encode_video_frames(
tmp_imgs_dir, video_path, fps, **(encoding or {})
)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = state
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["episode_index"] = torch.tensor(
[ep_idx] * num_frames, dtype=torch.int64
)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
ep_dict["episode_data_index_to"] = torch.tensor(
[from_idx + num_frames] * num_frames
)
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
@@ -172,7 +185,8 @@ def to_hf_dataset(data_dict, video):
features["observation.image"] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
@@ -192,7 +206,8 @@ def to_hf_dataset(data_dict, video):
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
)
features["gripper_width"] = Sequence(
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["gripper_width"].shape[1],
feature=Value(dtype="float32", id=None),
)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))

View File

@@ -45,7 +45,9 @@ def concatenate_episodes(ep_dicts):
return data_dict
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
def save_images_concurrently(
imgs_array: numpy.array, out_dir: Path, max_workers: int = 4
):
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
@@ -55,7 +57,10 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
[
executor.submit(save_image, imgs_array[i], i, out_dir)
for i in range(num_images)
]
def get_default_encoding() -> dict:
@@ -64,7 +69,8 @@ def get_default_encoding() -> dict:
return {
k: v.default
for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
if v.default is not inspect.Parameter.empty
and k in ["vcodec", "pix_fmt", "g", "crf"]
}
@@ -77,7 +83,9 @@ def check_repo_id(repo_id: str) -> None:
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
def calculate_episode_data_index(
hf_dataset: datasets.Dataset,
) -> Dict[str, torch.Tensor]:
"""
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.

View File

@@ -40,7 +40,10 @@ from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir):
keys = {"actions", "rewards", "dones"}
nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
nested_keys = {
"observations": {"rgb", "state"},
"next_observations": {"rgb", "state"},
}
xarm_files = list(raw_dir.glob("*.pkl"))
assert len(xarm_files) > 0
@@ -53,11 +56,17 @@ def check_format(raw_dir):
# Check for consistent lengths in nested keys
expected_len = len(dataset_dict["actions"])
assert all(len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict)
assert all(
len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict
)
for key, subkeys in nested_keys.items():
nested_dict = dataset_dict.get(key, {})
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
assert all(
len(nested_dict[subkey]) == expected_len
for subkey in subkeys
if subkey in nested_dict
)
def load_from_raw(
@@ -122,13 +131,18 @@ def load_from_raw(
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
ep_dict[img_key] = [
{"path": f"videos/{fname}", "timestamp": i / fps}
for i in range(num_frames)
]
else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
ep_dict["observation.state"] = state
ep_dict["action"] = action
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["episode_index"] = torch.tensor(
[ep_idx] * num_frames, dtype=torch.int64
)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
# ep_dict["next.observation.image"] = next_image
@@ -153,7 +167,8 @@ def to_hf_dataset(data_dict, video):
features["observation.image"] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
length=data_dict["observation.state"].shape[1],
feature=Value(dtype="float32", id=None),
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)

View File

@@ -43,7 +43,10 @@ class EpisodeAwareSampler:
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend(
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
range(
start_index.item() + drop_n_first_frames,
end_index.item() - drop_n_last_frames,
)
)
self.indices = indices

View File

@@ -57,7 +57,9 @@ class RandomSubsetApply(Transform):
elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None")
elif not (1 <= n_subset <= len(transforms)):
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
raise ValueError(
f"n_subset should be in the interval [1, {len(transforms)}]"
)
self.transforms = transforms
total = sum(p)
@@ -116,16 +118,22 @@ class SharpnessJitter(Transform):
def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)):
if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.")
raise ValueError(
"If sharpness is a single number, it must be non negative."
)
sharpness = [1.0 - sharpness, 1.0 + sharpness]
sharpness[0] = max(sharpness[0], 0.0)
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
sharpness = [float(v) for v in sharpness]
else:
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
raise TypeError(
f"{sharpness=} should be a single number or a sequence with length 2."
)
if not 0.0 <= sharpness[0] <= sharpness[1]:
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
raise ValueError(
f"sharpnesss values should be between (0., inf), but got {sharpness}."
)
return float(sharpness[0]), float(sharpness[1])
@@ -134,7 +142,9 @@ class SharpnessJitter(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
return self._call_kernel(
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
)
def get_image_transforms(
@@ -150,6 +160,10 @@ def get_image_transforms(
sharpness_min_max: tuple[float, float] | None = None,
max_num_transforms: int | None = None,
random_order: bool = False,
interpolation: str | None = None,
image_size: tuple[int, int] | None = None,
image_mean: list[float] | None = None,
image_std: list[float] | None = None,
):
def check_value(name, weight, min_max):
if min_max is not None:
@@ -170,6 +184,22 @@ def get_image_transforms(
weights = []
transforms = []
if image_size is not None:
interpolations = [interpolation.value for interpolation in v2.InterpolationMode]
if interpolation is None:
# Use BICUBIC as default interpolation
interpolation_mode = v2.InterpolationMode.BICUBIC
elif interpolation in interpolations:
interpolation_mode = v2.InterpolationMode(interpolation)
else:
raise ValueError("The interpolation passed is not supported")
# Weight for resizing is always 1
weights.append(1.0)
transforms.append(
v2.Resize(
size=(image_size[0], image_size[1]), interpolation=interpolation_mode
)
)
if brightness_min_max is not None and brightness_weight > 0.0:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
@@ -185,6 +215,15 @@ def get_image_transforms(
if sharpness_min_max is not None and sharpness_weight > 0.0:
weights.append(sharpness_weight)
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
if image_mean is not None and image_std is not None:
# Weight for normalization is always 1
weights.append(1.0)
transforms.append(
v2.Normalize(
mean=image_mean,
std=image_std,
)
)
n_subset = len(transforms)
if max_num_transforms is not None:
@@ -194,4 +233,6 @@ def get_image_transforms(
return v2.Identity()
else:
# TODO(rcadene, aliberts): add v2.ToDtype float16?
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
return RandomSubsetApply(
transforms, p=weights, n_subset=n_subset, random_order=random_order
)

View File

@@ -17,9 +17,11 @@ import importlib.resources
import json
import logging
import textwrap
from collections.abc import Iterator
from itertools import accumulate
from pathlib import Path
from pprint import pformat
from types import SimpleNamespace
from typing import Any
import datasets
@@ -41,9 +43,15 @@ EPISODES_PATH = "meta/episodes.jsonl"
STATS_PATH = "meta/stats.json"
TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
DEFAULT_VIDEO_PATH = (
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
)
DEFAULT_PARQUET_PATH = (
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
)
DEFAULT_IMAGE_PATH = (
"images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
)
DATASET_CARD_TEMPLATE = """
---
@@ -97,7 +105,9 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
serialized_dict = {
key: value.tolist() for key, value in flatten_dict(stats).items()
}
return unflatten_dict(serialized_dict)
@@ -155,14 +165,19 @@ def load_stats(local_dir: Path) -> dict:
def load_tasks(local_dir: Path) -> dict:
tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
return {
item["task_index"]: item["task"]
for item in sorted(tasks, key=lambda x: x["task_index"])
}
def load_episodes(local_dir: Path) -> dict:
return load_jsonlines(local_dir / EPISODES_PATH)
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
def load_image_as_numpy(
fpath: str | Path, dtype="float32", channel_first: bool = True
) -> np.ndarray:
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W)
@@ -220,7 +235,10 @@ class BackwardCompatibilityError(Exception):
def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
repo_id: str,
version_to_check: str,
current_version: str,
enforce_breaking_major: bool = True,
) -> None:
current_major, _ = _get_major_minor(current_version)
major_to_check, _ = _get_major_minor(version_to_check)
@@ -273,6 +291,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
)
# TODO: (alibers, azouitine) Add support for ft["shap"] == 0 as Value
return datasets.Features(hf_features)
@@ -314,7 +333,9 @@ def create_empty_dataset_info(
def get_episode_data_index(
episode_dicts: list[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
episode_lengths = {
ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)
}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
@@ -335,7 +356,9 @@ def calculate_total_episode(
return total_episodes
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
def calculate_episode_data_index(
hf_dataset: datasets.Dataset,
) -> dict[str, torch.Tensor]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
@@ -377,7 +400,9 @@ def check_timestamps_sync(
# Track original indices before masking
original_indices = torch.arange(len(diffs))
filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
outside_tolerance_filtered_indices = torch.nonzero(
~filtered_within_tolerance
) # .squeeze()
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
episode_indices = torch.stack(hf_dataset["episode_index"])
@@ -402,7 +427,10 @@ def check_timestamps_sync(
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
delta_timestamps: dict[str, list[float]],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
@@ -410,10 +438,14 @@ def check_delta_timestamps(
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
within_tolerance = [
abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts
]
if not all(within_tolerance):
outside_tolerance[key] = [
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
ts
for ts, is_within in zip(delta_ts, within_tolerance, strict=True)
if not is_within
]
if len(outside_tolerance) > 0:
@@ -431,7 +463,9 @@ def check_delta_timestamps(
return True
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
def get_delta_indices(
delta_timestamps: dict[str, list[float]], fps: int
) -> dict[str, list[int]]:
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
@@ -477,7 +511,6 @@ def create_lerobot_dataset_card(
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
"""
card_tags = ["LeRobot"]
card_template_path = importlib.resources.path("lerobot.common.datasets", "card_template.md")
if tags:
card_tags += tags
@@ -497,8 +530,67 @@ def create_lerobot_dataset_card(
],
)
card_template = (
importlib.resources.files("lerobot.common.datasets") / "card_template.md"
).read_text()
return DatasetCard.from_template(
card_data=card_data,
template_path=str(card_template_path),
template_str=card_template,
**kwargs,
)
class IterableNamespace(SimpleNamespace):
"""
A namespace object that supports both dictionary-like iteration and dot notation access.
Automatically converts nested dictionaries into IterableNamespaces.
This class extends SimpleNamespace to provide:
- Dictionary-style iteration over keys
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
- Dictionary-like methods: items(), keys(), values()
- Recursive conversion of nested dictionaries
Args:
dictionary: Optional dictionary to initialize the namespace
**kwargs: Additional keyword arguments passed to SimpleNamespace
Examples:
>>> data = {"name": "Alice", "details": {"age": 25}}
>>> ns = IterableNamespace(data)
>>> ns.name
'Alice'
>>> ns.details.age
25
>>> list(ns.keys())
['name', 'details']
>>> for key, value in ns.items():
... print(f"{key}: {value}")
name: Alice
details: IterableNamespace(age=25)
"""
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
super().__init__(**kwargs)
if dictionary is not None:
for key, value in dictionary.items():
if isinstance(value, dict):
setattr(self, key, IterableNamespace(value))
else:
setattr(self, key, value)
def __iter__(self) -> Iterator[str]:
return iter(vars(self))
def __getitem__(self, key: str) -> Any:
return vars(self)[key]
def items(self):
return vars(self).items()
def values(self):
return vars(self).values()
def keys(self):
return vars(self).keys()

View File

@@ -26,7 +26,10 @@ from pathlib import Path
from textwrap import dedent
from lerobot import available_datasets
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset, parse_robot_config
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import (
convert_dataset,
parse_robot_config,
)
LOCAL_DIR = Path("data/")
@@ -117,7 +120,10 @@ DATASETS = {
"single_task": "Place the battery into the slot of the remote controller.",
**ALOHA_STATIC_INFO,
},
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
"aloha_static_candy": {
"single_task": "Pick up the candy and unwrap it.",
**ALOHA_STATIC_INFO,
},
"aloha_static_coffee": {
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
**ALOHA_STATIC_INFO,
@@ -159,20 +165,29 @@ DATASETS = {
**ALOHA_STATIC_INFO,
},
"aloha_static_vinh_cup": {
"single_task": "Pick up the platic cup with the right arm, then pop its lid open with the left arm.",
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_vinh_cup_left": {
"single_task": "Pick up the platic cup with the left arm, then pop its lid open with the right arm.",
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_ziploc_slide": {
"single_task": "Slide open the ziploc bag.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_scripted": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_scripted_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
"aloha_sim_insertion_human": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
},
"aloha_sim_insertion_human_image": {
"single_task": "Insert the peg into the socket.",
**ALOHA_STATIC_INFO,
@@ -193,10 +208,19 @@ DATASETS = {
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
**ALOHA_STATIC_INFO,
},
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
"pusht": {
"single_task": "Push the T-shaped block onto the T-shaped target.",
**PUSHT_INFO,
},
"pusht_image": {
"single_task": "Push the T-shaped block onto the T-shaped target.",
**PUSHT_INFO,
},
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
"unitreeh1_rearrange_objects": {
"single_task": "Put the object into the bin.",
**UNITREEH_INFO,
},
"unitreeh1_two_robot_greeting": {
"single_task": "Greet the other robot with a high five.",
**UNITREEH_INFO,
@@ -206,13 +230,31 @@ DATASETS = {
**UNITREEH_INFO,
},
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
"xarm_lift_medium_image": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_lift_medium_replay": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_lift_medium_replay_image": {
"single_task": "Pick up the cube and lift it.",
**XARM_INFO,
},
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
"xarm_push_medium_image": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"xarm_push_medium_replay": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"xarm_push_medium_replay_image": {
"single_task": "Push the cube onto the target.",
**XARM_INFO,
},
"umi_cup_in_the_wild": {
"single_task": "Put the cup on the plate.",
"license": "apache-2.0",

View File

@@ -152,7 +152,9 @@ V1_INFO_PATH = "meta_data/info.json"
V1_STATS_PATH = "meta_data/stats.safetensors"
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
def parse_robot_config(
config_path: Path, config_overrides: list[str] | None = None
) -> tuple[str, dict]:
robot_cfg = init_hydra_config(config_path, config_overrides)
if robot_cfg["robot_type"] in ["aloha", "koch"]:
state_names = [
@@ -203,7 +205,9 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
torch.testing.assert_close(stats_json[key], stats[key])
def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]:
def get_features_from_hf_dataset(
dataset: Dataset, robot_config: dict | None = None
) -> dict[str, list]:
features = {}
for key, ft in dataset.features.items():
if isinstance(ft, datasets.Value):
@@ -215,7 +219,9 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N
dtype = ft.feature.dtype
shape = (ft.length,)
motor_names = (
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
robot_config["names"][key]
if robot_config
else [f"motor_{i}" for i in range(ft.length)]
)
assert len(motor_names) == shape[0]
names = {"motors": motor_names}
@@ -239,11 +245,15 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N
return features
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
def add_task_index_by_episodes(
dataset: Dataset, tasks_by_episodes: dict
) -> tuple[Dataset, list[str]]:
df = dataset.to_pandas()
tasks = list(set(tasks_by_episodes.values()))
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
episodes_to_task_index = {
ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()
}
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
features = dataset.features
@@ -260,10 +270,19 @@ def add_task_index_from_tasks_col(
# HACK: This is to clean some of the instructions in our version of Open X datasets
prefix_to_clean = "tf.Tensor(b'"
suffix_to_clean = "', shape=(), dtype=string)"
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
df[tasks_col] = (
df[tasks_col]
.str.removeprefix(prefix_to_clean)
.str.removesuffix(suffix_to_clean)
)
# Create task_index col
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
tasks_by_episode = (
df.groupby("episode_index")[tasks_col]
.unique()
.apply(lambda x: x.tolist())
.to_dict()
)
tasks = df[tasks_col].unique().tolist()
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
@@ -288,7 +307,9 @@ def split_parquet_by_episodes(
for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk
)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
@@ -320,7 +341,9 @@ def move_videos(
videos_moved = False
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
if len(video_files) == 0:
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
video_files = [
str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")
]
videos_moved = True # Videos have already been moved
assert len(video_files) == total_episodes * len(video_keys)
@@ -351,7 +374,9 @@ def move_videos(
target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
)
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
video_file = V1_VIDEO_FILE.format(
video_key=vid_key, episode_index=ep_idx
)
if len(video_dirs) == 1:
video_path = video_dirs[0] / video_file
else:
@@ -368,7 +393,9 @@ def move_videos(
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
def fix_lfs_video_files_tracking(
work_dir: Path, lfs_untracked_videos: list[str]
) -> None:
"""
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
there's no other option than to download the actual files and reupload them with lfs tracking.
@@ -376,7 +403,12 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
for i in range(0, len(lfs_untracked_videos), 100):
files = lfs_untracked_videos[i : i + 100]
try:
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
subprocess.run(
["git", "rm", "--cached", *files],
cwd=work_dir,
capture_output=True,
check=True,
)
except subprocess.CalledProcessError as e:
print("git rm --cached ERROR:")
print(e.stderr)
@@ -387,10 +419,14 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
subprocess.run(["git", "push"], cwd=work_dir, check=True)
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
def fix_gitattributes(
work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path
) -> None:
shutil.copyfile(clean_gittatributes, current_gittatributes)
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
subprocess.run(
["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True
)
subprocess.run(["git", "push"], cwd=work_dir, check=True)
@@ -399,7 +435,17 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
repo_url = f"https://huggingface.co/datasets/{repo_id}"
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
subprocess.run(
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
[
"git",
"clone",
"--branch",
branch,
"--single-branch",
"--depth",
"1",
repo_url,
str(work_dir),
],
check=True,
env=env,
)
@@ -407,13 +453,19 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
lfs_tracked_files = subprocess.run(
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
["git", "lfs", "ls-files", "-n"],
cwd=work_dir,
capture_output=True,
text=True,
check=True,
)
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
return [f for f in video_files if f not in lfs_tracked_files]
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
def get_videos_info(
repo_id: str, local_dir: Path, video_keys: list[str], branch: str
) -> dict:
# Assumes first episode
video_files = [
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
@@ -421,7 +473,11 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch
]
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
repo_id=repo_id,
repo_type="dataset",
local_dir=local_dir,
revision=branch,
allow_patterns=video_files,
)
videos_info_dict = {}
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
@@ -448,7 +504,11 @@ def convert_dataset(
hub_api = HfApi()
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
repo_id=repo_id,
repo_type="dataset",
revision=v1,
local_dir=v1x_dir,
ignore_patterns="videos*/",
)
branch = "main"
if test_branch:
@@ -480,19 +540,31 @@ def convert_dataset(
if single_task:
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
tasks_by_episodes = {
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
elif tasks_path:
tasks_by_episodes = load_json(tasks_path)
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
tasks_by_episodes = {
int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()
}
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
tasks_by_episodes = {
ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()
}
elif tasks_col:
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(
dataset, tasks_col
)
else:
raise ValueError
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
assert set(tasks) == {
task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks
}
tasks = [
{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)
]
write_jsonlines(tasks, v20_dir / TASKS_PATH)
features["task_index"] = {
"dtype": "int64",
@@ -506,14 +578,25 @@ def convert_dataset(
dataset = dataset.remove_columns(video_keys)
clean_gitattr = Path(
hub_api.hf_hub_download(
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
repo_id=GITATTRIBUTES_REF,
repo_type="dataset",
local_dir=local_dir,
filename=".gitattributes",
)
).absolute()
with tempfile.TemporaryDirectory() as tmp_video_dir:
move_videos(
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
repo_id,
video_keys,
total_episodes,
total_chunks,
Path(tmp_video_dir),
clean_gitattr,
branch,
)
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
videos_info = get_videos_info(
repo_id, v1x_dir, video_keys=video_keys, branch=branch
)
for key in video_keys:
features[key]["shape"] = (
videos_info[key].pop("video.height"),
@@ -521,15 +604,22 @@ def convert_dataset(
videos_info[key].pop("video.channels"),
)
features[key]["video_info"] = videos_info[key]
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
assert math.isclose(
videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3
)
if "encoding" in metadata_v1:
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
assert (
videos_info[key]["video.pix_fmt"]
== metadata_v1["encoding"]["pix_fmt"]
)
else:
assert metadata_v1.get("video", 0) == 0
videos_info = None
# Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
episode_lengths = split_parquet_by_episodes(
dataset, total_episodes, total_chunks, v20_dir
)
if robot_config is not None:
robot_type = robot_config["robot_type"]
@@ -540,7 +630,11 @@ def convert_dataset(
# Episodes
episodes = [
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
{
"episode_index": ep_idx,
"tasks": tasks_by_episodes[ep_idx],
"length": episode_lengths[ep_idx],
}
for ep_idx in episode_indices
]
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
@@ -563,16 +657,27 @@ def convert_dataset(
}
write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
card = create_lerobot_dataset_card(
tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
hub_api.delete_folder(
repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
hub_api.delete_folder(
repo_id=repo_id,
path_in_repo="meta_data",
repo_type="dataset",
revision=branch,
)
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
hub_api.delete_folder(
repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch
)
hub_api.upload_folder(
repo_id=repo_id,
@@ -655,7 +760,11 @@ def main():
if not args.local_dir:
args.local_dir = Path("/tmp/lerobot_dataset_v2")
robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None
robot_config = (
parse_robot_config(args.robot_config, args.robot_overrides)
if args.robot_config
else None
)
del args.robot_config, args.robot_overrides
convert_dataset(**vars(args), robot_config=robot_config)

View File

@@ -227,7 +227,9 @@ def get_audio_info(video_path: Path | str) -> dict:
"json",
str(video_path),
]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
result = subprocess.run(
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
@@ -241,7 +243,9 @@ def get_audio_info(video_path: Path | str) -> dict:
"has_audio": True,
"audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
"audio.bit_rate": int(audio_stream_info["bit_rate"])
if audio_stream_info.get("bit_rate")
else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate")
else None,
@@ -263,7 +267,9 @@ def get_video_info(video_path: Path | str) -> dict:
"json",
str(video_path),
]
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
result = subprocess.run(
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")

View File

@@ -14,9 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from collections import deque
import gymnasium as gym
import numpy as np
import torch
from omegaconf import DictConfig
# from mani_skill.utils import common
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
@@ -30,6 +34,12 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
if cfg.env.name == "real_world":
return
if "maniskill" in cfg.env.name:
env = make_maniskill_env(
cfg, n_envs if n_envs is not None else cfg.eval.batch_size
)
return env
package_name = f"gym_{cfg.env.name}"
try:
@@ -47,7 +57,11 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
# batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
env_cls = (
gym.vector.AsyncVectorEnv
if cfg.eval.use_async_envs
else gym.vector.SyncVectorEnv
)
env = env_cls(
[
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
@@ -56,3 +70,99 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
)
return env
def make_maniskill_env(
cfg: DictConfig, n_envs: int | None = None
) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
env = gym.make(
cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
num_envs=n_envs,
)
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
env = ManiSkillVectorEnv(env, ignore_terminations=True)
# state should have the size of 25
# env = ConvertToLeRobotEnv(env, n_envs)
# env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env.unwrapped.metadata["render_fps"] = 20
return env
class PixelWrapper(gym.Wrapper):
"""
Wrapper for pixel observations. Works with Maniskill vectorized environments
"""
def __init__(self, cfg, env, num_envs, num_frames=3):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
dtype=np.uint8,
)
self._frames = deque([], maxlen=num_frames)
self._render_size = cfg.env.render_size
def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame)
return {
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
self.env.device
)
}
def reset(self, seed):
obs, info = self.env.reset() # (seed=seed)
for _ in range(self._frames.maxlen):
obs_frames = self._get_obs(obs)
return obs_frames, info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
# TODO: Remove this
class ConvertToLeRobotEnv(gym.Wrapper):
def __init__(self, env, num_envs):
super().__init__(env)
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options={})
return self._get_obs(obs), info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._get_obs(obs), reward, terminated, truncated, info
def _get_obs(self, observation):
sensor_data = observation.pop("sensor_data")
del observation["sensor_param"]
images = []
for cam_data in sensor_data.values():
images.append(cam_data["rgb"])
images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(
observation, use_torch=True, device=self.base_env.device
)
ret = dict()
ret["state"] = observation
ret["pixels"] = images
return ret

View File

@@ -28,28 +28,32 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
"""
# map to expected inputs for the policy
return_observations = {}
if "pixels" in observations:
if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
else:
imgs = {"observation.image": observations["pixels"]}
# TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation
# And you keep sensor data rgb
for key, img in observations.items():
if "images" not in key:
continue
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
if img.ndim == 3:
img = img.unsqueeze(0)
# sanity check that images are channel last
_, h, w, c = img.shape
assert (
c < h and c < w
), f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
return_observations[imgkey] = img
return_observations[key] = img
# obs state agent qpos and qvel
# image
if "environment_state" in observations:
return_observations["observation.environment_state"] = torch.from_numpy(
@@ -58,5 +62,43 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
# requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
# return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
return_observations["observation.state"] = observations["observation.state"].float()
return return_observations
def preprocess_maniskill_observation(
observations: dict[str, np.ndarray],
) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# map to expected inputs for the policy
return_observations = {}
# TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation
# And you keep sensor data rgb
q_pos = observations["agent"]["qpos"]
q_vel = observations["agent"]["qvel"]
tcp_pos = observations["extra"]["tcp_pose"]
img = observations["sensor_data"]["base_camera"]["rgb"]
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
return_observations["observation.image"] = img
return_observations["observation.state"] = state
return return_observations

View File

@@ -25,6 +25,7 @@ from glob import glob
from pathlib import Path
import torch
import wandb
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
@@ -83,7 +84,9 @@ class Logger:
pretrained_model_dir_name = "pretrained_model"
training_state_file_name = "training_state.pth"
def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None):
def __init__(
self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None
):
"""
Args:
log_dir: The directory to save all logs and training outputs to.
@@ -103,12 +106,12 @@ class Logger:
enable_wandb = cfg.get("wandb", {}).get("enable", False)
run_offline = not enable_wandb or not project
if run_offline:
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
logging.info(
colored("Logs will be saved locally.", "yellow", attrs=["bold"])
)
self._wandb = None
else:
os.environ["WANDB_SILENT"] = "true"
import wandb
wandb_run_id = None
if cfg.resume:
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
@@ -128,8 +131,12 @@ class Logger:
job_type="train_eval",
resume="must" if cfg.resume else None,
)
# Handle custom step key for rl asynchronous training.
self._wandb_custom_step_key: set[str] | None = None
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
logging.info(
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
)
self._wandb = wandb
@classmethod
@@ -150,7 +157,9 @@ class Logger:
"""
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None):
def save_model(
self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None
):
"""Save the weights of the Policy model using PyTorchModelHubMixin.
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
@@ -173,18 +182,32 @@ class Logger:
self,
save_dir: Path,
train_step: int,
optimizer: Optimizer,
optimizer: Optimizer | dict,
scheduler: LRScheduler | None,
interaction_step: int | None = None,
):
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
All of these are saved as "training_state.pth" under the checkpoint directory.
"""
# In Sac, for example, we have a dictionary of torch.optim.Optimizer
if type(optimizer) is dict:
optimizer_state_dict = {}
for k in optimizer:
optimizer_state_dict[k] = optimizer[k].state_dict()
else:
optimizer_state_dict = optimizer.state_dict()
training_state = {
"step": train_step,
"optimizer": optimizer.state_dict(),
"optimizer": optimizer_state_dict,
**get_global_random_state(),
}
# Interaction step is related to the distributed training code
# In that setup, we have two kinds of steps, the online step of the env and the optimization step
# We need to save both in order to resume the optimization properly and not break the logs dependant on the interaction step
if interaction_step is not None:
training_state["interaction_step"] = interaction_step
if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, save_dir / self.training_state_file_name)
@@ -196,6 +219,7 @@ class Logger:
optimizer: Optimizer,
scheduler: LRScheduler | None,
identifier: str,
interaction_step: int | None = None,
):
"""Checkpoint the model weights and the training state."""
checkpoint_dir = self.checkpoints_dir / str(identifier)
@@ -205,18 +229,34 @@ class Logger:
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
)
self.save_model(
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
checkpoint_dir / self.pretrained_model_dir_name,
policy,
wandb_artifact_name=wandb_artifact_name,
)
self.save_training_state(
checkpoint_dir, train_step, optimizer, scheduler, interaction_step
)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
def load_last_training_state(
self, optimizer: Optimizer | dict, scheduler: LRScheduler | None
) -> int:
"""
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
random state, and return the global training step.
"""
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
optimizer.load_state_dict(training_state["optimizer"])
training_state = torch.load(
self.last_checkpoint_dir / self.training_state_file_name
)
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
if type(training_state["optimizer"]) is dict:
assert set(training_state["optimizer"].keys()) == set(
optimizer.keys()
), "Optimizer dictionaries do not have the same keys during resume!"
for k, v in training_state["optimizer"].items():
optimizer[k].load_state_dict(v)
else:
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
@@ -224,20 +264,63 @@ class Logger:
"The checkpoint contains a scheduler state_dict, but no LRScheduler was provided."
)
# Small hack to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
set_global_random_state(
{k: training_state[k] for k in get_global_random_state()}
)
return training_state["step"]
def log_dict(self, d, step, mode="train"):
def log_dict(
self,
d,
step: int | None = None,
mode="train",
custom_step_key: str | None = None,
):
"""Log a dictionary of metrics to WandB."""
assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log.
if step is None and custom_step_key is None:
raise ValueError("Either step or custom_step_key must be provided.")
if self._wandb is not None:
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example,
# multiple time steps is possible for example, the interaction step with the environment,
# the training step, the evaluation step, etc. So we need to define a custom step key
# to log the correct step for each metric.
if custom_step_key is not None:
if self._wandb_custom_step_key is None:
self._wandb_custom_step_key = set()
new_custom_key = f"{mode}/{custom_step_key}"
if new_custom_key not in self._wandb_custom_step_key:
self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True)
for k, v in d.items():
if not isinstance(v, (int, float, str)):
if not isinstance(v, (int, float, str, wandb.Table)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)
# Do not log the custom step key itself.
if (
self._wandb_custom_step_key is not None
and k in self._wandb_custom_step_key
):
continue
if custom_step_key is not None:
value_custom_step = d[custom_step_key]
self._wandb.log(
{
f"{mode}/{k}": v,
f"{mode}/{custom_step_key}": value_custom_step,
}
)
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"}

View File

@@ -168,4 +168,6 @@ class ACTConfig:
not any(k.startswith("observation.image") for k in self.input_shapes)
and "observation.environment_state" not in self.input_shapes
):
raise ValueError("You must provide at least one image or the environment state among the inputs.")
raise ValueError(
"You must provide at least one image or the environment state among the inputs."
)

View File

@@ -81,10 +81,14 @@ class ACTPolicy(
self.model = ACT(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.expected_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.temporal_ensembler = ACTTemporalEnsembler(
config.temporal_ensemble_coeff, config.chunk_size
)
self.reset()
@@ -107,8 +111,12 @@ class ACTPolicy(
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[k] for k in self.expected_image_keys], dim=-4
)
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
@@ -135,13 +143,18 @@ class ACTPolicy(
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[k] for k in self.expected_image_keys], dim=-4
)
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
F.l1_loss(batch["action"], actions_hat, reduction="none")
* ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {"l1_loss": l1_loss.item()}
@@ -151,7 +164,12 @@ class ACTPolicy(
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
(
-0.5
* (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())
)
.sum(-1)
.mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
@@ -205,7 +223,9 @@ class ACTTemporalEnsembler:
```
"""
self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights = torch.exp(
-temporal_ensemble_coeff * torch.arange(chunk_size)
)
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset()
@@ -221,7 +241,9 @@ class ACTTemporalEnsembler:
time steps, and pop/return the next batch of actions in the sequence.
"""
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(
device=actions.device
)
if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
@@ -229,19 +251,34 @@ class ACTTemporalEnsembler:
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
(self.chunk_size, 1),
dtype=torch.long,
device=self.ensembled_actions.device,
)
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
self.ensembled_actions *= self.ensemble_weights_cumsum[
self.ensembled_actions_count - 1
]
self.ensembled_actions += (
actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
)
self.ensembled_actions /= self.ensemble_weights_cumsum[
self.ensembled_actions_count
]
self.ensembled_actions_count = torch.clamp(
self.ensembled_actions_count + 1, max=self.chunk_size
)
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions = torch.cat(
[self.ensembled_actions, actions[:, -1:]], dim=1
)
self.ensembled_actions_count = torch.cat(
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
[
self.ensembled_actions_count,
torch.ones_like(self.ensembled_actions_count[-1:]),
]
)
# "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = (
@@ -293,7 +330,9 @@ class ACT(nn.Module):
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_robot_state = "observation.state" in config.input_shapes
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
self.use_images = any(
k.startswith("observation.image") for k in config.input_shapes
)
self.use_env_state = "observation.environment_state" in config.input_shapes
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
@@ -308,7 +347,9 @@ class ACT(nn.Module):
config.output_shapes["action"][0], config.dim_model
)
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
self.vae_encoder_latent_output_proj = nn.Linear(
config.dim_model, config.latent_dim * 2
)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
@@ -316,20 +357,28 @@ class ACT(nn.Module):
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
create_sinusoidal_pos_embedding(
num_input_token_encoder, config.dim_model
).unsqueeze(0),
)
# Backbone for image feature extraction.
if self.use_images:
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
replace_stride_with_dilation=[
False,
False,
config.replace_final_stride_with_dilation,
],
weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
# feature map).
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
self.backbone = IntermediateLayerGetter(
backbone_model, return_layers={"layer4": "feature_map"}
)
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = ACTEncoder(config)
@@ -343,7 +392,8 @@ class ACT(nn.Module):
)
if self.use_env_state:
self.encoder_env_state_input_proj = nn.Linear(
config.input_shapes["observation.environment_state"][0], config.dim_model
config.input_shapes["observation.environment_state"][0],
config.dim_model,
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
if self.use_images:
@@ -358,14 +408,18 @@ class ACT(nn.Module):
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.use_images:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
config.dim_model // 2
)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])
self.action_head = nn.Linear(
config.dim_model, config.output_shapes["action"][0]
)
self._reset_parameters()
@@ -375,7 +429,9 @@ class ACT(nn.Module):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
def forward(
self, batch: dict[str, Tensor]
) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
@@ -412,12 +468,20 @@ class ACT(nn.Module):
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.use_robot_state:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = self.vae_encoder_robot_state_input_proj(
batch["observation.state"]
)
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
action_embed = self.vae_encoder_action_input_proj(
batch["action"]
) # (B, S, D)
if self.use_robot_state:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
vae_encoder_input = [
cls_embed,
robot_state_embed,
action_embed,
] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
@@ -455,20 +519,26 @@ class ACT(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
latent_sample = torch.zeros(
[batch_size, self.config.latent_dim], dtype=torch.float32
).to(batch["observation.state"].device)
# Prepare transformer encoder inputs.
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
encoder_in_pos_embed = list(
self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)
)
# Robot state token.
if self.use_robot_state:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
encoder_in_tokens.append(
self.encoder_robot_state_input_proj(batch["observation.state"])
)
# Environment state token.
if self.use_env_state:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"])
self.encoder_env_state_input_proj(
batch["observation.environment_state"]
)
)
# Camera observation features and positional embeddings.
@@ -477,19 +547,29 @@ class ACT(nn.Module):
all_cam_pos_embeds = []
for cam_index in range(batch["observation.images"].shape[-4]):
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
cam_features = self.backbone(batch["observation.images"][:, cam_index])[
"feature_map"
]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
# buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
dtype=cam_features.dtype
)
cam_features = self.encoder_img_feat_input_proj(
cam_features
) # (B, C, h, w)
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
# and move to (sequence, batch, dim).
all_cam_features = torch.cat(all_cam_features, axis=-1)
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
encoder_in_tokens.extend(
einops.rearrange(all_cam_features, "b c h w -> (h w) b c")
)
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
encoder_in_pos_embed.extend(
einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c")
)
# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
@@ -524,12 +604,21 @@ class ACTEncoder(nn.Module):
def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):
super().__init__()
self.is_vae_encoder = is_vae_encoder
num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)])
num_layers = (
config.n_vae_encoder_layers
if self.is_vae_encoder
else config.n_encoder_layers
)
self.layers = nn.ModuleList(
[ACTEncoderLayer(config) for _ in range(num_layers)]
)
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
self,
x: Tensor,
pos_embed: Tensor | None = None,
key_padding_mask: Tensor | None = None,
) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
@@ -540,7 +629,9 @@ class ACTEncoder(nn.Module):
class ACTEncoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.self_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@@ -555,7 +646,9 @@ class ACTEncoderLayer(nn.Module):
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
def forward(
self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
skip = x
if self.pre_norm:
x = self.norm1(x)
@@ -580,7 +673,9 @@ class ACTDecoder(nn.Module):
def __init__(self, config: ACTConfig):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
self.layers = nn.ModuleList(
[ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]
)
self.norm = nn.LayerNorm(config.dim_model)
def forward(
@@ -592,7 +687,10 @@ class ACTDecoder(nn.Module):
) -> Tensor:
for layer in self.layers:
x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
x,
encoder_out,
decoder_pos_embed=decoder_pos_embed,
encoder_pos_embed=encoder_pos_embed,
)
if self.norm is not None:
x = self.norm(x)
@@ -602,8 +700,12 @@ class ACTDecoder(nn.Module):
class ACTDecoderLayer(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)
self.self_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
self.multihead_attn = nn.MultiheadAttention(
config.dim_model, config.n_heads, dropout=config.dropout
)
# Feed forward layers.
self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)
@@ -644,7 +746,9 @@ class ACTDecoderLayer(nn.Module):
if self.pre_norm:
x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = self.self_attn(q, k, value=x)[
0
] # select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
@@ -681,9 +785,14 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso
"""
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
return [
position / np.power(10000, 2 * (hid_j // 2) / dimension)
for hid_j in range(dimension)
]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(num_positions)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.from_numpy(sinusoid_table).float()
@@ -728,7 +837,9 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** (
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
2
* (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2)
/ self.dimension
)
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
@@ -736,9 +847,15 @@ class ACTSinusoidalPositionEmbedding2d(nn.Module):
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
pos_embed_x = torch.stack(
(x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed_y = torch.stack(
(y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1
).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(
0, 3, 1, 2
) # (1, C, H, W)
return pos_embed

View File

@@ -121,7 +121,9 @@ class DiffusionConfig:
"observation.state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"}
)
# Architecture / modeling.
# Vision backbone.
@@ -163,8 +165,13 @@ class DiffusionConfig:
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if (
len(image_keys) == 0
and "observation.environment_state" not in self.input_shapes
):
raise ValueError(
"You must provide at least one image or the environment state among the inputs."
)
if len(image_keys) > 0:
if self.crop_shape is not None:

View File

@@ -88,7 +88,9 @@ class DiffusionPolicy(
self.diffusion = DiffusionModel(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.expected_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
self.use_env_state = "observation.environment_state" in config.input_shapes
self.reset()
@@ -102,7 +104,9 @@ class DiffusionPolicy(
if len(self.expected_image_keys) > 0:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.use_env_state:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
self._queues["observation.environment_state"] = deque(
maxlen=self.config.n_obs_steps
)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@@ -128,14 +132,22 @@ class DiffusionPolicy(
"""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[k] for k in self.expected_image_keys], dim=-4
)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
batch = {
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
@@ -150,8 +162,12 @@ class DiffusionPolicy(
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[k] for k in self.expected_image_keys], dim=-4
)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
@@ -177,7 +193,9 @@ class DiffusionModel(nn.Module):
# Build observation encoders (depending on which observations are provided).
global_cond_dim = config.input_shapes["observation.state"][0]
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
num_images = len(
[k for k in config.input_shapes if k.startswith("observation.image")]
)
self._use_images = False
self._use_env_state = False
if num_images > 0:
@@ -193,7 +211,9 @@ class DiffusionModel(nn.Module):
self._use_env_state = True
global_cond_dim += config.input_shapes["observation.environment_state"][0]
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
self.unet = DiffusionConditionalUnet1d(
config, global_cond_dim=global_cond_dim * config.n_obs_steps
)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
@@ -213,14 +233,21 @@ class DiffusionModel(nn.Module):
# ========= inference ============
def conditional_sample(
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
self,
batch_size: int,
global_cond: Tensor | None = None,
generator: torch.Generator | None = None,
) -> Tensor:
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
# Sample prior.
sample = torch.randn(
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
size=(
batch_size,
self.config.horizon,
self.config.output_shapes["action"][0],
),
dtype=dtype,
device=device,
generator=generator,
@@ -236,7 +263,9 @@ class DiffusionModel(nn.Module):
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
sample = self.noise_scheduler.step(
model_output, t, sample, generator=generator
).prev_sample
return sample
@@ -248,27 +277,39 @@ class DiffusionModel(nn.Module):
if self._use_images:
if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
images_per_camera = einops.rearrange(
batch["observation.images"], "b s n ... -> n (b s) ..."
)
img_features_list = torch.cat(
[
encoder(images)
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
for encoder, images in zip(
self.rgb_encoder, images_per_camera, strict=True
)
]
)
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
img_features_list,
"(n b s) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
)
else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
einops.rearrange(
batch["observation.images"], "b s n ... -> (b s n) ..."
)
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
img_features,
"(b s n) ... -> b s (n ...)",
b=batch_size,
s=n_obs_steps,
)
global_cond_feats.append(img_features)
@@ -354,7 +395,9 @@ class DiffusionModel(nn.Module):
elif self.config.prediction_type == "sample":
target = batch["action"]
else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
raise ValueError(
f"Unsupported prediction type {self.config.prediction_type}"
)
loss = F.mse_loss(pred, target, reduction="none")
@@ -414,7 +457,9 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
@@ -456,7 +501,9 @@ class DiffusionRgbEncoder(nn.Module):
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -477,7 +524,9 @@ class DiffusionRgbEncoder(nn.Module):
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
)
# Set up pooling and final layers.
@@ -485,17 +534,25 @@ class DiffusionRgbEncoder(nn.Module):
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
# Note: we have a check in the config class to make sure all images have the same shape.
image_key = image_keys[0]
dummy_input_h_w = (
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
config.crop_shape
if config.crop_shape is not None
else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(
size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.pool = SpatialSoftmax(
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
@@ -522,7 +579,9 @@ class DiffusionRgbEncoder(nn.Module):
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module:
"""
Args:
@@ -535,7 +594,11 @@ def _replace_submodules(
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
@@ -550,7 +613,9 @@ def _replace_submodules(
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
return root_module
@@ -578,7 +643,9 @@ class DiffusionConv1dBlock(nn.Module):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.Conv1d(
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
@@ -601,9 +668,13 @@ class DiffusionConditionalUnet1d(nn.Module):
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
nn.Linear(
config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4
),
nn.Mish(),
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
nn.Linear(
config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
),
)
# The FiLM conditioning dimension.
@@ -628,10 +699,16 @@ class DiffusionConditionalUnet1d(nn.Module):
self.down_modules.append(
nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(
dim_in, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
nn.Conv1d(dim_out, dim_out, 3, 2, 1)
if not is_last
else nn.Identity(),
]
)
)
@@ -640,10 +717,14 @@ class DiffusionConditionalUnet1d(nn.Module):
self.mid_modules = nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
),
DiffusionConditionalResidualBlock1d(
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1],
config.down_dims[-1],
**common_res_block_kwargs,
),
]
)
@@ -656,16 +737,24 @@ class DiffusionConditionalUnet1d(nn.Module):
nn.ModuleList(
[
# dim_in * 2, because it takes the encoder's skip connection as well
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(
dim_in * 2, dim_out, **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
dim_out, dim_out, **common_res_block_kwargs
),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
if not is_last
else nn.Identity(),
]
)
)
self.final_conv = nn.Sequential(
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
DiffusionConv1dBlock(
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
),
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
)
@@ -733,17 +822,23 @@ class DiffusionConditionalResidualBlock1d(nn.Module):
self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
self.conv1 = DiffusionConv1dBlock(
in_channels, out_channels, kernel_size, n_groups=n_groups
)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
self.conv2 = DiffusionConv1dBlock(
out_channels, out_channels, kernel_size, n_groups=n_groups
)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
nn.Conv1d(in_channels, out_channels, 1)
if in_channels != out_channels
else nn.Identity()
)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:

View File

@@ -52,7 +52,9 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
return TDMPCPolicy, TDMPCConfig
elif name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.configuration_diffusion import (
DiffusionConfig,
)
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
return DiffusionPolicy, DiffusionConfig
@@ -66,12 +68,21 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy, VQBeTConfig
elif name == "sac":
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.sac.modeling_sac import SACPolicy
return SACPolicy, SACConfig
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
def make_policy(
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
hydra_cfg: DictConfig,
pretrained_policy_name_or_path: str | None = None,
dataset_stats=None,
*args,
**kwargs,
) -> Policy:
"""Make an instance of a policy class.
@@ -85,17 +96,19 @@ def make_policy(
be provided when initializing a new policy, and must not be provided when loading a pretrained
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
"""
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
raise ValueError(
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
)
# if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
# raise ValueError(
# "Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
# )
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
if pretrained_policy_name_or_path is None:
# Make a fresh policy.
policy = policy_cls(policy_cfg, dataset_stats)
# HACK: We pass *args and **kwargs to the policy constructor to allow for additional arguments
# for example device for the sac policy.
policy = policy_cls(config=policy_cfg, dataset_stats=dataset_stats)
else:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
@@ -104,7 +117,9 @@ def make_policy(
# huggingface_hub should make it possible to avoid the hack:
# https://github.com/huggingface/huggingface_hub/pull/2274.
policy = policy_cls(policy_cfg)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
policy.load_state_dict(
policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()
)
policy.to(get_safe_torch_device(hydra_cfg.device))

View File

@@ -0,0 +1,35 @@
import json
import os
from dataclasses import asdict, dataclass
@dataclass
class ClassifierConfig:
"""Configuration for the Classifier model."""
num_classes: int = 2
hidden_dim: int = 256
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2
def save_pretrained(self, save_dir):
"""Save config to json file."""
os.makedirs(save_dir, exist_ok=True)
# Convert to dict and save as JSON
config_dict = asdict(self)
with open(os.path.join(save_dir, "config.json"), "w") as f:
json.dump(config_dict, f, indent=2)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path):
"""Load config from json file."""
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
with open(config_file) as f:
config_dict = json.load(f)
return cls(**config_dict)

View File

@@ -0,0 +1,173 @@
import logging
from typing import Optional
import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from .configuration_classifier import ClassifierConfig
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
def __init__(
self,
logits: Tensor,
probabilities: Optional[Tensor] = None,
hidden_states: Optional[Tensor] = None,
):
self.logits = logits
self.probabilities = probabilities
self.hidden_states = hidden_states
def __repr__(self):
return (
f"ClassifierOutput(logits={self.logits}, "
f"probabilities={self.probabilities}, "
f"hidden_states={self.hidden_states})"
)
class Classifier(
nn.Module,
PyTorchModelHubMixin,
# Add Hub metadata
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "vision-classifier"],
):
"""Image classifier built on top of a pre-trained encoder."""
# Add name attribute for factory
name = "classifier"
def __init__(self, config: ClassifierConfig):
from transformers import AutoModel
super().__init__()
self.config = config
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(
self.config.model_name, trust_remote_code=True
)
# Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only")
self.encoder = encoder.vision_model
self.vision_config = encoder.config.vision_config
else:
self.encoder = encoder
self.vision_config = getattr(encoder, "config", None)
# Model type from config
self.is_cnn = self.config.model_type == "cnn"
# For CNNs, initialize backbone
if self.is_cnn:
self._setup_cnn_backbone()
self._freeze_encoder()
self._build_classifier_head()
def _setup_cnn_backbone(self):
"""Set up CNN encoder"""
if hasattr(self.encoder, "fc"):
self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[
-1
] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
self.encoder = self.encoder.to(self.config.device)
def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters."""
for param in self.encoder.parameters():
param.requires_grad = False
def _build_classifier_head(self) -> None:
"""Initialize the classifier head architecture."""
# Get input dimension based on model type
if self.is_cnn:
input_dim = self.feature_dim
else: # Transformer models
if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size
else:
raise ValueError(
"Unsupported transformer architecture since hidden_size is not found"
)
self.classifier_head = nn.Sequential(
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(),
nn.Linear(
self.config.hidden_dim,
1 if self.config.num_classes == 2 else self.config.num_classes,
),
)
self.classifier_head = self.classifier_head.to(self.config.device)
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
"""Extract the appropriate output from the encoder."""
# Process images with the processor (handles resizing and normalization)
# processed = self.processor(
# images=x, # LeRobotDataset already provides proper tensor format
# return_tensors="pt",
# )
# processed = processed["pixel_values"].to(x.device)
processed = x
with torch.no_grad():
if self.is_cnn:
# The HF ResNet applies pooling internally
outputs = self.encoder(processed)
# Get pooled output directly
features = outputs.pooler_output
if features.dim() > 2:
features = features.squeeze(-1).squeeze(-1)
return features
else: # Transformer models
outputs = self.encoder(processed)
if (
hasattr(outputs, "pooler_output")
and outputs.pooler_output is not None
):
return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :]
def forward(self, xs: torch.Tensor) -> ClassifierOutput:
"""Forward pass of the classifier."""
# For training, we expect input to be a tensor directly from LeRobotDataset
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
logits = self.classifier_head(encoder_outputs)
if self.config.num_classes == 2:
logits = logits.squeeze(-1)
probabilities = torch.sigmoid(logits)
else:
probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(
logits=logits, probabilities=probabilities, hidden_states=encoder_outputs
)
def predict_reward(self, x, threshold=0.6):
if self.config.num_classes == 2:
probs = self.forward(x).probabilities
logging.debug(f"Predicted reward images: {probs}")
return (probs > threshold).float()
else:
return torch.argmax(self.forward(x).probabilities, dim=1)

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,17 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
from dataclasses import dataclass
def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory):
root = tmp_path / "dataset"
output_dir = tmp_path / "outputs"
dataset = lerobot_dataset_factory(root=root)
visualize_dataset_html(
dataset,
episodes=[0],
output_dir=output_dir,
serve=False,
)
assert (output_dir / "static" / "episode_0.csv").exists()
@dataclass
class HILSerlConfig:
pass

View File

@@ -0,0 +1,29 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
class HILSerlPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "hilserl"],
):
pass

View File

@@ -130,7 +130,7 @@ class Normalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
# @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items():
@@ -196,7 +196,7 @@ class Unnormalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
# @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items():

View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any
@dataclass
class SACConfig:
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [2],
}
)
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
"observation.environment_state": "min_max",
}
)
input_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"observation.image": {
"mean": [[0.485, 0.456, 0.406]],
"std": [[0.229, 0.224, 0.225]],
},
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"}
)
output_normalization_params: dict[str, dict[str, list[float]]] = field(
default_factory=lambda: {
"action": {"min": [-1, -1], "max": [1, 1]},
}
)
# TODO: Move it outside of the config
actor_learner_config: dict[str, str | int] = field(
default_factory=lambda: {
"learner_host": "127.0.0.1",
"learner_port": 50051,
}
)
camera_number: int = 1
storage_device: str = "cpu"
# Add type annotations for these fields:
vision_encoder_name: str | None = field(default="helper2424/resnet10")
freeze_vision_encoder: bool = True
image_encoder_hidden_dim: int = 32
shared_encoder: bool = True
discount: float = 0.99
temperature_init: float = 1.0
num_critics: int = 2
num_subsample_critics: int | None = None
critic_lr: float = 3e-4
actor_lr: float = 3e-4
temperature_lr: float = 3e-4
critic_target_update_weight: float = 0.005
utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1
state_encoder_hidden_dim: int = 256
latent_dim: int = 256
target_entropy: float | None = None
use_backup_entropy: bool = True
grad_clip_norm: float = 40.0
critic_network_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"hidden_dims": [256, 256],
"activate_final": True,
"final_activation": None,
}
)
actor_network_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"hidden_dims": [256, 256],
"activate_final": True,
}
)
policy_kwargs: dict[str, Any] = field(
default_factory=lambda: {
"use_tanh_squash": True,
"log_std_min": -5,
"log_std_max": 2,
"init_final": 0.05,
}
)

View File

@@ -0,0 +1,981 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: (1) better device management
import math
from typing import Callable, Optional, Tuple, Union, Dict, List
from pathlib import Path
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters
class SACPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"],
):
name = "sac"
def __init__(
self,
config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
if config is None:
config = SACConfig()
self.config = config
if config.input_normalization_modes is not None:
input_normalization_params = _convert_normalization_params_to_tensor(
config.input_normalization_params
)
self.normalize_inputs = Normalize(
config.input_shapes,
config.input_normalization_modes,
input_normalization_params,
)
else:
self.normalize_inputs = nn.Identity()
output_normalization_params = _convert_normalization_params_to_tensor(
config.output_normalization_params
)
# HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
# NOTE: For images the encoder should be shared between the actor and critic
if config.shared_encoder:
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor: SACObservationEncoder = encoder_critic
else:
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
# Create a list of critic heads
critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
self.critic_ensemble = CriticEnsemble(
encoder=encoder_critic,
ensemble=critic_heads,
output_normalization=self.normalize_targets,
)
# Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
]
self.critic_target = CriticEnsemble(
encoder=encoder_critic,
ensemble=target_critic_heads,
output_normalization=self.normalize_targets,
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
self.actor = Policy(
encoder=encoder_actor,
network=MLP(
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
),
action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = (
-np.prod(config.output_shapes["action"][0]) / 2
) # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
# it triggers "can't optimize a non-leaf Tensor"
temperature_init = config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
self.temperature = self.log_alpha.exp().item()
def _save_pretrained(self, save_directory):
"""Custom save method to handle TensorDict properly"""
import os
import json
from dataclasses import asdict
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from safetensors.torch import save_model
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
# Save config
config_dict = asdict(self.config)
with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
json.dump(config_dict, f, indent=2)
print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Optional[Union[str, bool]],
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
) -> "SACPolicy":
"""Custom load method to handle loading SAC policy from saved files"""
import os
import json
from pathlib import Path
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from safetensors.torch import load_model
from lerobot.common.policies.sac.configuration_sac import SACConfig
# Check if model_id is a local path or a hub model ID
if os.path.isdir(model_id):
model_path = Path(model_id)
safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
config_file = os.path.join(model_path, CONFIG_NAME)
else:
# Download the safetensors file from the hub
safetensors_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
# Download the config file
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except Exception:
config_file = None
# Load or create config
if config_file and os.path.exists(config_file):
# Load config from file
with open(config_file) as f:
config_dict = json.load(f)
config = SACConfig(**config_dict)
else:
# Use the provided config or create a default one
config = model_kwargs.get("config", SACConfig())
# Create a new instance with the loaded config
model = cls(config=config)
# Load state dict from safetensors file
if os.path.exists(safetensors_file):
load_model(model, filename=safetensors_file, device=map_location)
return model
def reset(self):
"""Reset the policy"""
pass
def to(self, *args, **kwargs):
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
if self.actor.fixed_std is not None:
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
super().to(*args, **kwargs)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
actions, _, _ = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
def critic_forward(
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
self.critic_target.parameters(),
self.critic_ensemble.parameters(),
strict=False,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def compute_loss_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
self.temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(
next_observations, next_observation_features
)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[
"action"
]
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(
observations,
actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(1)
).sum()
return critics_loss
def compute_loss_temperature(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
).mean()
return temperature_loss
def compute_loss_actor(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
self.temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"]
q_preds = self.critic_forward(
observations,
actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.activate_final = activate_final
layers = []
# First layer uses input_dim
layers.append(nn.Linear(input_dim, hidden_dims[0]))
# Add activation after first layer
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
# Rest of the layers
for i in range(1, len(hidden_dims)):
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i]))
# If we're at the final layer and a final activation is specified, use it
if (
i + 1 == len(hidden_dims)
and activate_final
and final_activation is not None
):
layers.append(
final_activation
if isinstance(final_activation, nn.Module)
else getattr(nn, final_activation)()
)
else:
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class CriticHead(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
init_final: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.output_layer(self.net(x))
class CriticEnsemble(nn.Module):
"""
┌──────────────────┬─────────────────────────────────────────────────────────┐
│ Critic Ensemble │ │
├──────────────────┘ │
│ │
│ ┌────┐ ┌────┐ ┌────┐ │
│ │ Q1 │ │ Q2 │ │ Qn │ │
│ └────┘ └────┘ └────┘ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ │ │ │ │ │ │
│ │ MLP 1 │ │ MLP 2 │ │ MLP │ │
│ │ │ │ │ ... │ num_critics │ │
│ │ │ │ │ │ │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
│ ▲ ▲ ▲ │
│ └───────────────────┴───────┬────────────────────────────┘ │
│ │ │
│ │ │
│ ┌───────────────────┐ │
│ │ Embedding │ │
│ │ │ │
│ └───────────────────┘ │
│ ▲ │
│ │ │
│ ┌─────────────┴────────────┐ │
│ │ │ │
│ │ SACObservationEncoder │ │
│ │ │ │
│ └──────────────────────────┘ │
│ ▲ │
│ │ │
│ │ │
│ │ │
└───────────────────────────┬────────────────────┬───────────────────────────┘
│ Observation │
└────────────────────┘
"""
def __init__(
self,
encoder: Optional[nn.Module],
ensemble: List[CriticHead],
output_normalization: nn.Module,
init_final: Optional[float] = None,
):
super().__init__()
self.encoder = encoder
self.init_final = init_final
self.output_normalization = output_normalization
self.critics = nn.ModuleList(ensemble)
self.parameters_to_optimize = []
# Handle the case where a part of the encoder if frozen
if self.encoder is not None:
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
self.parameters_to_optimize += list(self.critics.parameters())
def forward(
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
# NOTE: We normalize actions it helps for sample efficiency
actions: dict[str, torch.tensor] = {"action": actions}
# NOTE: Normalization layer took dict in input and outputs a dict that why
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
inputs = torch.cat([obs_enc, actions], dim=-1)
# Loop through critics and collect outputs
q_values = []
for critic in self.critics:
q_values.append(critic(inputs))
# Stack outputs to match expected shape [num_critics, batch_size]
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
return q_values
class Policy(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network: nn.Module,
action_dim: int,
log_std_min: float = -5,
log_std_max: float = 2,
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
use_tanh_squash: bool = False,
encoder_is_shared: bool = False,
):
super().__init__()
self.encoder = encoder
self.network = network
self.action_dim = action_dim
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.fixed_std = fixed_std
self.use_tanh_squash = use_tanh_squash
self.parameters_to_optimize = []
self.parameters_to_optimize += list(self.network.parameters())
if self.encoder is not None and not encoder_is_shared:
self.parameters_to_optimize += list(self.encoder.parameters())
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Mean layer
self.mean_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.mean_layer.weight)
self.parameters_to_optimize += list(self.mean_layer.parameters())
# Standard deviation layer or parameter
if fixed_std is None:
self.std_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.std_layer.weight)
self.parameters_to_optimize += list(self.std_layer.parameters())
def forward(
self,
observations: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = (
observation_features
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
# Get network outputs
outputs = self.network(obs_enc)
means = self.mean_layer(outputs)
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(
log_std
).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (
self.log_std_max - self.log_std_min
) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
log_std = self.fixed_std.expand_as(means)
# uses tanh activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, torch.exp(log_std))
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log(
(1 - actions.pow(2)) + 1e-6
) # Adjust log-probs for Tanh
else:
actions = x_t # No Tanh; raw Gaussian sample
log_probs = log_probs.sum(-1) # Sum over action dimensions
means = torch.tanh(means) if self.use_tanh_squash else means
return actions, log_probs, means
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
device = get_device_from_parameters(self)
observations = observations.to(device)
if self.encoder is not None:
with torch.inference_mode():
return self.encoder(observations)
return observations
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self.has_pretrained_vision_encoder = False
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_shapes):
self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None:
self.image_enc_layers = PretrainedImageEncoder(config)
self.has_pretrained_vision_encoder = True
else:
self.image_enc_layers = DefaultImageEncoder(config)
self.aggregation_size += config.latent_dim * self.camera_number
if config.freeze_vision_encoder:
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.state"][0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.environment_state"][0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
nn.Tanh(),
)
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(
in_features=self.aggregation_size, out_features=config.latent_dim
)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
obs_dict = self.input_normalization(obs_dict)
# Batch all images along the batch dimension, then encode them.
if len(self.all_image_keys) > 0:
images_batched = torch.cat(
[obs_dict[key] for key in self.all_image_keys], dim=0
)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(
images_batched, dim=0, chunks=len(self.all_image_keys)
)
feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_shapes:
feat.append(
self.env_state_enc_layers(obs_dict["observation.environment_state"])
)
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1)
features = self.aggregation_layer(features)
return features
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
class DefaultImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
def forward(self, x):
return self.image_enc_layers(x)
class PretrainedImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = (
self._load_pretrained_vision_encoder(config)
)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
def _load_pretrained_vision_encoder(self, config):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(
config.vision_encoder_name, trust_remote_code=True
)
# self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
-1
] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError(
"Unsupported vision encoder architecture, make sure you are using a CNN"
)
return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x):
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
# doesn't reach the classifier layer because we don't need it
enc_feat = self.image_enc_layers(x).pooler_output
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
return enc_feat
def freeze_image_encoder(image_encoder: nn.Module):
"""Freeze all parameters in the encoder"""
for param in image_encoder.parameters():
param.requires_grad = False
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params = {}
for outer_key, inner_dict in normalization_params.items():
converted_params[outer_key] = {}
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][
key
].view(3, 1, 1)
return converted_params
if __name__ == "__main__":
# Benchmark the CriticEnsemble performance
import time
# Configuration
num_critics = 10
batch_size = 32
action_dim = 7
obs_dim = 64
hidden_dims = [256, 256]
num_iterations = 100
print("Creating test environment...")
# Create a simple dummy encoder
class DummyEncoder(nn.Module):
def __init__(self):
super().__init__()
self.output_dim = obs_dim
self.parameters_to_optimize = []
def forward(self, obs):
# Just return a random tensor of the right shape
# In practice, this would encode the observations
return torch.randn(batch_size, obs_dim, device=device)
# Create critic heads
print(f"Creating {num_critics} critic heads...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
critic_heads = [
CriticHead(
input_dim=obs_dim + action_dim,
hidden_dims=hidden_dims,
).to(device)
for _ in range(num_critics)
]
# Create the critic ensemble
print("Creating CriticEnsemble...")
critic_ensemble = CriticEnsemble(
encoder=DummyEncoder().to(device),
ensemble=critic_heads,
output_normalization=nn.Identity(),
).to(device)
# Create random input data
print("Creating input data...")
obs_dict = {
"observation.state": torch.randn(batch_size, obs_dim, device=device),
}
actions = torch.randn(batch_size, action_dim, device=device)
# Warmup run
print("Warming up...")
_ = critic_ensemble(obs_dict, actions)
# Time the forward pass
print(f"Running benchmark with {num_iterations} iterations...")
start_time = time.perf_counter()
for _ in range(num_iterations):
q_values = critic_ensemble(obs_dict, actions)
end_time = time.perf_counter()
# Print results
elapsed_time = end_time - start_time
print(f"Total time: {elapsed_time:.4f} seconds")
print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
# Verify that all critic heads produce different outputs
# This confirms each critic head is unique
# print("\nVerifying critic outputs are different:")
# for i in range(num_critics):
# for j in range(i + 1, num_critics):
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")

View File

@@ -191,6 +191,10 @@ class TDMPCConfig:
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
)
if not self.use_mpc:
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
raise ValueError(
"If `n_action_steps > 1`, `use_mpc` must be set to `True`."
)
if self.n_action_steps > self.horizon:
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
raise ValueError(
"`n_action_steps` must be less than or equal to `horizon`."
)

View File

@@ -68,7 +68,9 @@ class TDMPCPolicy(
name = "tdmpc"
def __init__(
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
self,
config: TDMPCConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
@@ -100,7 +102,9 @@ class TDMPCPolicy(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
self._use_image = False
self._use_env_state = False
@@ -120,7 +124,9 @@ class TDMPCPolicy(
"""
self._queues = {
"observation.state": deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
"action": deque(
maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)
),
}
if self._use_image:
self._queues["observation.image"] = deque(maxlen=1)
@@ -135,7 +141,9 @@ class TDMPCPolicy(
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
if self._use_image:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
self._queues = populate_queues(self._queues, batch)
@@ -209,13 +217,20 @@ class TDMPCPolicy(
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
# trajectories.
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
z = einops.repeat(
z,
"b d -> n b d",
n=self.config.n_gaussian_samples + self.config.n_pi_samples,
)
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
# algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros(
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
self.config.horizon,
batch_size,
self.config.output_shapes["action"][0],
device=device,
)
# Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None:
@@ -231,35 +246,47 @@ class TDMPCPolicy(
self.config.output_shapes["action"][0],
device=std.device,
)
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
gaussian_actions = torch.clamp(
mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1
)
# Compute elite actions.
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
value = self.estimate_value(z, actions).nan_to_num_(0)
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
elite_idxs = torch.topk(
value, self.config.n_elites, dim=0
).indices # (n_elites, batch)
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
# (horizon, n_elites, batch, action_dim)
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
elite_actions = actions.take_along_dim(
einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1
)
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
score = torch.exp(
self.config.elite_weighting_temperature * (elite_value - max_value)
)
score /= score.sum(axis=0, keepdim=True)
# (horizon, batch, action_dim)
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
_mean = torch.sum(
einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1
)
_std = torch.sqrt(
torch.sum(
einops.rearrange(score, "n b -> n b 1")
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d"))
** 2,
dim=1,
)
)
# Update mean with an exponential moving average, and std with a direct replacement.
mean = (
self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean
self.config.gaussian_mean_momentum * mean
+ (1 - self.config.gaussian_mean_momentum) * _mean
)
std = _std.clamp_(self.config.min_std, self.config.max_std)
@@ -268,7 +295,9 @@ class TDMPCPolicy(
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
# scores from the last iteration.
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
actions = elite_actions[
:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)
]
return actions
@@ -291,7 +320,8 @@ class TDMPCPolicy(
# of the FOWM paper.
if self.config.uncertainty_regularizer_coeff > 0:
regularization = -(
self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0)
self.config.uncertainty_regularizer_coeff
* self.model.Qs(z, actions[t]).std(0)
)
else:
regularization = 0
@@ -311,15 +341,22 @@ class TDMPCPolicy(
if self.config.q_ensemble_size > 2:
G += (
running_discount
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
0
]
* torch.min(
terminal_values[
torch.randint(0, self.config.q_ensemble_size, size=(2,))
],
dim=0,
)[0]
)
else:
G += running_discount * torch.min(terminal_values, dim=0)[0]
# Finally, also regularize the terminal value.
if self.config.uncertainty_regularizer_coeff > 0:
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
G -= (
running_discount
* self.config.uncertainty_regularizer_coeff
* terminal_values.std(0)
)
return G
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
@@ -331,7 +368,9 @@ class TDMPCPolicy(
batch = self.normalize_inputs(batch)
if self._use_image:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
batch = self.normalize_targets(batch)
@@ -349,7 +388,10 @@ class TDMPCPolicy(
# Apply random image augmentations.
if self._use_image and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
partial(
random_shifts_aug,
max_random_shift_ratio=self.config.max_random_shift_ratio,
),
observations["observation.image"],
)
@@ -367,14 +409,20 @@ class TDMPCPolicy(
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
# gives us a next `z`.
batch_size = batch["index"].shape[0]
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
z_preds = torch.empty(
horizon + 1, batch_size, self.config.latent_dim, device=device
)
z_preds[0] = self.model.encode(current_observation)
reward_preds = torch.empty_like(reward, device=device)
for t in range(horizon):
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(
z_preds[t], action[t]
)
# Compute Q and V value predictions based on the latent rollout.
q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch)
q_preds_ensemble = self.model.Qs(
z_preds[:-1], action
) # (ensemble, horizon, batch)
v_preds = self.model.V(z_preds[:-1])
info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()})
@@ -388,10 +436,14 @@ class TDMPCPolicy(
# actions (not actions estimated by π).
# Note: Here we do not use self.model_target, but self.model. This is to follow the original code
# and the FOWM paper.
q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations))
q_targets = reward + self.config.discount * self.model.V(
self.model.encode(next_observations)
)
# From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we
# are using them to compute loss for V.
v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True)
v_targets = self.model_target.Qs(
z_preds[:-1].detach(), action, return_min=True
)
# Compute losses.
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
@@ -434,7 +486,9 @@ class TDMPCPolicy(
temporal_loss_coeffs
* F.mse_loss(
q_preds_ensemble,
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
einops.repeat(
q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]
),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
@@ -472,12 +526,14 @@ class TDMPCPolicy(
z_preds = z_preds.detach()
# Use stopgrad for the advantage calculation.
with torch.no_grad():
advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V(
z_preds[:-1]
)
advantage = self.model_target.Qs(
z_preds[:-1], action, return_min=True
) - self.model.V(z_preds[:-1])
info["advantage"] = advantage[0]
# (t, b)
exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0)
exp_advantage = torch.clamp(
torch.exp(advantage * self.config.advantage_scaling), max=100.0
)
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
# Calculate the MSE between the actions and the action predictions.
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
@@ -532,7 +588,9 @@ class TDMPCPolicy(
# Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA
# update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code
# we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995)
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
update_ema_parameters(
self.model_target, self.model, self.config.target_model_momentum
)
class TDMPCTOLD(nn.Module):
@@ -543,7 +601,9 @@ class TDMPCTOLD(nn.Module):
self.config = config
self._encoder = TDMPCObservationEncoder(config)
self._dynamics = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.Linear(
config.latent_dim + config.output_shapes["action"][0], config.mlp_dim
),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -554,7 +614,9 @@ class TDMPCTOLD(nn.Module):
nn.Sigmoid(),
)
self._reward = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.Linear(
config.latent_dim + config.output_shapes["action"][0], config.mlp_dim
),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -574,7 +636,10 @@ class TDMPCTOLD(nn.Module):
self._Qs = nn.ModuleList(
[
nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.Linear(
config.latent_dim + config.output_shapes["action"][0],
config.mlp_dim,
),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -619,7 +684,9 @@ class TDMPCTOLD(nn.Module):
m[-1], nn.Linear
), "Sanity check. The last linear layer needs 0 initialization on weights."
nn.init.zeros_(m[-1].weight)
nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure
nn.init.zeros_(
m[-1].bias
) # this has already been done, but keep this line here for good measure
def encode(self, obs: dict[str, Tensor]) -> Tensor:
"""Encodes an observation into its latent representation."""
@@ -717,14 +784,32 @@ class TDMPCObservationEncoder(nn.Module):
if "observation.image" in config.input_shapes:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
config.input_shapes["observation.image"][0],
config.image_encoder_hidden_dim,
7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.Conv2d(
config.image_encoder_hidden_dim,
config.image_encoder_hidden_dim,
3,
stride=2,
),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
@@ -740,7 +825,10 @@ class TDMPCObservationEncoder(nn.Module):
)
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
nn.Linear(
config.input_shapes["observation.state"][0],
config.state_encoder_hidden_dim,
),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
@@ -749,7 +837,8 @@ class TDMPCObservationEncoder(nn.Module):
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
config.input_shapes["observation.environment_state"][0],
config.state_encoder_hidden_dim,
),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
@@ -766,9 +855,15 @@ class TDMPCObservationEncoder(nn.Module):
feat = []
# NOTE: Order of observations matters here.
if "observation.image" in self.config.input_shapes:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
feat.append(
flatten_forward_unflatten(
self.image_enc_layers, obs_dict["observation.image"]
)
)
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
feat.append(
self.env_state_enc_layers(obs_dict["observation.environment_state"])
)
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
return torch.stack(feat, dim=0).mean(0)
@@ -811,12 +906,17 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
for (n_p_ema, p_ema), (n_p, p) in zip(
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
ema_module.named_parameters(recurse=False),
module.named_parameters(recurse=False),
strict=True,
):
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
if isinstance(p, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad:
if (
isinstance(module, nn.modules.batchnorm._BatchNorm)
or not p.requires_grad
):
# Copy BatchNorm parameters, and non-trainable parameters directly.
p_ema.copy_(p.to(dtype=p_ema.dtype).data)
with torch.no_grad():
@@ -824,7 +924,9 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha)
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
def flatten_forward_unflatten(
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:

View File

@@ -109,7 +109,9 @@ class VQBeTConfig:
"observation.state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"}
)
# Architecture / modeling.
# Vision backbone.

View File

@@ -79,7 +79,9 @@ class VQBeTPolicy(
self.vqbet = VQBeTModel(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.expected_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
self.reset()
@@ -104,8 +106,12 @@ class VQBeTPolicy(
"""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[k] for k in self.expected_image_keys], dim=-4
)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
@@ -116,8 +122,14 @@ class VQBeTPolicy(
)
if len(self._queues["action"]) == 0:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
batch = {
k: torch.stack(list(self._queues[k]), dim=1)
for k in batch
if k in self._queues
}
actions = self.vqbet(batch, rollout=True)[
:, : self.config.action_chunk_size
]
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
@@ -130,8 +142,12 @@ class VQBeTPolicy(
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[k] for k in self.expected_image_keys], dim=-4
)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item():
@@ -139,7 +155,9 @@ class VQBeTPolicy(
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
loss, n_different_codes, n_different_combinations, recon_l1_error = (
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
self.vqbet.action_head.discretize(
self.config.n_vqvae_training_steps, batch["action"]
)
)
return {
"loss": loss,
@@ -196,7 +214,9 @@ class SpatialSoftmax(nn.Module):
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)
)
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
@@ -288,14 +308,17 @@ class VQBeTModel(nn.Module):
self.config = config
self.rgb_encoder = VQBeTRgbEncoder(config)
self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self.num_images = len(
[k for k in config.input_shapes if k.startswith("observation.image")]
)
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self.state_projector = MLP(
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
config.input_shapes["observation.state"][0],
hidden_channels=[self.config.gpt_input_dim],
)
self.rgb_feature_projector = MLP(
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
@@ -310,7 +333,12 @@ class VQBeTModel(nn.Module):
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
self.register_buffer(
"select_target_actions_indices",
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
torch.row_stack(
[
torch.arange(i, i + self.config.action_chunk_size)
for i in range(num_tokens)
]
),
)
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
@@ -325,7 +353,11 @@ class VQBeTModel(nn.Module):
)
# Separate batch and sequence dims.
img_features = einops.rearrange(
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
img_features,
"(b s n) ... -> b s n ...",
b=batch_size,
s=n_obs_steps,
n=self.num_images,
)
# Arrange prior and current observation step tokens as shown in the class docstring.
@@ -337,13 +369,19 @@ class VQBeTModel(nn.Module):
input_tokens.append(
self.state_projector(batch["observation.state"])
) # (batch, obs_step, projection dims)
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
input_tokens.append(
einops.repeat(
self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
)
)
# Interleave tokens by stacking and rearranging.
input_tokens = torch.stack(input_tokens, dim=2)
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
len_additional_action_token = self.config.n_action_pred_token - 1
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
future_action_tokens = self.action_token.repeat(
batch_size, len_additional_action_token, 1
)
# add additional action query tokens for predicting future action chunks
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
@@ -352,9 +390,9 @@ class VQBeTModel(nn.Module):
features = self.policy(input_tokens)
# len(self.config.input_shapes) is the number of different observation modes.
# this line gets the index of action prompt tokens.
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len(
self.config.input_shapes
)
historical_act_pred_index = np.arange(0, n_obs_steps) * (
len(self.config.input_shapes) + 1
) + len(self.config.input_shapes)
# only extract the output tokens at the position of action query:
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
@@ -362,7 +400,11 @@ class VQBeTModel(nn.Module):
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
if len_additional_action_token > 0:
features = torch.cat(
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
[
features[:, historical_act_pred_index],
features[:, -len_additional_action_token:],
],
dim=1,
)
else:
features = features[:, historical_act_pred_index]
@@ -370,13 +412,15 @@ class VQBeTModel(nn.Module):
action_head_output = self.action_head(features)
# if rollout, VQ-BeT don't calculate loss
if rollout:
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
batch_size, self.config.action_chunk_size, -1
)
return action_head_output["predicted_action"][
:, n_obs_steps - 1, :
].reshape(batch_size, self.config.action_chunk_size, -1)
# else, it calculate overall loss (bin prediction loss, and offset loss)
else:
output = batch["action"][:, self.select_target_actions_indices]
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
loss = self.action_head.loss_fn(
action_head_output, output, reduction="mean"
)
return action_head_output, loss
@@ -411,7 +455,9 @@ class VQBeTHead(nn.Module):
else:
self.map_to_cbet_preds_bin = MLP(
in_channels=config.gpt_output_dim,
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
hidden_channels=[
self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed
],
)
self.map_to_cbet_preds_offset = MLP(
in_channels=config.gpt_output_dim,
@@ -438,7 +484,10 @@ class VQBeTHead(nn.Module):
loss, metric = self.vqvae_model.vqvae_forward(actions)
n_different_codes = sum(
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
[
len(torch.unique(metric[2][:, i]))
for i in range(self.vqvae_model.vqvae_num_layers)
]
)
n_different_combinations = len(torch.unique(metric[2], dim=0))
recon_l1_error = metric[0].detach().cpu().item()
@@ -485,7 +534,13 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
torch.cat(
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
(
x,
F.one_hot(
sampled_primary_centers,
num_classes=self.config.vqvae_n_embed,
),
),
axis=1,
)
)
@@ -493,19 +548,29 @@ class VQBeTHead(nn.Module):
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
)
sampled_secondary_centers = einops.rearrange(
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
torch.multinomial(
cbet_secondary_probs.view(-1, choices), num_samples=1
),
"(NT) 1 -> NT",
NT=NT,
)
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
sampled_centers = torch.stack(
(sampled_primary_centers, sampled_secondary_centers), axis=1
)
cbet_logits = torch.stack(
[cbet_primary_logits, cbet_secondary_logits], dim=1
)
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
else:
cbet_logits = self.map_to_cbet_preds_bin(x)
cbet_logits = einops.rearrange(
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
cbet_logits,
"(NT) (G C) -> (NT) G C",
G=self.vqvae_model.vqvae_num_layers,
)
cbet_probs = torch.softmax(
cbet_logits / self.config.bet_softmax_temperature, dim=-1
)
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
NT, G, choices = cbet_probs.shape
sampled_centers = einops.rearrange(
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
@@ -525,9 +590,17 @@ class VQBeTHead(nn.Module):
sampled_offsets = sampled_offsets.sum(dim=1)
with torch.no_grad():
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
return_decoder_input = (
self.vqvae_model.get_embeddings_from_code(sampled_centers)
.clone()
.detach()
)
# pass the centroids through decoder to get actions.
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
decoded_action = (
self.vqvae_model.get_action_from_latent(return_decoder_input)
.clone()
.detach()
)
# reshaped extracted offset to match with decoded centroids
sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
@@ -576,7 +649,9 @@ class VQBeTHead(nn.Module):
# Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action.
with torch.no_grad():
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
state_vq, action_bins = self.vqvae_model.get_code(
action_seq
) # action_bins: NT, G
# Now we can compute the loss.
@@ -599,8 +674,12 @@ class VQBeTHead(nn.Module):
+ cbet_loss2 * self.config.secondary_code_loss_weight
)
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
equal_primary_code_rate = torch.sum(
(action_bins[:, 0] == sampled_centers[:, 0]).int()
) / (NT)
equal_secondary_code_rate = torch.sum(
(action_bins[:, 1] == sampled_centers[:, 1]).int()
) / (NT)
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
@@ -614,7 +693,9 @@ class VQBeTHead(nn.Module):
"classification_loss": cbet_loss.detach().cpu().item(),
"offset_loss": offset_loss.detach().cpu().item(),
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach()
.cpu()
.item(),
"vq_action_error": vq_action_error.detach().cpu().item(),
"offset_action_error": offset_action_error.detach().cpu().item(),
"action_error_max": action_error_max.detach().cpu().item(),
@@ -643,11 +724,17 @@ class VQBeTOptimizer(torch.optim.Adam):
if cfg.policy.sequentially_select:
decay_params = (
decay_params
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
+ list(
policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()
)
+ list(
policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()
)
)
else:
decay_params = decay_params + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
decay_params = decay_params + list(
policy.vqbet.action_head.map_to_cbet_preds_bin.parameters()
)
optim_groups = [
{
@@ -693,7 +780,11 @@ class VQBeTScheduler(nn.Module):
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return max(
0.0,
0.5
* (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
)
self.lr_scheduler = LambdaLR(optimizer, lr_lambda, -1)
@@ -717,7 +808,9 @@ class VQBeTRgbEncoder(nn.Module):
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
self.maybe_random_crop = torchvision.transforms.RandomCrop(
config.crop_shape
)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -738,7 +831,9 @@ class VQBeTRgbEncoder(nn.Module):
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
)
# Set up pooling and final layers.
@@ -746,17 +841,25 @@ class VQBeTRgbEncoder(nn.Module):
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
assert len(image_keys) == 1
image_key = image_keys[0]
dummy_input_h_w = (
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
config.crop_shape
if config.crop_shape is not None
else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(
size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.pool = SpatialSoftmax(
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
@@ -783,7 +886,9 @@ class VQBeTRgbEncoder(nn.Module):
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module],
) -> nn.Module:
"""
Args:
@@ -796,7 +901,11 @@ def _replace_submodules(
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
replace_list = [
k.split(".")
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
@@ -811,7 +920,9 @@ def _replace_submodules(
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
assert not any(
predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)
)
return root_module
@@ -844,7 +955,8 @@ class VqVae(nn.Module):
)
self.encoder = MLP(
in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size,
in_channels=self.config.output_shapes["action"][0]
* self.config.action_chunk_size,
hidden_channels=[
config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim,
@@ -872,9 +984,13 @@ class VqVae(nn.Module):
# given latent vector, this function outputs the decoded action.
output = self.decoder(latent)
if self.config.action_chunk_size == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
return einops.rearrange(
output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]
)
else:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
return einops.rearrange(
output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]
)
def get_code(self, state):
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)

View File

@@ -123,9 +123,15 @@ class CausalSelfAttention(nn.Module):
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(
1, 2
) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
@@ -133,7 +139,9 @@ class CausalSelfAttention(nn.Module):
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
y = (
y.transpose(1, 2).contiguous().view(B, T, C)
) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
@@ -189,12 +197,16 @@ class GPT(nn.Module):
"ln_f": nn.LayerNorm(config.gpt_hidden_dim),
}
)
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
self.lm_head = nn.Linear(
config.gpt_hidden_dim, config.gpt_output_dim, bias=False
)
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer))
torch.nn.init.normal_(
p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)
)
# report number of parameters
n_params = sum(p.numel() for p in self.parameters())
@@ -208,11 +220,17 @@ class GPT(nn.Module):
), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
# positional encodings that are added to the input embeddings
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
0
) # shape (1, t)
# forward the GPT model itself
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
tok_emb = self.transformer.wte(
input
) # token embeddings of shape (b, t, gpt_hidden_dim)
pos_emb = self.transformer.wpe(
pos
) # position embeddings of shape (1, t, gpt_hidden_dim)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
@@ -237,7 +255,9 @@ class GPT(nn.Module):
# but want to use a smaller block size for some smaller, simpler model
assert gpt_block_size <= self.config.gpt_block_size
self.config.gpt_block_size = gpt_block_size
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
self.transformer.wpe.weight = nn.Parameter(
self.transformer.wpe.weight[:gpt_block_size]
)
for block in self.transformer.h:
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
@@ -270,7 +290,9 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
assert (
len(inter_params) == 0
), "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
)
assert (
@@ -368,8 +390,12 @@ class ResidualVQ(nn.Module):
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.project_in = (
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.num_quantizers = num_quantizers
@@ -377,7 +403,10 @@ class ResidualVQ(nn.Module):
self.layers = nn.ModuleList(
[
VectorQuantize(
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
dim=codebook_dim,
codebook_dim=codebook_dim,
accept_image_fmap=accept_image_fmap,
**kwargs,
)
for _ in range(num_quantizers)
]
@@ -448,7 +477,9 @@ class ResidualVQ(nn.Module):
return all_codes
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None):
def forward(
self, x, indices=None, return_all_codes=False, sample_codebook_temp=None
):
"""
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
@@ -477,13 +508,17 @@ class ResidualVQ(nn.Module):
), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
ce_losses = []
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
should_quantize_dropout = (
self.training and self.quantize_dropout and not return_loss
)
# sample a layer index at which to dropout further residual quantization
# also prepare null indices and loss
if should_quantize_dropout:
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
rand_quantize_dropout_index = randrange(
self.quantize_dropout_cutoff_index, num_quant
)
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = (
@@ -492,14 +527,23 @@ class ResidualVQ(nn.Module):
- 1
)
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long)
null_indices_shape = (
(x.shape[0], *x.shape[-2:])
if self.accept_image_fmap
else tuple(x.shape[:2])
)
null_indices = torch.full(
null_indices_shape, -1.0, device=device, dtype=torch.long
)
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
# go through the layers
for quantizer_index, layer in enumerate(self.layers):
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
if (
should_quantize_dropout
and quantizer_index > rand_quantize_dropout_index
):
all_indices.append(null_indices)
all_losses.append(null_loss)
continue
@@ -539,7 +583,9 @@ class ResidualVQ(nn.Module):
# stack all losses and indices
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
all_losses, all_indices = map(
partial(torch.stack, dim=-1), (all_losses, all_indices)
)
ret = (quantized_out, all_indices, all_losses)
@@ -599,8 +645,12 @@ class VectorQuantize(nn.Module):
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
self.project_in = (
nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
)
self.eps = eps
self.commitment_weight = commitment_weight
@@ -614,10 +664,14 @@ class VectorQuantize(nn.Module):
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update"
assert not (
ema_update and learnable_codebook
), "learnable codebook not compatible with EMA update"
assert 0 <= sync_update_v <= 1.0
assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on"
assert not (
sync_update_v > 0.0 and not learnable_codebook
), "learnable codebook must be turned on"
self.sync_update_v = sync_update_v
@@ -629,7 +683,9 @@ class VectorQuantize(nn.Module):
)
if sync_codebook is None:
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
sync_codebook = (
distributed.is_initialized() and distributed.get_world_size() > 1
)
codebook_kwargs = {
"dim": codebook_dim,
@@ -794,11 +850,17 @@ class VectorQuantize(nn.Module):
# quantize again
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
quantize, embed_ind, distances = self._codebook(
x, **codebook_forward_kwargs
)
if self.training:
# determine code to use for commitment loss
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
maybe_detach = (
torch.detach
if not self.learnable_codebook or freeze_codebook
else identity
)
commit_quantize = maybe_detach(quantize)
@@ -808,7 +870,9 @@ class VectorQuantize(nn.Module):
if self.sync_update_v > 0.0:
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
quantize = quantize + self.sync_update_v * (quantize - quantize.detach())
quantize = quantize + self.sync_update_v * (
quantize - quantize.detach()
)
# function for calculating cross entropy loss to distance matrix
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
@@ -841,7 +905,9 @@ class VectorQuantize(nn.Module):
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
if self.accept_image_fmap:
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)
embed_ind = rearrange(
embed_ind, "b (h w) ... -> b h w ...", h=height, w=width
)
if only_one:
embed_ind = rearrange(embed_ind, "b 1 -> b")
@@ -895,8 +961,12 @@ class VectorQuantize(nn.Module):
num_codes = codebook.shape[-2]
if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes]
if (
self.orthogonal_reg_max_codes is not None
) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[
: self.orthogonal_reg_max_codes
]
codebook = codebook[:, rand_ids]
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
@@ -928,7 +998,9 @@ class VectorQuantize(nn.Module):
# if masking, only return quantized for where mask has True
if mask is not None:
quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input)
quantize = torch.where(
rearrange(mask, "... -> ... 1"), quantize, orig_input
)
return quantize, embed_ind, loss
@@ -1038,7 +1110,9 @@ def sample_vectors(samples, num):
def batched_sample_vectors(samples, num):
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0)
return torch.stack(
[sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0
)
def pad_shape(shape, size, dim=0):
@@ -1089,7 +1163,9 @@ def sample_vectors_distributed(local_samples, num):
all_num_samples = all_gather_sizes(local_samples, dim=0)
if rank == 0:
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
samples_per_rank = sample_multinomial(
num, all_num_samples / all_num_samples.sum()
)
else:
samples_per_rank = torch.empty_like(all_num_samples)
@@ -1202,7 +1278,9 @@ class EuclideanCodebook(nn.Module):
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.reset_cluster_size = (
reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
reset_cluster_size
if (reset_cluster_size is not None)
else threshold_ema_dead_code
)
assert callable(gumbel_sample)
@@ -1213,8 +1291,14 @@ class EuclideanCodebook(nn.Module):
use_ddp and num_codebooks > 1 and kmeans_init
), "kmeans init is not compatible with multiple codebooks in distributed environment for now"
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
self.sample_fn = (
sample_vectors_distributed
if use_ddp and sync_kmeans
else batched_sample_vectors
)
self.kmeans_all_reduce_fn = (
distributed.all_reduce if use_ddp and sync_kmeans else noop
)
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
@@ -1353,7 +1437,9 @@ class EuclideanCodebook(nn.Module):
distributed.all_reduce(variance_numer)
batch_variance = variance_numer / num_vectors
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
self.update_with_decay(
"batch_variance", batch_variance, self.affine_param_batch_decay
)
def replace(self, batch_samples, batch_mask):
for ind, (samples, mask) in enumerate(
@@ -1362,7 +1448,9 @@ class EuclideanCodebook(nn.Module):
if not torch.any(mask):
continue
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
sampled = self.sample_fn(
rearrange(samples, "... -> 1 ..."), mask.sum().item()
)
sampled = rearrange(sampled, "1 ... -> ...")
self.embed.data[ind][mask] = sampled
@@ -1386,7 +1474,9 @@ class EuclideanCodebook(nn.Module):
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
needs_codebook_dim = x.ndim < 4
sample_codebook_temp = (
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
sample_codebook_temp
if (sample_codebook_temp is not None)
else self.sample_codebook_temp
)
x = x.float()
@@ -1414,7 +1504,9 @@ class EuclideanCodebook(nn.Module):
if self.affine_param:
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
embed = (embed - self.codebook_mean) * (
batch_std / codebook_std
) + self.batch_mean
dist = -cdist(flatten, embed)
@@ -1432,7 +1524,9 @@ class EuclideanCodebook(nn.Module):
if self.training and self.ema_update and not freeze_codebook:
if self.affine_param:
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
flatten = (flatten - self.batch_mean) * (
codebook_std / batch_std
) + self.codebook_mean
if mask is not None:
embed_onehot[~mask] = 0.0
@@ -1455,7 +1549,9 @@ class EuclideanCodebook(nn.Module):
self.expire_codes_(x)
if needs_codebook_dim:
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))
quantize, embed_ind = tuple(
rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)
)
dist = unpack_one(dist, ps, "h * d")

View File

@@ -65,7 +65,9 @@ def save_image(img_array, serial_number, frame_index, images_dir):
img.save(str(path), quality=100)
logging.info(f"Saved image: {path}")
except Exception as e:
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
logging.error(
f"Failed to save image for camera {serial_number} frame {frame_index}: {e}"
)
def save_images_from_cameras(
@@ -94,7 +96,9 @@ def save_images_from_cameras(
cameras = []
for cam_sn in serial_numbers:
print(f"{cam_sn=}")
camera = IntelRealSenseCamera(cam_sn, fps=fps, width=width, height=height, mock=mock)
camera = IntelRealSenseCamera(
cam_sn, fps=fps, width=width, height=height, mock=mock
)
camera.connect()
print(
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
@@ -140,7 +144,9 @@ def save_images_from_cameras(
if time.perf_counter() - start_time > record_time_s:
break
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
print(
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
frame_index += 1
finally:
@@ -182,8 +188,12 @@ class IntelRealSenseCameraConfig:
self.channels = 3
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
at_least_one_is_not_none = (
self.fps is not None or self.width is not None or self.height is not None
)
at_least_one_is_none = (
self.fps is None or self.width is None or self.height is None
)
if at_least_one_is_not_none and at_least_one_is_none:
raise ValueError(
"For `fps`, `width` and `height`, either all of them need to be set, or none of them, "
@@ -191,7 +201,9 @@ class IntelRealSenseCameraConfig:
)
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
raise ValueError(
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
)
class IntelRealSenseCamera:
@@ -286,7 +298,9 @@ class IntelRealSenseCamera:
self.rotation = cv2.ROTATE_180
@classmethod
def init_from_name(cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs):
def init_from_name(
cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs
):
camera_infos = find_cameras()
camera_names = [cam["name"] for cam in camera_infos]
this_name_count = Counter(camera_names)[name]
@@ -296,7 +310,9 @@ class IntelRealSenseCamera:
f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them."
)
name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos}
name_to_serial_dict = {
cam["name"]: cam["serial_number"] for cam in camera_infos
}
cam_sn = name_to_serial_dict[name]
if config is None:
@@ -323,13 +339,17 @@ class IntelRealSenseCamera:
if self.fps and self.width and self.height:
# TODO(rcadene): can we set rgb8 directly?
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
config.enable_stream(
rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps
)
else:
config.enable_stream(rs.stream.color)
if self.use_depth:
if self.fps and self.width and self.height:
config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
config.enable_stream(
rs.stream.depth, self.width, self.height, rs.format.z16, self.fps
)
else:
config.enable_stream(rs.stream.depth)
@@ -362,7 +382,9 @@ class IntelRealSenseCamera:
actual_height = color_profile.height()
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
if self.fps is not None and not math.isclose(
self.fps, actual_fps, rel_tol=1e-3
):
# Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError(
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
@@ -382,7 +404,9 @@ class IntelRealSenseCamera:
self.is_connected = True
def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
def read(
self, temporary_color: str | None = None
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3)
of type `np.uint8`, contrarily to the pytorch format which is float channel first.
@@ -409,11 +433,15 @@ class IntelRealSenseCamera:
color_frame = frame.get_color_frame()
if not color_frame:
raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
raise OSError(
f"Can't capture color image from IntelRealSenseCamera({self.serial_number})."
)
color_image = np.asanyarray(color_frame.get_data())
requested_color_mode = self.color_mode if temporary_color is None else temporary_color
requested_color_mode = (
self.color_mode if temporary_color is None else temporary_color
)
if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
@@ -441,7 +469,9 @@ class IntelRealSenseCamera:
if self.use_depth:
depth_frame = frame.get_depth_frame()
if not depth_frame:
raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
raise OSError(
f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})."
)
depth_map = np.asanyarray(depth_frame.get_data())
@@ -483,7 +513,9 @@ class IntelRealSenseCamera:
# TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here
num_tries += 1
time.sleep(1 / self.fps)
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
if num_tries > self.fps and (
self.thread.ident is None or not self.thread.is_alive()
):
raise Exception(
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
)

View File

@@ -31,10 +31,14 @@ from lerobot.common.utils.utils import capture_timestamp_utc
MAX_OPENCV_INDEX = 60
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
def find_cameras(
raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False
) -> list[dict]:
cameras = []
if platform.system() == "Linux":
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
print(
"Linux detected. Finding available camera indices through scanning '/dev/video*' ports"
)
possible_ports = [str(port) for port in Path("/dev").glob("video*")]
ports = _find_cameras(possible_ports, mock=mock)
for port in ports:
@@ -165,7 +169,9 @@ def save_images_from_cameras(
dt_s = time.perf_counter() - now
busy_wait(1 / fps - dt_s)
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
print(
f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}"
)
if time.perf_counter() - start_time > record_time_s:
break
@@ -205,7 +211,9 @@ class OpenCVCameraConfig:
self.channels = 3
if self.rotation not in [-90, None, 90, 180]:
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
raise ValueError(
f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})"
)
class OpenCVCamera:
@@ -247,7 +255,12 @@ class OpenCVCamera:
```
"""
def __init__(self, camera_index: int | str, config: OpenCVCameraConfig | None = None, **kwargs):
def __init__(
self,
camera_index: int | str,
config: OpenCVCameraConfig | None = None,
**kwargs,
):
if config is None:
config = OpenCVCameraConfig()
@@ -261,12 +274,16 @@ class OpenCVCamera:
if platform.system() == "Linux":
if isinstance(self.camera_index, int):
self.port = Path(f"/dev/video{self.camera_index}")
elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index):
elif isinstance(self.camera_index, str) and is_valid_unix_path(
self.camera_index
):
self.port = Path(self.camera_index)
# Retrieve the camera index from a potentially symlinked path
self.camera_index = get_camera_index_from_unix_port(self.port)
else:
raise ValueError(f"Please check the provided camera_index: {camera_index}")
raise ValueError(
f"Please check the provided camera_index: {camera_index}"
)
self.fps = config.fps
self.width = config.width
@@ -298,7 +315,9 @@ class OpenCVCamera:
def connect(self):
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
raise RobotDeviceAlreadyConnectedError(
f"OpenCVCamera({self.camera_index}) is already connected."
)
if self.mock:
import tests.mock_cv2 as cv2
@@ -309,7 +328,11 @@ class OpenCVCamera:
# when other threads are used to save the images.
cv2.setNumThreads(1)
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
camera_idx = (
f"/dev/video{self.camera_index}"
if platform.system() == "Linux"
else self.camera_index
)
# First create a temporary camera trying to access `camera_index`,
# and verify it is a valid camera by calling `isOpened`.
tmp_camera = cv2.VideoCapture(camera_idx)
@@ -349,16 +372,22 @@ class OpenCVCamera:
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
# Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30)
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
if self.fps is not None and not math.isclose(
self.fps, actual_fps, rel_tol=1e-3
):
# Using `OSError` since it's a broad that encompasses issues related to device communication
raise OSError(
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
)
if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3):
if self.width is not None and not math.isclose(
self.width, actual_width, rel_tol=1e-3
):
raise OSError(
f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
)
if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3):
if self.height is not None and not math.isclose(
self.height, actual_height, rel_tol=1e-3
):
raise OSError(
f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
)
@@ -388,7 +417,9 @@ class OpenCVCamera:
if not ret:
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
requested_color_mode = (
self.color_mode if temporary_color_mode is None else temporary_color_mode
)
if requested_color_mode not in ["rgb", "bgr"]:
raise ValueError(

View File

@@ -11,6 +11,7 @@ from copy import copy
from functools import cache
import cv2
import numpy as np
import torch
import tqdm
from deepdiff import DeepDiff
@@ -22,11 +23,17 @@ from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
from lerobot.common.utils.utils import (
get_safe_torch_device,
init_hydra_config,
set_global_seed,
)
from lerobot.scripts.eval import get_pretrained_policy_path
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
def log_control_info(
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
):
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
@@ -35,7 +42,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
def log_dt(shortname, dt_val_s):
nonlocal log_items, fps
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
if fps is not None:
actual_fps = 1 / dt_val_s
if actual_fps < fps - 1:
@@ -97,7 +104,9 @@ def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
torch.autocast(device_type=device.type)
if device.type == "cuda" and use_amp
else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
@@ -120,14 +129,22 @@ def predict_action(observation, policy, device, use_amp):
return action
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
def init_keyboard_listener(assign_rewards=False):
"""
Initializes a keyboard listener to enable early termination of an episode
or environment reset by pressing the right arrow key ('->'). This may require
sudo permissions to allow the terminal to monitor keyboard events.
Args:
assign_rewards (bool): If True, allows annotating the collected trajectory
with a binary reward at the end of the episode to indicate success.
"""
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if assign_rewards:
events["next.reward"] = 0
if is_headless():
logging.warning(
@@ -145,13 +162,22 @@ def init_keyboard_listener():
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
print(
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
elif assign_rewards and key == keyboard.Key.space:
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
print(
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
events["next.reward"],
)
except Exception as e:
print(f"Error handling key press: {e}")
@@ -164,8 +190,12 @@ def init_keyboard_listener():
def init_policy(pretrained_policy_name_or_path, policy_overrides):
"""Instantiate the policy and load fps, device and use_amp from config yaml"""
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
hydra_cfg = init_hydra_config(
pretrained_policy_path / "config.yaml", policy_overrides
)
policy = make_policy(
hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path
)
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
@@ -209,6 +239,7 @@ def record_episode(
device,
use_amp,
fps,
record_delta_actions,
):
control_loop(
robot=robot,
@@ -220,6 +251,7 @@ def record_episode(
device=device,
use_amp=use_amp,
fps=fps,
record_delta_actions=record_delta_actions,
teleoperate=policy is None,
)
@@ -236,6 +268,7 @@ def control_loop(
device=None,
use_amp=None,
fps=None,
record_delta_actions=False,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
@@ -251,15 +284,21 @@ def control_loop(
raise ValueError("When `teleoperate` is True, `policy` should be None.")
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
raise ValueError(
f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps})."
)
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
if teleoperate:
observation, action = robot.teleop_step(record_data=True)
if record_delta_actions:
action["action"] = action["action"] - current_joint_positions
else:
observation = robot.capture_observation()
@@ -272,12 +311,22 @@ def control_loop(
if dataset is not None:
frame = {**observation, **action}
if "next.reward" in events:
frame["next.reward"] = events["next.reward"]
frame["next.done"] = (events["next.reward"] == 1) or (
events["exit_early"]
)
dataset.add_frame(frame)
# if frame["next.done"]:
# break
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
if fps is not None:
@@ -301,6 +350,8 @@ def reset_environment(robot, events, reset_time_s):
timestamp = 0
start_vencod_t = time.perf_counter()
if "next.reward" in events:
events["next.reward"] = 0
# Wait if necessary
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
@@ -313,6 +364,16 @@ def reset_environment(robot, events, reset_time_s):
break
def reset_follower_position(robot: Robot, target_position):
current_position = robot.follower_arms["main"].read("Present_Position")
trajectory = torch.from_numpy(
np.linspace(current_position, target_position, 50)
) # NOTE: 30 is just an aribtrary number
for pose in trajectory:
robot.send_action(pose)
busy_wait(0.015)
def stop_recording(robot, listener, display_cameras):
robot.disconnect()
@@ -343,21 +404,32 @@ def sanity_check_dataset_name(repo_id, policy):
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
dataset: LeRobotDataset,
robot: Robot,
fps: int,
use_videos: bool,
extra_features: dict = None,
) -> None:
features_from_robot = get_features_from_robot(robot, use_videos)
if extra_features is not None:
features_from_robot.update(extra_features)
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
("features", dataset.features, get_features_from_robot(robot, use_videos)),
("features", dataset.features, features_from_robot),
]
mismatches = []
for field, dataset_value, present_value in fields:
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
diff = DeepDiff(
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
)
if diff:
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
if mismatches:
raise ValueError(
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
"Dataset metadata compatibility check failed with mismatches:\n"
+ "\n".join(mismatches)
)

View File

@@ -8,7 +8,10 @@ from copy import deepcopy
import numpy as np
import tqdm
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.common.utils.utils import capture_timestamp_utc
PROTOCOL_VERSION = 2.0
@@ -143,7 +146,9 @@ NUM_READ_RETRY = 10
NUM_WRITE_RETRY = 10
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
def convert_degrees_to_steps(
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@@ -378,7 +383,9 @@ class DynamixelMotorsBus:
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
present_idx = self.read_with_motor_ids(
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
except ConnectionError:
continue
@@ -394,7 +401,9 @@ class DynamixelMotorsBus:
def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
print(
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
@@ -415,7 +424,9 @@ class DynamixelMotorsBus:
def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
def apply_calibration_autocorrect(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@@ -428,7 +439,9 @@ class DynamixelMotorsBus:
values = self.apply_calibration(values, motor_names)
return values
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def apply_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree.
@@ -503,7 +516,9 @@ class DynamixelMotorsBus:
return values
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def autocorrect_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration.
@@ -545,15 +560,23 @@ class DynamixelMotorsBus:
values[i] *= -1
# Convert from initial range to range [-180, 180] degrees
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
calib_val = (
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (
-(resolution // 2) - values[i] - homing_offset
) / resolution
upp_factor = (
(resolution // 2) - values[i] - homing_offset
) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
start_pos = self.calibration["start_pos"][calib_idx]
@@ -561,7 +584,9 @@ class DynamixelMotorsBus:
# Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
calib_val < UPPER_BOUND_LINEAR
)
# Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@@ -577,19 +602,27 @@ class DynamixelMotorsBus:
factor = math.ceil(low_factor)
if factor > upp_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
else:
factor = math.ceil(upp_factor)
if factor > low_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
out_of_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@@ -599,7 +632,9 @@ class DynamixelMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def revert_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Inverse of `apply_calibration`."""
if motor_names is None:
motor_names = self.motor_names
@@ -638,7 +673,9 @@ class DynamixelMotorsBus:
values = np.round(values).astype(np.int32)
return values
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
def read_with_motor_ids(
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
if self.mock:
import tests.mock_dynamixel_sdk as dxl
else:
@@ -740,7 +777,9 @@ class DynamixelMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
delta_ts_name = get_log_name(
"delta_timestamp_s", "read", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
@@ -749,7 +788,9 @@ class DynamixelMotorsBus:
return values
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
def write_with_motor_ids(
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
if self.mock:
import tests.mock_dynamixel_sdk as dxl
else:
@@ -778,7 +819,12 @@ class DynamixelMotorsBus:
f"{self.packet_handler.getTxRxResult(comm)}"
)
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
def write(
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
@@ -839,7 +885,9 @@ class DynamixelMotorsBus:
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
delta_ts_name = get_log_name(
"delta_timestamp_s", "write", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?

View File

@@ -8,7 +8,10 @@ from copy import deepcopy
import numpy as np
import tqdm
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from lerobot.common.utils.utils import capture_timestamp_utc
PROTOCOL_VERSION = 0
@@ -122,7 +125,9 @@ NUM_READ_RETRY = 20
NUM_WRITE_RETRY = 20
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
def convert_degrees_to_steps(
degrees: float | np.ndarray, models: str | list[str]
) -> np.ndarray:
"""This function converts the degree range to the step range for indicating motors rotation.
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
@@ -358,7 +363,9 @@ class FeetechMotorsBus:
indices = []
for idx in tqdm.tqdm(possible_ids):
try:
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
present_idx = self.read_with_motor_ids(
self.motor_models, [idx], "ID", num_retry=num_retry
)[0]
except ConnectionError:
continue
@@ -374,7 +381,9 @@ class FeetechMotorsBus:
def set_bus_baudrate(self, baudrate):
present_bus_baudrate = self.port_handler.getBaudRate()
if present_bus_baudrate != baudrate:
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
print(
f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}."
)
self.port_handler.setBaudRate(baudrate)
if self.port_handler.getBaudRate() != baudrate:
@@ -395,7 +404,9 @@ class FeetechMotorsBus:
def set_calibration(self, calibration: dict[str, list]):
self.calibration = calibration
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
def apply_calibration_autocorrect(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
@@ -408,7 +419,9 @@ class FeetechMotorsBus:
values = self.apply_calibration(values, motor_names)
return values
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def apply_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
a "zero position" at 0 degree.
@@ -482,7 +495,9 @@ class FeetechMotorsBus:
return values
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def autocorrect_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
Some motors might have values outside of expected maximum bounds after calibration.
@@ -521,18 +536,26 @@ class FeetechMotorsBus:
values[i] *= -1
# Convert from initial range to range [-180, 180] degrees
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
calib_val = (
(values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
)
in_range = (calib_val > LOWER_BOUND_DEGREE) and (
calib_val < UPPER_BOUND_DEGREE
)
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
low_factor = (
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
- values[i]
- homing_offset
) / resolution
upp_factor = (
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2)
- values[i]
- homing_offset
) / resolution
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
@@ -541,7 +564,9 @@ class FeetechMotorsBus:
# Convert from initial range to range [0, 100] in %
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
in_range = (calib_val > LOWER_BOUND_LINEAR) and (
calib_val < UPPER_BOUND_LINEAR
)
# Solve this inequality to find the factor to shift the range into [0, 100] %
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
@@ -557,19 +582,27 @@ class FeetechMotorsBus:
factor = math.ceil(low_factor)
if factor > upp_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
else:
factor = math.ceil(upp_factor)
if factor > low_factor:
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
raise ValueError(
f"No integer found between bounds [{low_factor=}, {upp_factor=}]"
)
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
out_of_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
in_range_str = (
f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
)
logging.warning(
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
@@ -579,7 +612,9 @@ class FeetechMotorsBus:
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
self.calibration["homing_offset"][calib_idx] += resolution * factor
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
def revert_calibration(
self, values: np.ndarray | list, motor_names: list[str] | None
):
"""Inverse of `apply_calibration`."""
if motor_names is None:
motor_names = self.motor_names
@@ -655,7 +690,9 @@ class FeetechMotorsBus:
return values
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
def read_with_motor_ids(
self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY
):
if self.mock:
import tests.mock_scservo_sdk as scs
else:
@@ -760,7 +797,9 @@ class FeetechMotorsBus:
values = self.apply_calibration_autocorrect(values, motor_names)
# log the number of seconds it took to read the data from the motors
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
delta_ts_name = get_log_name(
"delta_timestamp_s", "read", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# log the utc time at which the data was received
@@ -769,7 +808,9 @@ class FeetechMotorsBus:
return values
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
def write_with_motor_ids(
self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY
):
if self.mock:
import tests.mock_scservo_sdk as scs
else:
@@ -798,7 +839,12 @@ class FeetechMotorsBus:
f"{self.packet_handler.getTxRxResult(comm)}"
)
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
def write(
self,
data_name,
values: int | float | np.ndarray,
motor_names: str | list[str] | None = None,
):
if not self.is_connected:
raise RobotDeviceNotConnectedError(
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
@@ -859,7 +905,9 @@ class FeetechMotorsBus:
)
# log the number of seconds it took to write the data to the motors
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
delta_ts_name = get_log_name(
"delta_timestamp_s", "write", data_name, motor_names
)
self.logs[delta_ts_name] = time.perf_counter() - start_time
# TODO(rcadene): should we log the time before sending the write command?

View File

@@ -10,9 +10,7 @@ from lerobot.common.robot_devices.motors.dynamixel import (
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
# The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used.
@@ -23,7 +21,9 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
raise ValueError(
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
def apply_drive_mode(position, drive_mode):
@@ -64,12 +64,16 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
```
"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@@ -90,10 +94,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
rotated_target_pos = convert_degrees_to_steps(
ROTATED_POSITION_DEGREE, arm.motor_models
)
# Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@@ -102,11 +111,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
# Re-compute homing offset to take into account drive mode
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
rotated_nearest_pos = compute_nearest_rounded_position(
rotated_drived_pos, arm.motor_models
)
homing_offset = rotated_target_pos - rotated_nearest_pos
print("\nMove arm to rest position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
input("Press Enter to continue...")
print()

View File

@@ -12,9 +12,7 @@ from lerobot.common.robot_devices.motors.feetech import (
)
from lerobot.common.robot_devices.motors.utils import MotorsBus
URL_TEMPLATE = (
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
)
URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
# The following positions are provided in nominal degree range ]-180, +180[
# For more info on these constants, see comments in the code where they get used.
@@ -25,7 +23,9 @@ ROTATED_POSITION_DEGREE = 90
def assert_drive_mode(drive_mode):
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
if not np.all(np.isin(drive_mode, [0, 1])):
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
raise ValueError(
f"`drive_mode` contains values other than 0 or 1: ({drive_mode})"
)
def apply_drive_mode(position, drive_mode):
@@ -126,7 +126,9 @@ def apply_offset(calib, offset):
return calib
def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
def run_arm_auto_calibration(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
if robot_type == "so100":
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
elif robot_type == "moss":
@@ -135,18 +137,27 @@ def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm
raise ValueError(robot_type)
def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
def run_arm_auto_calibration_so100(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
if not (robot_type == "so100" and arm_type == "follower"):
raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
raise NotImplementedError(
"Auto calibration only supports the follower of so100 arms for now."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254])
@@ -193,11 +204,16 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
print("Calibrate elbow_flex")
calib["elbow_flex"] = move_to_calibrate(
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
arm,
"elbow_flex",
positive_first=False,
in_between_move_hook=in_between_move_hook,
)
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
arm.write(
"Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex"
)
time.sleep(1)
def in_between_move_hook():
@@ -225,18 +241,30 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
}
arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift")
arm.write(
"Goal_Position",
round(calib["shoulder_lift"]["zero_pos"] - 1600),
"shoulder_lift",
)
time.sleep(2)
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
arm.write(
"Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex"
)
time.sleep(2)
arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
arm.write(
"Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex"
)
time.sleep(2)
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
time.sleep(2)
print("Calibrate wrist_roll")
calib["wrist_roll"] = move_to_calibrate(
arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook
arm,
"wrist_roll",
invert_drive_mode=True,
positive_first=False,
while_move_hook=while_move_hook,
)
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
@@ -246,7 +274,9 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
time.sleep(1)
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
arm.write(
"Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift"
)
time.sleep(1)
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
time.sleep(1)
@@ -275,18 +305,27 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
return calib_dict
def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
def run_arm_auto_calibration_moss(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
if not (robot_type == "moss" and arm_type == "follower"):
raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
raise NotImplementedError(
"Auto calibration only supports the follower of moss arms for now."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to initial position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")
)
input("Press Enter to continue...")
# Lower the acceleration of the motors (in [0,254])
@@ -370,8 +409,12 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
time.sleep(1)
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
arm.write(
"Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift"
)
arm.write(
"Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex"
)
time.sleep(2)
calib_modes = []
@@ -398,7 +441,9 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
return calib_dict
def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
def run_arm_manual_calibration(
arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str
):
"""This function ensures that a neural network trained on data collected on a given robot
can work on another robot. For instance before calibration, setting a same goal position
for each motor of two different robots will get two very different positions. But after calibration,
@@ -421,12 +466,16 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
```
"""
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run calibration, the torque must be disabled on all motors.")
raise ValueError(
"To run calibration, the torque must be disabled on all motors."
)
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
print("\nMove arm to zero position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")
)
input("Press Enter to continue...")
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
@@ -446,10 +495,15 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
# of the previous motor in the kinetic chain.
print("\nMove arm to rotated target position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
print(
"See: "
+ URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")
)
input("Press Enter to continue...")
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
rotated_target_pos = convert_degrees_to_steps(
ROTATED_POSITION_DEGREE, arm.motor_models
)
# Find drive mode by rotating each motor by a quarter of a turn.
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
@@ -461,7 +515,9 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
homing_offset = rotated_target_pos - rotated_drived_pos
print("\nMove arm to rest position")
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
print(
"See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")
)
input("Press Enter to continue...")
print()

View File

@@ -18,11 +18,16 @@ import torch
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.robots.utils import get_arm_id
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
def ensure_safe_goal_position(
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
goal_pos: torch.Tensor,
present_pos: torch.Tensor,
max_relative_target: float | list[float],
):
# Cap relative action target magnitude for safety.
diff = goal_pos - present_pos
@@ -32,7 +37,7 @@ def ensure_safe_goal_position(
safe_goal_pos = present_pos + safe_diff
if not torch.allclose(goal_pos, safe_goal_pos):
logging.warning(
logging.debug(
"Relative goal position magnitude had to be clamped to be safe.\n"
f" requested relative goal position target: {diff}\n"
f" clamped relative goal position target: {safe_diff}"
@@ -67,8 +72,14 @@ class ManipulatorRobotConfig:
# gripper is not put in torque mode.
gripper_open_degree: float | None = None
joint_position_relative_bounds: dict[np.ndarray] | None = None
def __setattr__(self, prop: str, val):
if prop == "max_relative_target" and val is not None and isinstance(val, Sequence):
if (
prop == "max_relative_target"
and val is not None
and isinstance(val, Sequence)
):
for name in self.follower_arms:
if len(self.follower_arms[name].motors) != len(val):
raise ValueError(
@@ -78,11 +89,16 @@ class ManipulatorRobotConfig:
"Note: This feature does not yet work with robots where different follower arms have "
"different numbers of motors."
)
if prop == "joint_position_relative_bounds" and val is not None:
for key in val:
val[key] = torch.tensor(val[key])
super().__setattr__(prop, val)
def __post_init__(self):
if self.robot_type not in ["koch", "koch_bimanual", "aloha", "so100", "moss"]:
raise ValueError(f"Provided robot type ({self.robot_type}) is not supported.")
raise ValueError(
f"Provided robot type ({self.robot_type}) is not supported."
)
class ManipulatorRobot:
@@ -336,7 +352,9 @@ class ManipulatorRobot:
# to squeeze the gripper and have it spring back to an open position on its own.
for name in self.leader_arms:
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
self.leader_arms[name].write(
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
# Check both arms can be read
for name in self.follower_arms:
@@ -368,18 +386,26 @@ class ManipulatorRobot:
print(f"Missing calibration file '{arm_calib_path}'")
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration
from lerobot.common.robot_devices.robots.dynamixel_calibration import (
run_arm_calibration,
)
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
calibration = run_arm_calibration(
arm, self.robot_type, name, arm_type
)
elif self.robot_type in ["so100", "moss"]:
from lerobot.common.robot_devices.robots.feetech_calibration import (
run_arm_manual_calibration,
)
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
calibration = run_arm_manual_calibration(
arm, self.robot_type, name, arm_type
)
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
print(
f"Calibration is done! Saving calibration file '{arm_calib_path}'"
)
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
with open(arm_calib_path, "w") as f:
json.dump(calibration, f)
@@ -398,13 +424,17 @@ class ManipulatorRobot:
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
raise ValueError(
"To run set robot preset, the torque must be disabled on all motors."
)
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
all_motors_except_gripper = [
name for name in arm.motor_names if name != "gripper"
]
if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Koch motors
arm.write("Operating_Mode", 4, all_motors_except_gripper)
@@ -433,7 +463,9 @@ class ManipulatorRobot:
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
# so that we can use it as a trigger to close the gripper of the follower arms.
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper")
self.leader_arms[name].write(
"Goal_Position", self.config.gripper_open_degree, "gripper"
)
def set_aloha_robot_preset(self):
def set_shadow_(arm):
@@ -463,11 +495,15 @@ class ManipulatorRobot:
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
all_motors_except_gripper = [
name for name in self.follower_arms[name].motor_names if name != "gripper"
name
for name in self.follower_arms[name].motor_names
if name != "gripper"
]
if len(all_motors_except_gripper) > 0:
# 4 corresponds to Extended Position on Aloha motors
self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper)
self.follower_arms[name].write(
"Operating_Mode", 4, all_motors_except_gripper
)
# Use 'position control current based' for follower gripper to be limited by the limit of the current.
# It can grasp an object without forcing too much even tho,
@@ -515,7 +551,9 @@ class ManipulatorRobot:
before_lread_t = time.perf_counter()
leader_pos[name] = self.leader_arms[name].read("Present_Position")
leader_pos[name] = torch.from_numpy(leader_pos[name])
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
self.logs[f"read_leader_{name}_pos_dt_s"] = (
time.perf_counter() - before_lread_t
)
# Send goal position to the follower
follower_goal_pos = {}
@@ -523,19 +561,31 @@ class ManipulatorRobot:
before_fwrite_t = time.perf_counter()
goal_pos = leader_pos[name]
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
if self.config.joint_position_relative_bounds is not None:
goal_pos = torch.clamp(
goal_pos,
self.config.joint_position_relative_bounds["min"],
self.config.joint_position_relative_bounds["max"],
)
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
goal_pos = ensure_safe_goal_position(
goal_pos, present_pos, self.config.max_relative_target
)
# Used when record_data=True
follower_goal_pos[name] = goal_pos
goal_pos = goal_pos.numpy().astype(np.int32)
self.follower_arms[name].write("Goal_Position", goal_pos)
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = (
time.perf_counter() - before_fwrite_t
)
# Early exit when recording data is not requested
if not record_data:
@@ -548,7 +598,9 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
self.logs[f"read_follower_{name}_pos_dt_s"] = (
time.perf_counter() - before_fread_t
)
# Create state by concatenating follower current position
state = []
@@ -570,8 +622,12 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionnaries
obs_dict, action_dict = {}, {}
@@ -595,7 +651,9 @@ class ManipulatorRobot:
before_fread_t = time.perf_counter()
follower_pos[name] = self.follower_arms[name].read("Present_Position")
follower_pos[name] = torch.from_numpy(follower_pos[name])
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
self.logs[f"read_follower_{name}_pos_dt_s"] = (
time.perf_counter() - before_fread_t
)
# Create state by concatenating follower current position
state = []
@@ -610,8 +668,12 @@ class ManipulatorRobot:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionnaries and format to pytorch
obs_dict = {}
@@ -644,18 +706,29 @@ class ManipulatorRobot:
goal_pos = action[from_idx:to_idx]
from_idx = to_idx
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
if self.config.joint_position_relative_bounds is not None:
goal_pos = torch.clamp(
goal_pos,
self.config.joint_position_relative_bounds["min"],
self.config.joint_position_relative_bounds["max"],
)
# Cap goal position when too far away from present position.
# Slower fps expected due to reading from the follower.
if self.config.max_relative_target is not None:
present_pos = self.follower_arms[name].read("Present_Position")
present_pos = torch.from_numpy(present_pos)
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
goal_pos = ensure_safe_goal_position(
goal_pos, present_pos, self.config.max_relative_target
)
# Save tensor to concat and return
action_sent.append(goal_pos)
# Send goal position to each follower
goal_pos = goal_pos.numpy().astype(np.int32)
self.follower_arms[name].write("Goal_Position", goal_pos)
return torch.cat(action_sent)

View File

@@ -60,7 +60,9 @@ class StretchRobot(StretchAPI):
def connect(self) -> None:
self.is_connected = self.startup()
if not self.is_connected:
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
print(
"Another process is already using Stretch. Try running 'stretch_free_robot_process.py'"
)
raise ConnectionError()
for name in self.cameras:
@@ -68,7 +70,9 @@ class StretchRobot(StretchAPI):
self.is_connected = self.is_connected and self.cameras[name].is_connected
if not self.is_connected:
print("Could not connect to the cameras, check that all cameras are plugged-in.")
print(
"Could not connect to the cameras, check that all cameras are plugged-in."
)
raise ConnectionError()
self.run_calibration()
@@ -113,8 +117,12 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionnaries
obs_dict, action_dict = {}, {}
@@ -158,8 +166,12 @@ class StretchRobot(StretchAPI):
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[
"delta_timestamp_s"
]
self.logs[f"async_read_camera_{name}_dt_s"] = (
time.perf_counter() - before_camread_t
)
# Populate output dictionnaries
obs_dict = {}

View File

@@ -34,7 +34,8 @@ class RobotDeviceNotConnectedError(Exception):
"""Exception raised when the robot device is not connected."""
def __init__(
self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
self,
message="This robot device is not connected. Try calling `robot_device.connect()` first.",
):
self.message = message
super().__init__(self.message)

View File

@@ -17,7 +17,9 @@ import importlib
import logging
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
def is_package_available(
pkg_name: str, return_version: bool = False
) -> tuple[bool, str] | bool:
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
Check if the package spec exists and grab its version to avoid importing a local directory.
**Note:** this doesn't work for all packages.

View File

@@ -22,6 +22,8 @@ def write_video(video_path, stacked_frames, fps):
# Filter out DeprecationWarnings raised from pkg_resources
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
"ignore",
"pkg_resources is deprecated as an API",
category=DeprecationWarning,
)
imageio.mimsave(video_path, stacked_frames, fps=fps)

View File

@@ -18,6 +18,7 @@ import os
import os.path as osp
import platform
import random
import time
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
@@ -115,11 +116,11 @@ def seeded_context(seed: int) -> Generator[None, None, None]:
set_global_random_state(random_state_dict)
def init_logging():
def init_logging(log_file=None):
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}"
return message
logging.basicConfig(level=logging.INFO)
@@ -133,6 +134,12 @@ def init_logging():
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
if log_file is not None:
# File handler
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logging.getLogger().addHandler(file_handler)
def format_big_number(num, precision=0):
suffixes = ["", "K", "M", "B", "T", "Q"]
@@ -155,11 +162,16 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
"/".join(
[".."] * (len(path2.parts) - len(common_parts))
+ list(path1.parts[len(common_parts) :])
)
)
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
def init_hydra_config(
config_path: str, overrides: list[str] | None = None
) -> DictConfig:
"""Initialize a Hydra config given only the path to the relevant config file.
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
@@ -168,7 +180,11 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
hydra.core.global_hydra.GlobalHydra.instance().clear()
# Hydra needs a path relative to this file.
hydra.initialize(
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)),
str(
_relative_path_between(
Path(config_path).absolute().parent, Path(__file__).absolute().parent
)
),
version_base="1.2",
)
cfg = hydra.compose(Path(config_path).stem, overrides)
@@ -182,10 +198,26 @@ def print_cuda_memory_usage():
gc.collect()
# Also clear the cache if you want to fully release the memory
torch.cuda.empty_cache()
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
print(
"Current GPU Memory Allocated: {:.2f} MB".format(
torch.cuda.memory_allocated(0) / 1024**2
)
)
print(
"Maximum GPU Memory Allocated: {:.2f} MB".format(
torch.cuda.max_memory_allocated(0) / 1024**2
)
)
print(
"Current GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.memory_reserved(0) / 1024**2
)
)
print(
"Maximum GPU Memory Reserved: {:.2f} MB".format(
torch.cuda.max_memory_reserved(0) / 1024**2
)
)
def capture_timestamp_utc():
@@ -217,3 +249,33 @@ def log_say(text, play_sounds, blocking=False):
if play_sounds:
say(text, blocking)
class TimerManager:
def __init__(
self,
elapsed_time_list: list[float] | None = None,
label="Elapsed time",
log=True,
):
self.label = label
self.elapsed_time_list = elapsed_time_list
self.log = log
self.elapsed = 0.0
def __enter__(self):
self.start = time.perf_counter()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.elapsed: float = time.perf_counter() - self.start
if self.elapsed_time_list is not None:
self.elapsed_time_list.append(self.elapsed)
if self.log:
print(f"{self.label}: {self.elapsed:.6f} seconds")
@property
def elapsed_seconds(self):
return self.elapsed

View File

@@ -2,6 +2,7 @@ defaults:
- _self_
- env: pusht
- policy: diffusion
- robot: so100
hydra:
run:

View File

@@ -0,0 +1,30 @@
# @package _global_
fps: 400
env:
name: maniskill/pushcube
task: PushCube-v1
image_size: 64
control_mode: pd_ee_delta_pose
state_dim: 25
action_dim: 7
fps: ${fps}
obs: rgb
render_mode: rgb_array
render_size: 64
device: cuda
reward_classifier:
pretrained_path: null
config_path: null
wrapper:
joint_masking_action_space: null
delta_action: null
video_record:
enabled: false
record_dir: maniskill_videos
trajectory_name: trajectory
fps: ${fps}

View File

@@ -1,10 +1,50 @@
# @package _global_
fps: 30
fps: 10
env:
name: real_world
task: null
state_dim: 6
action_dim: 6
state_dim: 15
action_dim: 3
fps: ${fps}
device: mps
wrapper:
crop_params_dict:
observation.images.front: [171, 207, 116, 251]
observation.images.side: [232, 200, 142, 204]
resize_size: [128, 128]
control_time_s: 10
reset_follower_pos: false
use_relative_joint_positions: true
reset_time_s: 5
display_cameras: false
delta_action: null #0.3
joint_masking_action_space: null #[1, 1, 1, 1, 0, 0] # disable wrist and gripper
add_joint_velocity_to_observation: true
add_ee_pose_to_observation: true
# If null then the teleoperation will be used to reset the robot
# Bounds for pushcube_gamepad_lerobot15 dataset and experiments
# fixed_reset_joint_positions: [-19.86, 103.19, 117.33, 42.7, 13.89, 0.297]
# ee_action_space_params: # If null then ee_action_space is not used
# bounds:
# max: [0.291, 0.147, 0.074]
# min: [0.139, -0.143, 0.03]
# Bounds for insertcube_gamepad dataset and experiments
fixed_reset_joint_positions: [20.0, 90., 90., 75., -0.7910156, -0.5673759]
ee_action_space_params:
bounds:
max: [0.25295413, 0.07498981, 0.06862044]
min: [0.2010096, -0.12, 0.0433196]
use_gamepad: true
x_step_size: 0.03
y_step_size: 0.03
z_step_size: 0.03
reward_classifier:
pretrained_path: null # outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
config_path: null # lerobot/configs/policy/hilserl_classifier.yaml

View File

@@ -0,0 +1,61 @@
# @package _global_
defaults:
- _self_
hydra:
run:
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
dir: outputs/train_hilserl_classifier/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${hydra.job.name}
job:
name: default
seed: 13
dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized
# aractingi/push_cube_square_reward_1_cropped_resized
dataset_root: data/aractingi/push_cube_square_light_reward_cropped_resized
local_files_only: true
train_split_proportion: 0.8
# Required by logger
env:
name: "classifier"
task: "binary_classification"
training:
num_epochs: 6
batch_size: 16
learning_rate: 1e-4
num_workers: 4
grad_clip_norm: 10
use_amp: true
log_freq: 1
eval_freq: 1 # How often to run validation (in epochs)
save_freq: 1 # How often to save checkpoints (in epochs)
save_checkpoint: true
image_keys: ["observation.images.front", "observation.images.side"]
label_key: "next.reward"
profile_inference_time: false
profile_inference_time_iters: 20
eval:
batch_size: 16
num_samples_to_log: 30 # Number of validation samples to log in the table
policy:
name: "hilserl/classifier"
model_name: "helper2424/resnet10" # "facebook/convnext-base-224
model_type: "cnn"
num_cameras: 2 # Has to be len(training.image_keys)
wandb:
enable: false
project: "classifier-training"
job_name: "classifier_training_0"
disable_artifact: false
device: "mps"
resume: false
output_dir: "outputs/classifier/old_trainer_resnet10_frozen"

View File

@@ -0,0 +1,118 @@
# @package _global_
# Train with:
#
# python lerobot/scripts/train.py \
# +dataset=lerobot/pusht_keypoints
# env=pusht \
# env.gym.obs_type=environment_state_agent_pos \
seed: 1
# dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium"
dataset_repo_id: null
training:
# Offline training dataloader
num_workers: 4
batch_size: 512
grad_clip_norm: 40.0
lr: 3e-4
storage_device: "cuda"
eval_freq: 2500
log_freq: 10
save_freq: 1000000
online_steps: 1000000
online_rollout_n_episodes: 10
online_rollout_batch_size: 10
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 200000
offline_buffer_capacity: 100000
online_buffer_seed_size: 0
online_step_before_learning: 500
do_online_rollout_async: false
policy_update_freq: 1
policy:
name: sac
pretrained_model_path:
# Input / output structure.
n_action_repeats: 1
horizon: 1
n_action_steps: 1
shared_encoder: true
# vision_encoder_name: "helper2424/resnet10"
vision_encoder_name: null
# freeze_vision_encoder: true
freeze_vision_encoder: false
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]
observation.image: [3, 64, 64]
output_shapes:
action: [7]
camera_number: 1
# Normalization / Unnormalization
# input_normalization_modes: null
input_normalization_modes:
observation.state: min_max
observation.image: mean_std
# input_normalization_params: null
input_normalization_params:
observation.state:
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
0.4001]
observation.image:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
output_normalization_modes:
action: min_max
output_normalization_params:
action:
min: [-0.03, -0.03, -0.03, -0.03, -0.03, -0.03, -0.03]
max: [0.03, 0.03, 0.03, 0.03, 0.03, 0.03, 0.03]
output_normalization_shapes:
action: [7]
# Architecture / modeling.
# Neural networks.
image_encoder_hidden_dim: 32
# discount: 0.99
discount: 0.80
temperature_init: 1.0
num_critics: 2 #10
num_subsample_critics: null
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
# critic_target_update_weight: 0.005
critic_target_update_weight: 0.01
utd_ratio: 2 # 10
actor_learner_config:
learner_host: "127.0.0.1"
learner_port: 50051
policy_parameters_push_frequency: 4
concurrency:
actor: 'threads'
learner: 'threads'

View File

@@ -0,0 +1,89 @@
# @package _global_
# Train with:
#
# python lerobot/scripts/train.py \
# env=pusht \
# +dataset=lerobot/pusht_keypoints
seed: 1
dataset_repo_id: lerobot/pusht_keypoints
training:
offline_steps: 0
# Offline training dataloader
num_workers: 4
batch_size: 128
grad_clip_norm: 10.0
lr: 3e-4
eval_freq: 50000
log_freq: 500
save_freq: 50000
online_steps: 1000000
online_rollout_n_episodes: 10
online_rollout_batch_size: 10
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 40000
online_buffer_seed_size: 0
do_online_rollout_async: false
delta_timestamps:
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
action: "[i / ${fps} for i in range(${policy.horizon})]"
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
policy:
name: sac
pretrained_model_path:
# Input / output structure.
n_action_repeats: 1
horizon: 5
n_action_steps: 5
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.environment_state: [16]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.environment_state: min_max
observation.state: min_max
output_normalization_modes:
action: min_max
# Architecture / modeling.
# Neural networks.
# image_encoder_hidden_dim: 32
discount: 0.99
temperature_init: 1.0
num_critics: 2
num_subsample_critics: None
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
critic_target_update_weight: 0.005
utd_ratio: 2
# # Loss coefficients.
# reward_coeff: 0.5
# expectile_weight: 0.9
# value_coeff: 0.1
# consistency_coeff: 20.0
# advantage_scaling: 3.0
# pi_coeff: 0.5
# temporal_decay_coeff: 0.5
# # Target model.
# target_model_momentum: 0.995

View File

@@ -0,0 +1,120 @@
# @package _global_
# Train with:
#
# python lerobot/scripts/train.py \
# +dataset=lerobot/pusht_keypoints
# env=pusht \
# env.gym.obs_type=environment_state_agent_pos \
seed: 1
dataset_repo_id: aractingi/insertcube_simple
training:
# Offline training dataloader
num_workers: 4
# batch_size: 256
batch_size: 512
grad_clip_norm: 10.0
lr: 3e-4
eval_freq: 2500
log_freq: 1
save_freq: 2000000
online_steps: 1000000
online_rollout_n_episodes: 10
online_rollout_batch_size: 10
online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0
online_env_seed: 10000
online_buffer_capacity: 10000
online_buffer_seed_size: 0
online_step_before_learning: 100 #5000
do_online_rollout_async: false
policy_update_freq: 1
# delta_timestamps:
# observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
# observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
# action: "[i / ${fps} for i in range(${policy.horizon})]"
# next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
policy:
name: sac
pretrained_model_path:
# Input / output structure.
n_action_repeats: 1
horizon: 1
n_action_steps: 1
shared_encoder: true
vision_encoder_name: "helper2424/resnet10"
freeze_vision_encoder: true
input_shapes:
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.state: ["${env.state_dim}"]
observation.images.front: [3, 128, 128]
observation.images.side: [3, 128, 128]
# observation.image: [3, 128, 128]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.front: mean_std
observation.images.side: mean_std
observation.state: min_max
input_normalization_params:
observation.images.front:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
observation.images.side:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
observation.state:
# 6- joint positions, 6- joint velocities, 3- ee position
max: [ 52.822266, 136.14258, 142.03125, 72.1582, 22.675781, -0.5673759, 100., 100., 100., 100., 100., 100., 0.25295413, 0.07498981, 0.06862044]
min: [-2.6367188, 86.572266, 89.82422, 12.392578, -26.015625, -0.5673759, -100., -100., -100., -100., -100., -100., 0.2010096, -0.12, 0.0433196]
output_normalization_modes:
action: min_max
output_normalization_params:
action:
min: [-0.03, -0.03, -0.01]
max: [0.03, 0.03, 0.03]
# Architecture / modeling.
# Neural networks.
image_encoder_hidden_dim: 32
# discount: 0.99
discount: 0.97
temperature_init: 1.0
num_critics: 2 #10
camera_number: 2
num_subsample_critics: null
critic_lr: 3e-4
actor_lr: 3e-4
temperature_lr: 3e-4
# critic_target_update_weight: 0.005
critic_target_update_weight: 0.01
utd_ratio: 2 # 10
actor_learner_config:
learner_host: "127.0.0.1"
learner_port: 50051
policy_parameters_push_frequency: 15
# # Loss coefficients.
# reward_coeff: 0.5
# expectile_weight: 0.9
# value_coeff: 0.1
# consistency_coeff: 20.0
# advantage_scaling: 3.0
# pi_coeff: 0.5
# temporal_decay_coeff: 0.5
# # Target model.
# target_model_momentum: 0.995

View File

@@ -10,7 +10,7 @@ max_relative_target: null
leader_arms:
main:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem575E0031751
port: /dev/tty.usbmodem58760430441
motors:
# name: (index, model)
shoulder_pan: [1, "xl330-m077"]
@@ -23,7 +23,7 @@ leader_arms:
follower_arms:
main:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem575E0032081
port: /dev/tty.usbmodem585A0083391
motors:
# name: (index, model)
shoulder_pan: [1, "xl430-w250"]

View File

@@ -14,11 +14,18 @@ calibration_dir: .cache/calibration/so100
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
# the number of motors in your follower arms.
max_relative_target: null
joint_position_relative_bounds: null
# max: [100, 100, 100, 100, 100, 100]
# min: [-100, -100, -100, -100, -100, -100]
# max: [ 7.2158203e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
# min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
# max: [ 35.06836 , 103.18359 , 127.61719 , 75.58594 , 0., 0.]
# min: [ -8.876953 , 63.808594 , 90.49805 , 49.48242 , 0., 0.]
leader_arms:
main:
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
port: /dev/tty.usbmodem585A0077581
port: /dev/tty.usbmodem58760433331
motors:
# name: (index, model)
shoulder_pan: [1, "sts3215"]
@@ -31,7 +38,7 @@ leader_arms:
follower_arms:
main:
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
port: /dev/tty.usbmodem585A0080971
port: /dev/tty.usbmodem58760431631
motors:
# name: (index, model)
shoulder_pan: [1, "sts3215"]
@@ -42,15 +49,15 @@ follower_arms:
gripper: [6, "sts3215"]
cameras:
laptop:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 0
fps: 30
width: 640
height: 480
phone:
front:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 1
fps: 30
width: 640
height: 480
side:
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
camera_index: 0
fps: 30
width: 640
height: 480

View File

@@ -22,13 +22,17 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
from lerobot.common.robot_devices.motors.feetech import (
SCS_SERIES_BAUDRATE_TABLE as SERIES_BAUDRATE_TABLE,
)
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus as MotorsBusClass
from lerobot.common.robot_devices.motors.feetech import (
FeetechMotorsBus as MotorsBusClass,
)
elif brand == "dynamixel":
from lerobot.common.robot_devices.motors.dynamixel import MODEL_BAUDRATE_TABLE
from lerobot.common.robot_devices.motors.dynamixel import (
X_SERIES_BAUDRATE_TABLE as SERIES_BAUDRATE_TABLE,
)
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus as MotorsBusClass
from lerobot.common.robot_devices.motors.dynamixel import (
DynamixelMotorsBus as MotorsBusClass,
)
else:
raise ValueError(
f"Currently we do not support this motor brand: {brand}. We currently support feetech and dynamixel motors."
@@ -46,7 +50,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
motor_model = model # Use the motor model passed via argument
# Initialize the MotorBus with the correct port and motor configurations
motor_bus = MotorsBusClass(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)})
motor_bus = MotorsBusClass(
port=port, motors={motor_name: (motor_index_arbitrary, motor_model)}
)
# Try to connect to the motor bus and handle any connection-specific errors
try:
@@ -78,20 +84,26 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
motor_index = present_ids[0]
if motor_index == -1:
raise ValueError("No motors detected. Please ensure you have one motor connected.")
raise ValueError(
"No motors detected. Please ensure you have one motor connected."
)
print(f"Motor index found at: {motor_index}")
if brand == "feetech":
# Allows ID and BAUDRATE to be written in memory
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Lock", 0
)
if baudrate != baudrate_des:
print(f"Setting its baudrate to {baudrate_des}")
baudrate_idx = list(SERIES_BAUDRATE_TABLE.values()).index(baudrate_des)
# The write can fail, so we allow retries
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx)
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx
)
time.sleep(0.5)
motor_bus.set_bus_baudrate(baudrate_des)
present_baudrate_idx = motor_bus.read_with_motor_ids(
@@ -103,9 +115,13 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
print(f"Setting its index to desired index {motor_idx_des}")
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "ID", motor_idx_des
)
present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2)
present_idx = motor_bus.read_with_motor_ids(
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
)
if present_idx != motor_idx_des:
raise OSError("Failed to write index.")
@@ -133,12 +149,29 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)")
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)")
parser.add_argument(
"--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)"
"--port",
type=str,
required=True,
help="Motors bus port (e.g. dynamixel,feetech)",
)
parser.add_argument(
"--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)"
)
parser.add_argument(
"--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)"
)
parser.add_argument(
"--ID",
type=int,
required=True,
help="Desired ID of the current motor (e.g. 1,2,3)",
)
parser.add_argument(
"--baudrate",
type=int,
default=1000000,
help="Desired baudrate for the motor (default: 1000000)",
)
args = parser.parse_args()

View File

@@ -109,6 +109,7 @@ from lerobot.common.robot_devices.control_utils import (
log_control_info,
record_episode,
reset_environment,
reset_follower_position,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
stop_recording,
@@ -117,7 +118,12 @@ from lerobot.common.robot_devices.control_utils import (
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say, none_or_int
from lerobot.common.utils.utils import (
init_hydra_config,
init_logging,
log_say,
none_or_int,
)
########################################################################################
# Control modes
@@ -172,7 +178,10 @@ def calibrate(robot: Robot, arms: list[str] | None):
@safe_disconnect
def teleoperate(
robot: Robot, fps: int | None = None, teleop_time_s: float | None = None, display_cameras: bool = False
robot: Robot,
fps: int | None = None,
teleop_time_s: float | None = None,
display_cameras: bool = False,
):
control_loop(
robot,
@@ -191,6 +200,7 @@ def record(
single_task: str,
pretrained_policy_name_or_path: str | None = None,
policy_overrides: List[str] | None = None,
assign_rewards: bool = False,
fps: int | None = None,
warmup_time_s: int | float = 2,
episode_time_s: int | float = 10,
@@ -204,6 +214,8 @@ def record(
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
play_sounds: bool = True,
reset_follower: bool = False,
record_delta_actions: bool = False,
resume: bool = False,
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
local_files_only: bool = False,
@@ -214,6 +226,14 @@ def record(
policy = None
device = None
use_amp = None
extra_features = (
{
"next.reward": {"dtype": "int64", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
}
if assign_rewards
else None
)
if single_task:
task = single_task
@@ -222,11 +242,15 @@ def record(
# Load pretrained policy
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
policy, policy_fps, device, use_amp = init_policy(
pretrained_policy_name_or_path, policy_overrides
)
if fps is None:
fps = policy_fps
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
logging.warning(
f"No fps provided, so using the fps from policy config ({policy_fps})."
)
elif fps != policy_fps:
logging.warning(
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
@@ -242,7 +266,9 @@ def record(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
sanity_check_dataset_robot_compatibility(
dataset, robot, fps, video, extra_features
)
else:
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
@@ -253,13 +279,17 @@ def record(
robot=robot,
use_videos=video,
image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
image_writer_threads=num_image_writer_threads_per_camera
* len(robot.cameras),
features=extra_features,
)
if not robot.is_connected:
robot.connect()
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
listener, events = init_keyboard_listener()
if reset_follower:
initial_position = robot.follower_arms["main"].read("Present_Position")
# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,
@@ -267,7 +297,9 @@ def record(
# 3. place the cameras windows on screen
enable_teleoperation = policy is None
log_say("Warmup record", play_sounds)
warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps)
warmup_record(
robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps
)
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
@@ -293,6 +325,7 @@ def record(
device=device,
use_amp=use_amp,
fps=fps,
record_delta_actions=record_delta_actions,
)
# Execute a few seconds without recording to give time to manually reset the environment
@@ -300,9 +333,11 @@ def record(
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
(recorded_episodes < num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", play_sounds)
if reset_follower:
reset_follower_position(robot, initial_position)
reset_environment(robot, events, reset_time_s)
if events["rerecord_episode"]:
@@ -342,21 +377,26 @@ def replay(
fps: int | None = None,
play_sounds: bool = True,
local_files_only: bool = False,
replay_delta_actions: bool = False,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
dataset = LeRobotDataset(
repo_id, root=root, episodes=[episode], local_files_only=local_files_only
)
actions = dataset.hf_dataset.select_columns("action")
if not robot.is_connected:
robot.connect()
log_say("Replaying episode", play_sounds, blocking=True)
for idx in range(dataset.num_frames):
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
start_episode_t = time.perf_counter()
action = actions[idx]["action"]
if replay_delta_actions:
action = action + current_joint_positions
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
@@ -395,7 +435,10 @@ if __name__ == "__main__":
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
parser_teleop.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
"--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
)
parser_teleop.add_argument(
"--display-cameras",
@@ -407,7 +450,10 @@ if __name__ == "__main__":
parser_record = subparsers.add_parser("record", parents=[base_parser])
task_args = parser_record.add_mutually_exclusive_group(required=True)
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
"--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
)
task_args.add_argument(
"--single-task",
@@ -456,7 +502,9 @@ if __name__ == "__main__":
default=60,
help="Number of seconds for resetting the environment after each episode.",
)
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
parser_record.add_argument(
"--num-episodes", type=int, default=50, help="Number of episodes to record."
)
parser_record.add_argument(
"--run-compute-stats",
type=int,
@@ -469,12 +517,12 @@ if __name__ == "__main__":
default=1,
help="Upload dataset to Hugging Face hub.",
)
parser_record.add_argument(
"--tags",
type=str,
nargs="*",
help="Add tags to your dataset on the hub.",
)
# parser_record.add_argument(
# "--tags",
# type=str,
# nargs="*",
# help="Add tags to your dataset on the hub.",
# )
parser_record.add_argument(
"--num-image-writer-processes",
type=int,
@@ -517,10 +565,31 @@ if __name__ == "__main__":
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser_record.add_argument(
"--assign-rewards",
type=int,
default=0,
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
)
parser_record.add_argument(
"--record-delta-actions",
type=int,
default=0,
help="Enables the recording of delta actions instead of absolute actions.",
)
parser_record.add_argument(
"--reset-follower",
type=int,
default=0,
help="Resets the follower to the initial position during while reseting the evironment, this is to avoid having the follower start at an awkward position in the next episode",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
"--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
)
parser_replay.add_argument(
"--root",
@@ -540,7 +609,15 @@ if __name__ == "__main__":
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
parser_replay.add_argument(
"--replay-delta-actions",
type=int,
default=0,
help="Enables the replay of delta actions instead of absolute actions.",
)
parser_replay.add_argument(
"--episode", type=int, default=0, help="Index of the episode to replay."
)
args = parser.parse_args()

View File

@@ -135,7 +135,11 @@ def init_sim_calibration(robot, cfg):
axis_directions = np.array(cfg.get("axis_directions", [1]))
offsets = np.array(cfg.get("offsets", [0])) * np.pi
return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
return {
"start_pos": start_pos,
"axis_directions": axis_directions,
"offsets": offsets,
}
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
@@ -156,7 +160,10 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
leader_pos = robot.leader_arms.main.read("Present_Position")
action = process_action_fn(leader_pos)
env.step(np.expand_dims(action, 0))
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
if (
teleop_time_s is not None
and time.perf_counter() - start_teleop_t > teleop_time_s
):
print("Teleoperation processes finished.")
break
@@ -183,21 +190,35 @@ def record(
resume: bool = False,
local_files_only: bool = False,
run_compute_stats: bool = True,
assign_rewards: bool = False,
) -> LeRobotDataset:
# Load pretrained policy
extra_features = (
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}}
if assign_rewards
else None
)
policy = None
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
policy, policy_fps, device, use_amp = init_policy(
pretrained_policy_name_or_path, policy_overrides
)
if fps is None:
fps = policy_fps
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
logging.warning(
f"No fps provided, so using the fps from policy config ({policy_fps})."
)
if policy is None and process_action_from_leader is None:
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
raise ValueError(
"Either policy or process_action_fn has to be set to enable control in sim."
)
# initialize listener before sim env
listener, events = init_keyboard_listener()
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
# create sim env
env = env()
@@ -227,7 +248,11 @@ def record(
shape = env.observation_space[key].shape
if not key.startswith("observation.image."):
key = "observation.image." + key
features[key] = {"dtype": "video", "names": ["channel", "height", "width"], "shape": shape}
features[key] = {
"dtype": "video",
"names": ["channel", "height", "width"],
"shape": shape,
}
for key, obs_key in state_keys_dict.items():
features[key] = {
@@ -236,7 +261,12 @@ def record(
"shape": env.observation_space[obs_key].shape,
}
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
features["action"] = {
"dtype": "float32",
"shape": env.action_space.shape,
"names": None,
}
features = {**features, **extra_features}
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
@@ -288,6 +318,13 @@ def record(
"timestamp": env_timestamp,
}
# Overwrite environment reward with manually assigned reward
if assign_rewards:
frame["next.reward"] = events["next.reward"]
# Should success always be false to match what we do in control_utils?
frame["next.success"] = False
for key in image_keys:
if not key.startswith("observation.image"):
frame["observation.image." + key] = observation[key]
@@ -329,7 +366,9 @@ def record(
if events["stop_recording"] or recorded_episodes >= num_episodes:
break
else:
logging.info("Waiting for a few seconds before starting next episode recording...")
logging.info(
"Waiting for a few seconds before starting next episode recording..."
)
busy_wait(3)
log_say("Stop recording", play_sounds, blocking=True)
@@ -347,7 +386,12 @@ def record(
def replay(
env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
env,
root: Path,
repo_id: str,
episode: int,
fps: int | None = None,
local_files_only: bool = True,
):
env = env()
@@ -394,7 +438,10 @@ if __name__ == "__main__":
parser_record = subparsers.add_parser("record", parents=[base_parser])
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
"--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
)
parser_record.add_argument(
"--root",
@@ -420,7 +467,9 @@ if __name__ == "__main__":
required=True,
help="A description of the task preformed during recording that can be used as a language instruction.",
)
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
parser_record.add_argument(
"--num-episodes", type=int, default=50, help="Number of episodes to record."
)
parser_record.add_argument(
"--run-compute-stats",
type=int,
@@ -472,9 +521,19 @@ if __name__ == "__main__":
default=0,
help="Resume recording on an existing dataset.",
)
parser_record.add_argument(
"--assign-rewards",
type=int,
default=0,
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
"--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
)
parser_replay.add_argument(
"--root",
@@ -488,7 +547,9 @@ if __name__ == "__main__":
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
parser_replay.add_argument(
"--episode", type=int, default=0, help="Index of the episodes to replay."
)
args = parser.parse_args()

View File

@@ -59,7 +59,11 @@ np_version = np.__version__ if HAS_NP else "N/A"
torch_version = torch.__version__ if HAS_TORCH else "N/A"
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
cuda_version = (
torch._C._cuda_getCompiledVersion()
if HAS_TORCH and torch.version.cuda is not None
else "N/A"
)
# TODO(aliberts): refactor into an actual command `lerobot env`
@@ -77,7 +81,9 @@ def display_sys_info() -> dict:
"Using GPU in script?": "<fill in>",
# "Using distributed or parallel set-up in script?": "<fill in>",
}
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
print(
"\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n"
)
print(format_dict(info))
return info

View File

@@ -149,7 +149,9 @@ def rollout(
if return_observations:
all_observations.append(deepcopy(observation))
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
observation = {
key: observation[key].to(device, non_blocking=True) for key in observation
}
with torch.inference_mode():
action = policy.select_action(observation)
@@ -166,7 +168,10 @@ def rollout(
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished.
if "final_info" in info:
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
successes = [
info["is_success"] if info is not None else False
for info in info["final_info"]
]
else:
successes = [False] * env.num_envs
@@ -180,9 +185,13 @@ def rollout(
step += 1
running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any")
.numpy()
.mean()
)
progbar.set_postfix(
{"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
)
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update()
# Track the final observation.
@@ -200,7 +209,9 @@ def rollout(
if return_observations:
stacked_observations = {}
for key in all_observations[0]:
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
stacked_observations[key] = torch.stack(
[obs[key] for obs in all_observations], dim=1
)
ret["observation"] = stacked_observations
return ret
@@ -255,7 +266,9 @@ def eval_policy(
return
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
if isinstance(env, gym.vector.SyncVectorEnv):
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
ep_frames.append(
np.stack([env.envs[i].render() for i in range(n_to_render_now)])
) # noqa: B023
elif isinstance(env, gym.vector.AsyncVectorEnv):
# Here we must render all frames and discard any we don't need.
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
@@ -267,7 +280,9 @@ def eval_policy(
episode_data: dict | None = None
# we dont want progress bar when we use slurm, since it clutters the logs
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
progbar = trange(
n_batches, desc="Stepping through eval batches", disable=inside_slurm()
)
for batch_ix in progbar:
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
# step.
@@ -278,7 +293,8 @@ def eval_policy(
seeds = None
else:
seeds = range(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
start_seed + (batch_ix * env.num_envs),
start_seed + ((batch_ix + 1) * env.num_envs),
)
rollout_data = rollout(
env,
@@ -296,13 +312,22 @@ def eval_policy(
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
mask = (
torch.arange(n_steps)
<= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)
).int()
# Extend metrics.
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
batch_sum_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "sum"
)
sum_rewards.extend(batch_sum_rewards.tolist())
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
batch_max_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "max"
)
max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
batch_successes = einops.reduce(
(rollout_data["success"] * mask), "b n -> b", "any"
)
all_successes.extend(batch_successes.tolist())
if seeds:
all_seeds.extend(seeds)
@@ -315,17 +340,27 @@ def eval_policy(
rollout_data,
done_indices,
start_episode_index=batch_ix * env.num_envs,
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
start_data_index=(
0
if episode_data is None
else (episode_data["index"][-1].item() + 1)
),
fps=env.unwrapped.metadata["render_fps"],
)
if episode_data is None:
episode_data = this_episode_data
else:
# Some sanity checks to make sure we are correctly compiling the data.
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
assert (
episode_data["episode_index"][-1] + 1
== this_episode_data["episode_index"][0]
)
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
# Concatenate the episode data.
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
episode_data = {
k: torch.cat([episode_data[k], this_episode_data[k]])
for k in episode_data
}
# Maybe render video for visualization.
if max_episodes_rendered > 0 and len(ep_frames) > 0:
@@ -343,7 +378,9 @@ def eval_policy(
target=write_video,
args=(
str(video_path),
stacked_frames[: done_index + 1], # + 1 to capture the last observation
stacked_frames[
: done_index + 1
], # + 1 to capture the last observation
env.unwrapped.metadata["render_fps"],
),
)
@@ -352,7 +389,9 @@ def eval_policy(
n_episodes_rendered += 1
progbar.set_postfix(
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
{
"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"
}
)
# Wait till all video rendering threads are done.
@@ -398,7 +437,11 @@ def eval_policy(
def _compile_episode_data(
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
rollout_data: dict,
done_indices: Tensor,
start_episode_index: int,
start_data_index: int,
fps: float,
) -> dict:
"""Convenience function for `eval_policy(return_episode_data=True)`
@@ -416,12 +459,16 @@ def _compile_episode_data(
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = {
"action": rollout_data["action"][ep_ix, : num_frames - 1],
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"episode_index": torch.tensor(
[start_episode_index + ep_ix] * (num_frames - 1)
),
"frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(
torch.float32
),
}
# For the last observation frame, all other keys will just be copy padded.
@@ -437,7 +484,9 @@ def _compile_episode_data(
for key in ep_dicts[0]:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
data_dict["index"] = torch.arange(
start_data_index, start_data_index + total_frames, 1
)
return data_dict
@@ -450,7 +499,9 @@ def main(
):
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
if pretrained_policy_path is not None:
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
hydra_cfg = init_hydra_config(
str(pretrained_policy_path / "config.yaml"), config_overrides
)
else:
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
@@ -481,15 +532,23 @@ def main(
logging.info("Making policy.")
if hydra_cfg_path is None:
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
policy = make_policy(
hydra_cfg=hydra_cfg,
pretrained_policy_name_or_path=str(pretrained_policy_path),
)
else:
# Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats)
policy = make_policy(
hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats
)
assert isinstance(policy, nn.Module)
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(),
):
info = eval_policy(
env,
policy,
@@ -511,16 +570,14 @@ def main(
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
try:
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
pretrained_policy_path = Path(
snapshot_download(pretrained_policy_name_or_path, revision=revision)
)
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
error_message = "The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
error_message = "The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(pretrained_policy_name_or_path)
@@ -555,7 +612,9 @@ if __name__ == "__main__":
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
)
parser.add_argument(
"--out-dir",
help=(
@@ -571,7 +630,11 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.pretrained_policy_name_or_path is None:
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
main(
hydra_cfg_path=args.config,
out_dir=args.out_dir,
config_overrides=args.overrides,
)
else:
pretrained_policy_path = get_pretrained_policy_path(
args.pretrained_policy_name_or_path, revision=args.revision

View File

@@ -0,0 +1,426 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluate a policy by running rollouts on the real robot and computing metrics.
Usage examples: evaluate a checkpoint from the LeRobot training script for 10 episodes.
```
python lerobot/scripts/eval_on_robot.py \
-p outputs/train/model/checkpoints/005000/pretrained_model \
eval.n_episodes=10
```
Test reward classifier with teleoperation (you need to press space to take over)
```
python lerobot/scripts/eval_on_robot.py \
--robot-path lerobot/configs/robot/so100.yaml \
--reward-classifier-pretrained-path outputs/classifier/checkpoints/best/pretrained_model \
--reward-classifier-config-file lerobot/configs/policy/hilserl_classifier.yaml \
--display-cameras 1
```
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
for running training on the real robot.
"""
import argparse
import logging
import time
import cv2
import numpy as np
import torch
from tqdm import trange
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.robot_devices.control_utils import (
busy_wait,
is_headless,
reset_follower_position,
)
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
from lerobot.common.utils.utils import (
init_hydra_config,
init_logging,
log_say,
)
def get_classifier(pretrained_path, config_path):
if pretrained_path is None or config_path is None:
return
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(
cfg.training.image_keys
) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to("mps")
return model
def rollout(
robot: Robot,
policy: Policy,
reward_classifier,
fps: int,
control_time_s: float = 20,
use_amp: bool = True,
display_cameras: bool = False,
) -> dict:
"""Run a batched policy rollout on the real robot.
The return dictionary contains:
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
keys. NOTE the that this has an extra sequence element relative to the other keys in the
dictionary. This is because an extra observation is included for after the environment is
terminated or truncated.
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
including the last observations).
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
environment termination/truncation).
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
the first True is followed by True's all the way till the end. This can be used for masking
extraneous elements from the sequences above.
Args:
robot: The robot class that defines the interface with the real robot.
policy: The policy. Must be a PyTorch nn module.
Returns:
The dictionary described above.
"""
# TODO (michel-aractingi): Infer the device from policy parameters when policy is added
# assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
# device = get_device_from_parameters(policy)
# define keyboard listener
listener, events = init_keyboard_listener()
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
# policy.reset()
# NOTE: sorting to make sure the key sequence is the same during training and testing.
observation = robot.capture_observation()
image_keys = [key for key in observation if "image" in key]
image_keys.sort()
all_actions = []
all_rewards = []
all_successes = []
start_episode_t = time.perf_counter()
init_pos = robot.follower_arms["main"].read("Present_Position")
timestamp = 0.0
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
# Apply the next action.
while events["pause_policy"] and not events["human_intervention_step"]:
busy_wait(0.5)
if events["human_intervention_step"]:
# take over the robot's actions
observation, action = robot.teleop_step(record_data=True)
action = action["action"] # teleop step returns torch tensors but in a dict
else:
# explore with policy
with torch.inference_mode():
# TODO (michel-aractingi) replace this part with policy (predict_action)
action = robot.follower_arms["main"].read("Present_Position")
action = torch.from_numpy(action)
robot.send_action(action)
# action = predict_action(observation, policy, device, use_amp)
observation = robot.capture_observation()
images = []
for key in image_keys:
if display_cameras:
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
images.append(observation[key].to("mps"))
reward = (
reward_classifier.predict_reward(images)
if reward_classifier is not None
else 0.0
)
all_rewards.append(reward)
# print("REWARD : ", reward)
all_actions.append(action)
all_successes.append(torch.tensor([False]))
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
timestamp = time.perf_counter() - start_episode_t
if events["exit_early"]:
events["exit_early"] = False
events["human_intervention_step"] = False
events["pause_policy"] = False
break
reset_follower_position(robot, target_position=init_pos)
dones = torch.tensor([False] * len(all_actions))
dones[-1] = True
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
ret = {
"action": torch.stack(all_actions, dim=1),
"next.reward": torch.stack(all_rewards, dim=1),
"next.success": torch.stack(all_successes, dim=1),
"done": dones,
}
listener.stop()
return ret
def eval_policy(
robot: Robot,
policy: torch.nn.Module,
fps: float,
n_episodes: int,
control_time_s: int = 20,
use_amp: bool = True,
display_cameras: bool = False,
reward_classifier_pretrained_path: str | None = None,
reward_classifier_config_file: str | None = None,
) -> dict:
"""
Args:
env: The batch of environments.
policy: The policy.
n_episodes: The number of episodes to evaluate.
Returns:
Dictionary with metrics and data regarding the rollouts.
"""
# TODO (michel-aractingi) comment this out for testing with a fixed policy
# assert isinstance(policy, Policy)
# policy.eval()
sum_rewards = []
max_rewards = []
successes = []
rollouts = []
start_eval = time.perf_counter()
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
reward_classifier = get_classifier(
reward_classifier_pretrained_path, reward_classifier_config_file
)
for _ in progbar:
rollout_data = rollout(
robot,
policy,
reward_classifier,
fps,
control_time_s,
use_amp,
display_cameras,
)
rollouts.append(rollout_data)
sum_rewards.append(sum(rollout_data["next.reward"]))
max_rewards.append(max(rollout_data["next.reward"]))
successes.append(rollout_data["next.success"][-1])
info = {
"per_episode": [
{
"episode_ix": i,
"sum_reward": sum_reward,
"max_reward": max_reward,
"pc_success": success * 100,
}
for i, (sum_reward, max_reward, success) in enumerate(
zip(
sum_rewards[:n_episodes],
max_rewards[:n_episodes],
successes[:n_episodes],
strict=False,
)
)
],
"aggregated": {
"avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))),
"avg_max_reward": float(np.nanmean(torch.cat(max_rewards[:n_episodes]))),
"pc_success": float(np.nanmean(torch.cat(successes[:n_episodes])) * 100),
"eval_s": time.time() - start_eval,
"eval_ep_s": (time.time() - start_eval) / n_episodes,
},
}
if robot.is_connected:
robot.disconnect()
return info
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["pause_policy"] = False
events["human_intervention_step"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print(
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.space:
# check if first space press then pause the policy for the user to get ready
# if second space press then the user is ready to start intervention
if not events["pause_policy"]:
print(
"Space key pressed. Human intervention required.\n"
"Place the leader in similar pose to the follower and press space again."
)
events["pause_policy"] = True
log_say(
"Human intervention stage. Get ready to take over.",
play_sounds=True,
)
else:
events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.")
log_say("Starting human intervention.", play_sounds=True)
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
if __name__ == "__main__":
init_logging()
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
group.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
group.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
group.add_argument(
"--config",
help=(
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument(
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
)
parser.add_argument(
"--out-dir",
help=(
"Where to save the evaluation outputs. If not provided, outputs are saved in "
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
),
)
parser.add_argument(
"--display-cameras",
help=("Whether to display the camera feed while the rollout is happening"),
)
parser.add_argument(
"--reward-classifier-pretrained-path",
type=str,
default=None,
help="Path to the pretrained classifier weights.",
)
parser.add_argument(
"--reward-classifier-config-file",
type=str,
default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.",
)
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
robot = make_robot(robot_cfg)
if not robot.is_connected:
robot.connect()
eval_policy(
robot,
None,
fps=40,
n_episodes=2,
control_time_s=100,
display_cameras=args.display_cameras,
reward_classifier_config_file=args.reward_classifier_config_file,
reward_classifier_pretrained_path=args.reward_classifier_pretrained_path,
)

View File

@@ -32,9 +32,13 @@ def find_port():
print(f"The port of this MotorsBus is '{port}'")
print("Reconnect the USB cable.")
elif len(ports_diff) == 0:
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
raise OSError(
f"Could not detect the port. No difference was found ({ports_diff})."
)
else:
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
raise OSError(
f"Could not detect the port. More than one port was found ({ports_diff})."
)
if __name__ == "__main__":

View File

@@ -56,24 +56,42 @@ from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
from lerobot.common.datasets.utils import (
create_branch,
create_lerobot_dataset_card,
flatten_dict,
)
def get_from_raw_to_lerobot_format_fn(raw_format: str):
if raw_format == "pusht_zarr":
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import (
from_raw_to_lerobot_format,
)
elif raw_format == "umi_zarr":
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import (
from_raw_to_lerobot_format,
)
elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import (
from_raw_to_lerobot_format,
)
elif raw_format in ["rlds", "openx"]:
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import (
from_raw_to_lerobot_format,
)
elif raw_format == "dora_parquet":
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import (
from_raw_to_lerobot_format,
)
elif raw_format == "xarm_pkl":
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import (
from_raw_to_lerobot_format,
)
elif raw_format == "cam_png":
from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format
from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import (
from_raw_to_lerobot_format,
)
else:
raise ValueError(
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
@@ -83,7 +101,10 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
def save_meta_data(
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
info: dict[str, Any],
stats: dict,
episode_data_index: dict[str, list],
meta_data_dir: Path,
):
meta_data_dir.mkdir(parents=True, exist_ok=True)
@@ -97,12 +118,16 @@ def save_meta_data(
save_file(flatten_dict(stats), stats_path)
# save episode_data_index
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
episode_data_index = {
key: torch.tensor(episode_data_index[key]) for key in episode_data_index
}
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
save_file(episode_data_index, ep_data_idx_path)
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
def push_meta_data_to_hub(
repo_id: str, meta_data_dir: str | Path, revision: str | None
):
"""Expect all meta data files to be all stored in a single "meta_data" directory.
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
"""
@@ -187,7 +212,9 @@ def push_dataset_to_hub(
if force_override:
shutil.rmtree(local_dir)
elif not resume:
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
raise ValueError(
f"`local_dir` already exists ({local_dir}). Use `--force-override 1`."
)
meta_data_dir = local_dir / "meta_data"
videos_dir = local_dir / "videos"
@@ -223,7 +250,9 @@ def push_dataset_to_hub(
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
if local_dir:
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset = hf_dataset.with_format(
None
) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train"))
if push_to_hub or local_dir:

View File

@@ -0,0 +1,641 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from statistics import mean, quantiles
from functools import lru_cache
from lerobot.scripts.server.utils import setup_process_handlers
# from lerobot.scripts.eval import eval_policy
import grpc
import hydra
import torch
from omegaconf import DictConfig
from torch import nn
import time
# TODO: Remove the import of maniskill
# from lerobot.common.envs.factory import make_maniskill_env
# from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import (
TimerManager,
get_safe_torch_device,
set_global_seed,
)
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
from lerobot.scripts.server.buffer import (
Transition,
move_state_dict_to_device,
move_transition_to_device,
python_object_to_bytes,
transitions_to_bytes,
bytes_to_state_dict,
)
from lerobot.scripts.server.network_utils import (
receive_bytes_in_chunks,
send_bytes_in_chunks,
)
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
from lerobot.scripts.server import learner_service
from lerobot.common.robot_devices.utils import busy_wait
from torch.multiprocessing import Queue, Event
from queue import Empty
from lerobot.common.utils.utils import init_logging
from lerobot.scripts.server.utils import get_last_item_from_queue
ACTOR_SHUTDOWN_TIMEOUT = 30
def receive_policy(
cfg: DictConfig,
parameters_queue: Queue,
shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
logging.info("[ACTOR] Start receiving parameters from the Learner")
if not use_threads(cfg):
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(False)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
)
try:
iterator = learner_client.StreamParameters(hilserl_pb2.Empty())
receive_bytes_in_chunks(
iterator,
parameters_queue,
shutdown_event,
log_prefix="[ACTOR] parameters",
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Received policy loop stopped")
def transitions_stream(
shutdown_event: Event, transitions_queue: Queue
) -> hilserl_pb2.Empty:
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=5)
except Empty:
logging.debug("[ACTOR] Transition queue is empty")
continue
yield from send_bytes_in_chunks(
message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions"
)
return hilserl_pb2.Empty()
def interactions_stream(
shutdown_event: any, # Event,
interactions_queue: Queue,
) -> hilserl_pb2.Empty:
while not shutdown_event.is_set():
try:
message = interactions_queue.get(block=True, timeout=5)
except Empty:
logging.debug("[ACTOR] Interaction queue is empty")
continue
yield from send_bytes_in_chunks(
message,
hilserl_pb2.InteractionMessage,
log_prefix="[ACTOR] Send interactions",
)
return hilserl_pb2.Empty()
def send_transitions(
cfg: DictConfig,
transitions_queue: Queue,
shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> hilserl_pb2.Empty:
"""
Sends transitions to the learner.
This function continuously retrieves messages from the queue and processes:
- **Transition Data:**
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
"""
if not use_threads(cfg):
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(False)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
)
try:
learner_client.SendTransitions(
transitions_stream(shutdown_event, transitions_queue)
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming transitions")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Transitions process stopped")
def send_interactions(
cfg: DictConfig,
interactions_queue: Queue,
shutdown_event: any, # Event,
learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> hilserl_pb2.Empty:
"""
Sends interactions to the learner.
This function continuously retrieves messages from the queue and processes:
- **Interaction Messages:**
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
"""
if not use_threads(cfg):
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
setup_process_handlers(False)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
)
try:
learner_client.SendInteractions(
interactions_stream(shutdown_event, interactions_queue)
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming interactions")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Interactions process stopped")
@lru_cache(maxsize=1)
def learner_service_client(
host="127.0.0.1", port=50051
) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]:
import json
"""
Returns a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it.
"""
service_config = {
"methodConfig": [
{
"name": [{}], # Applies to ALL methods in ALL services
"retryPolicy": {
"maxAttempts": 5, # Max retries (total attempts = 5)
"initialBackoff": "0.1s", # First retry after 0.1s
"maxBackoff": "2s", # Max wait time between retries
"backoffMultiplier": 2, # Exponential backoff factor
"retryableStatusCodes": [
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
], # Retries on network failures
},
}
]
}
service_config_json = json.dumps(service_config)
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.enable_retries", 1),
("grpc.service_config", service_config_json),
],
)
stub = hilserl_pb2_grpc.LearnerServiceStub(channel)
logging.info("[ACTOR] Learner service client created")
return stub, channel
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
if not parameters_queue.empty():
logging.info("[ACTOR] Load new parameters from Learner.")
bytes_state_dict = get_last_item_from_queue(parameters_queue)
state_dict = bytes_to_state_dict(bytes_state_dict)
state_dict = move_state_dict_to_device(state_dict, device=device)
policy.load_state_dict(state_dict)
def act_with_policy(
cfg: DictConfig,
robot: Robot,
reward_classifier: nn.Module,
shutdown_event: any, # Event,
parameters_queue: Queue,
transitions_queue: Queue,
interactions_queue: Queue,
):
"""
Executes policy interaction within the environment.
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
Args:
cfg (DictConfig): Configuration settings for the interaction process.
"""
logging.info("make_env online")
online_env = make_robot_env(
robot=robot, reward_classifier=reward_classifier, cfg=cfg
)
set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online training, we do not need dataset_stats
dataset_stats=None,
# TODO: Handle resume training
device=device,
)
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
list_transition_to_send_to_learner = []
list_policy_time = []
episode_intervention = False
# Add counters for intervention rate calculation
episode_intervention_steps = 0
episode_total_steps = 0
for interaction_step in range(cfg.training.online_steps):
start_time = time.perf_counter()
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down act_with_policy")
return
if interaction_step >= cfg.training.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with TimerManager(
elapsed_time_list=list_policy_time,
label="Policy inference time",
log=False,
) as timer: # noqa: F841
action = policy.select_action(batch=obs)
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
log_policy_frequency_issue(
policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step
)
next_obs, reward, done, truncated, info = online_env.step(
action.squeeze(dim=0).cpu().numpy()
)
else:
# TODO (azouitine): Make a custom space for torch tensor
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
action = (
torch.from_numpy(action[0])
.to(device, non_blocking=device.type == "cuda")
.unsqueeze(dim=0)
)
sum_reward_episode += float(reward)
# Increment total steps counter for intervention rate
episode_total_steps += 1
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
if "is_intervention" in info and info["is_intervention"]:
# TODO: Check the shape
# NOTE: The action space for demonstration before hand is with the full action space
# but sometimes for example we want to deactivate the gripper
action = info["action_intervention"]
episode_intervention = True
# Increment intervention steps counter
episode_intervention_steps += 1
# Check for NaN values in observations
for key, tensor in obs.items():
if torch.isnan(tensor).any():
logging.error(
f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}"
)
list_transition_to_send_to_learner.append(
Transition(
state=obs,
action=action,
reward=reward,
next_state=next_obs,
done=done,
truncated=truncated, # TODO: (azouitine) Handle truncation properly
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
)
)
# assign obs to the next obs and continue the rollout
obs = next_obs
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
# Because we are using a single environment we can index at zero
if done or truncated:
# TODO: Handle logging for episode information
logging.info(
f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}"
)
update_policy_parameters(
policy=policy.actor, parameters_queue=parameters_queue, device=device
)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
transitions=list_transition_to_send_to_learner,
transitions_queue=transitions_queue,
)
list_transition_to_send_to_learner = []
stats = get_frequency_stats(list_policy_time)
list_policy_time.clear()
# Calculate intervention rate
intervention_rate = 0.0
if episode_total_steps > 0:
intervention_rate = episode_intervention_steps / episode_total_steps
# Send episodic reward to the learner
interactions_queue.put(
python_object_to_bytes(
{
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
"Episode intervention": int(episode_intervention),
"Intervention rate": intervention_rate,
**stats,
}
)
)
sum_reward_episode = 0.0
episode_intervention = False
# Reset intervention counters
episode_intervention_steps = 0
episode_total_steps = 0
obs, info = online_env.reset()
if cfg.fps is not None:
dt_time = time.perf_counter() - start_time
busy_wait(1 / cfg.fps - dt_time)
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
"""Send transitions to learner in smaller chunks to avoid network issues.
Args:
transitions: List of transitions to send
message_queue: Queue to send messages to learner
chunk_size: Size of each chunk to send
"""
transition_to_send_to_learner = []
for transition in transitions:
tr = move_transition_to_device(transition=transition, device="cpu")
for key, value in tr["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
transition_to_send_to_learner.append(tr)
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
stats = {}
list_policy_fps = [1.0 / t for t in list_policy_time]
if len(list_policy_fps) > 1:
policy_fps = mean(list_policy_fps)
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
stats = {
"Policy frequency [Hz]": policy_fps,
"Policy frequency 90th-p [Hz]": quantiles_90,
}
return stats
def log_policy_frequency_issue(
policy_fps: float, cfg: DictConfig, interaction_step: int
):
if policy_fps < cfg.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
)
def establish_learner_connection(
stub,
shutdown_event: any, # Event,
attempts=30,
):
for _ in range(attempts):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down establish_learner_connection")
return False
# Force a connection attempt and check state
try:
logging.info("[ACTOR] Send ready message to Learner")
if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty():
return True
except grpc.RpcError as e:
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
time.sleep(2)
return False
def use_threads(cfg: DictConfig) -> bool:
return cfg.actor_learner_config.concurrency.actor == "threads"
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def actor_cli(cfg: dict):
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
init_logging(log_file="actor.log")
robot = make_robot(cfg=cfg.robot)
shutdown_event = setup_process_handlers(use_threads(cfg))
learner_client, grpc_channel = learner_service_client(
host=cfg.actor_learner_config.learner_host,
port=cfg.actor_learner_config.learner_port,
)
logging.info("[ACTOR] Establishing connection with Learner")
if not establish_learner_connection(learner_client, shutdown_event):
logging.error("[ACTOR] Failed to establish connection with Learner")
return
if not use_threads(cfg):
# If we use multithreading, we can reuse the channel
grpc_channel.close()
grpc_channel = None
logging.info("[ACTOR] Connection with Learner established")
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
concurrency_entity = None
if use_threads(cfg):
from threading import Thread
concurrency_entity = Thread
else:
from multiprocessing import Process
concurrency_entity = Process
receive_policy_process = concurrency_entity(
target=receive_policy,
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
daemon=True,
)
transitions_process = concurrency_entity(
target=send_transitions,
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
daemon=True,
)
interactions_process = concurrency_entity(
target=send_interactions,
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
daemon=True,
)
transitions_process.start()
interactions_process.start()
receive_policy_process.start()
# HACK: FOR MANISKILL we do not have a reward classifier
# TODO: Remove this once we merge into main
reward_classifier = None
if (
cfg.env.reward_classifier.pretrained_path is not None
and cfg.env.reward_classifier.config_path is not None
):
reward_classifier = get_classifier(
pretrained_path=cfg.env.reward_classifier.pretrained_path,
config_path=cfg.env.reward_classifier.config_path,
)
act_with_policy(
cfg,
robot,
reward_classifier,
shutdown_event,
parameters_queue,
transitions_queue,
interactions_queue,
)
logging.info("[ACTOR] Policy process joined")
logging.info("[ACTOR] Closing queues")
transitions_queue.close()
interactions_queue.close()
parameters_queue.close()
transitions_process.join()
logging.info("[ACTOR] Transitions process joined")
interactions_process.join()
logging.info("[ACTOR] Interactions process joined")
receive_policy_process.join()
logging.info("[ACTOR] Receive policy process joined")
logging.info("[ACTOR] join queues")
transitions_queue.cancel_join_thread()
interactions_queue.cancel_join_thread()
parameters_queue.cancel_join_thread()
logging.info("[ACTOR] queues closed")
if __name__ == "__main__":
actor_cli()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,286 @@
import argparse # noqa: I001
import json
from copy import deepcopy
from typing import Dict, Tuple
from pathlib import Path
import cv2
# import torch.nn.functional as F # noqa: N812
import torchvision.transforms.functional as F # type: ignore # noqa: N812
from tqdm import tqdm # type: ignore
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def select_rect_roi(img):
"""
Allows the user to draw a rectangular ROI on the image.
The user must click and drag to draw the rectangle.
- While dragging, the rectangle is dynamically drawn.
- On mouse button release, the rectangle is fixed.
- Press 'c' to confirm the selection.
- Press 'r' to reset the selection.
- Press ESC to cancel.
Returns:
A tuple (top, left, height, width) representing the rectangular ROI,
or None if no valid ROI is selected.
"""
# Create a working copy of the image
clone = img.copy()
working_img = clone.copy()
roi = None # Will store the final ROI as (top, left, height, width)
drawing = False
ix, iy = -1, -1 # Initial click coordinates
def mouse_callback(event, x, y, flags, param):
nonlocal ix, iy, drawing, roi, working_img
if event == cv2.EVENT_LBUTTONDOWN:
# Start drawing: record starting coordinates
drawing = True
ix, iy = x, y
elif event == cv2.EVENT_MOUSEMOVE:
if drawing:
# Compute the top-left and bottom-right corners regardless of drag direction
top = min(iy, y)
left = min(ix, x)
bottom = max(iy, y)
right = max(ix, x)
# Show a temporary image with the current rectangle drawn
temp = working_img.copy()
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", temp)
elif event == cv2.EVENT_LBUTTONUP:
# Finish drawing
drawing = False
top = min(iy, y)
left = min(ix, x)
bottom = max(iy, y)
right = max(ix, x)
height = bottom - top
width = right - left
roi = (top, left, height, width) # (top, left, height, width)
# Draw the final rectangle on the working image and display it
working_img = clone.copy()
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", working_img)
# Create the window and set the callback
cv2.namedWindow("Select ROI")
cv2.setMouseCallback("Select ROI", mouse_callback)
cv2.imshow("Select ROI", working_img)
print("Instructions for ROI selection:")
print(" - Click and drag to draw a rectangular ROI.")
print(" - Press 'c' to confirm the selection.")
print(" - Press 'r' to reset and draw again.")
print(" - Press ESC to cancel the selection.")
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
while True:
key = cv2.waitKey(1) & 0xFF
# Confirm ROI if one has been drawn
if key == ord("c") and roi is not None:
break
# Reset: clear the ROI and restore the original image
elif key == ord("r"):
working_img = clone.copy()
roi = None
cv2.imshow("Select ROI", working_img)
# Cancel selection for this image
elif key == 27: # ESC key
roi = None
break
cv2.destroyWindow("Select ROI")
return roi
def select_square_roi_for_images(images: dict) -> dict:
"""
For each image in the provided dictionary, open a window to allow the user
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
(top, left, height, width) representing the ROI.
Parameters:
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
Returns:
dict: Mapping of image keys to the selected rectangular ROI.
"""
selected_rois = {}
for key, img in images.items():
if img is None:
print(f"Image for key '{key}' is None, skipping.")
continue
print(f"\nSelect rectangular ROI for image with key: '{key}'")
roi = select_rect_roi(img)
if roi is None:
print(f"No valid ROI selected for '{key}'.")
else:
selected_rois[key] = roi
print(f"ROI for '{key}': {roi}")
return selected_rois
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
"""
Find the first row in the dataset and extract the image in order to be used for the crop.
"""
row = dataset[0]
image_dict = {}
for k in row:
if "image" in k:
image_dict[k] = deepcopy(row[k])
return image_dict
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset: LeRobotDataset,
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
new_repo_id: str,
new_dataset_root: str,
resize_size: Tuple[int, int] = (128, 128),
) -> LeRobotDataset:
"""
Converts an existing LeRobotDataset by iterating over its episodes and frames,
applying cropping and resizing to image observations, and saving a new dataset
with the transformed data.
Args:
original_dataset (LeRobotDataset): The source dataset.
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
A dictionary mapping observation keys to crop parameters (top, left, height, width).
new_repo_id (str): Repository id for the new dataset.
new_dataset_root (str): The root directory where the new dataset will be written.
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
Defaults to (128, 128).
Returns:
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
and resized.
"""
# 1. Create a new (empty) LeRobotDataset for writing.
new_dataset = LeRobotDataset.create(
repo_id=new_repo_id,
fps=original_dataset.fps,
root=new_dataset_root,
robot_type=original_dataset.meta.robot_type,
features=original_dataset.meta.info["features"],
use_videos=len(original_dataset.meta.video_keys) > 0,
)
# Update the metadata for every image key that will be cropped:
# (Here we simply set the shape to be the final resize_size.)
for key in crop_params_dict:
if key in new_dataset.meta.info["features"]:
new_dataset.meta.info["features"][key]["shape"] = list(resize_size)
# 2. Process each episode in the original dataset.
episodes_info = original_dataset.meta.episodes
# (Sort episodes by episode_index for consistency.)
episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"])
# Use the first task from the episode metadata (or "unknown" if not provided)
task = episodes_info[0]["tasks"][0] if episodes_info[0].get("tasks") else "unknown"
last_episode_index = 0
for sample in tqdm(original_dataset):
episode_index = sample.pop("episode_index")
if episode_index != last_episode_index:
new_dataset.save_episode(task, encode_videos=True)
last_episode_index = episode_index
sample.pop("frame_index")
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
new_sample = sample.copy()
# Loop over each observation key that should be cropped/resized.
for key, params in crop_params_dict.items():
if key in new_sample:
top, left, height, width = params
# Apply crop then resize.
cropped = F.crop(new_sample[key], top, left, height, width)
resized = F.resize(cropped, resize_size)
new_sample[key] = resized
# Add the transformed frame to the new dataset.
new_dataset.add_frame(new_sample)
# save last episode
new_dataset.save_episode(task, encode_videos=True)
# Optionally, consolidate the new dataset to compute statistics and update video info.
new_dataset.consolidate(run_compute_stats=True, keep_image_files=True)
new_dataset.push_to_hub(tags=None)
return new_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Crop rectangular ROIs from a LeRobot dataset."
)
parser.add_argument(
"--repo-id",
type=str,
default="lerobot",
help="The repository id of the LeRobot dataset to process.",
)
parser.add_argument(
"--root",
type=str,
default=None,
help="The root directory of the LeRobot dataset.",
)
parser.add_argument(
"--crop-params-path",
type=str,
default=None,
help="The path to the JSON file containing the ROIs.",
)
args = parser.parse_args()
local_files_only = args.root is not None
dataset = LeRobotDataset(
repo_id=args.repo_id, root=args.root, local_files_only=local_files_only
)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
if args.crop_params_path is None:
rois = select_square_roi_for_images(images)
else:
with open(args.crop_params_path) as f:
rois = json.load(f)
# Print the selected rectangular ROIs
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
for key, roi in rois.items():
print(f"{key}: {roi}")
new_repo_id = args.repo_id + "_cropped_resized"
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset=dataset,
crop_params_dict=rois,
new_repo_id=new_repo_id,
new_dataset_root=new_dataset_root,
resize_size=(128, 128),
)
meta_dir = new_dataset_root / "meta"
meta_dir.mkdir(exist_ok=True)
with open(meta_dir / "crop_params.json", "w") as f:
json.dump(rois, f, indent=4)

View File

@@ -0,0 +1,797 @@
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.scripts.server.kinematics import RobotKinematics
import logging
import time
import torch
import numpy as np
import argparse
logging.basicConfig(level=logging.INFO)
class InputController:
"""Base class for input controllers that generate motion deltas."""
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
"""
Initialize the controller.
Args:
x_step_size: Base movement step size in meters
y_step_size: Base movement step size in meters
z_step_size: Base movement step size in meters
"""
self.x_step_size = x_step_size
self.y_step_size = y_step_size
self.z_step_size = z_step_size
self.running = True
self.episode_end_status = None # None, "success", or "failure"
def start(self):
"""Start the controller and initialize resources."""
pass
def stop(self):
"""Stop the controller and release resources."""
pass
def get_deltas(self):
"""Get the current movement deltas (dx, dy, dz) in meters."""
return 0.0, 0.0, 0.0
def should_quit(self):
"""Return True if the user has requested to quit."""
return not self.running
def update(self):
"""Update controller state - call this once per frame."""
pass
def __enter__(self):
"""Support for use in 'with' statements."""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Ensure resources are released when exiting 'with' block."""
self.stop()
def get_episode_end_status(self):
"""
Get the current episode end status.
Returns:
None if episode should continue, "success" or "failure" otherwise
"""
status = self.episode_end_status
self.episode_end_status = None # Reset after reading
return status
class KeyboardController(InputController):
"""Generate motion deltas from keyboard input."""
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01):
super().__init__(x_step_size, y_step_size, z_step_size)
self.key_states = {
"forward_x": False,
"backward_x": False,
"forward_y": False,
"backward_y": False,
"forward_z": False,
"backward_z": False,
"quit": False,
"success": False,
"failure": False,
}
self.listener = None
def start(self):
"""Start the keyboard listener."""
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.up:
self.key_states["forward_x"] = True
elif key == keyboard.Key.down:
self.key_states["backward_x"] = True
elif key == keyboard.Key.left:
self.key_states["forward_y"] = True
elif key == keyboard.Key.right:
self.key_states["backward_y"] = True
elif key == keyboard.Key.shift:
self.key_states["backward_z"] = True
elif key == keyboard.Key.shift_r:
self.key_states["forward_z"] = True
elif key == keyboard.Key.esc:
self.key_states["quit"] = True
self.running = False
return False
elif key == keyboard.Key.enter:
self.key_states["success"] = True
self.episode_end_status = "success"
elif key == keyboard.Key.backspace:
self.key_states["failure"] = True
self.episode_end_status = "failure"
except AttributeError:
pass
def on_release(key):
try:
if key == keyboard.Key.up:
self.key_states["forward_x"] = False
elif key == keyboard.Key.down:
self.key_states["backward_x"] = False
elif key == keyboard.Key.left:
self.key_states["forward_y"] = False
elif key == keyboard.Key.right:
self.key_states["backward_y"] = False
elif key == keyboard.Key.shift:
self.key_states["backward_z"] = False
elif key == keyboard.Key.shift_r:
self.key_states["forward_z"] = False
elif key == keyboard.Key.enter:
self.key_states["success"] = False
elif key == keyboard.Key.backspace:
self.key_states["failure"] = False
except AttributeError:
pass
self.listener = keyboard.Listener(on_press=on_press, on_release=on_release)
self.listener.start()
print("Keyboard controls:")
print(" Arrow keys: Move in X-Y plane")
print(" Shift and Shift_R: Move in Z axis")
print(" Enter: End episode with SUCCESS")
print(" Backspace: End episode with FAILURE")
print(" ESC: Exit")
def stop(self):
"""Stop the keyboard listener."""
if self.listener and self.listener.is_alive():
self.listener.stop()
def get_deltas(self):
"""Get the current movement deltas from keyboard state."""
delta_x = delta_y = delta_z = 0.0
if self.key_states["forward_x"]:
delta_x += self.x_step_size
if self.key_states["backward_x"]:
delta_x -= self.x_step_size
if self.key_states["forward_y"]:
delta_y += self.y_step_size
if self.key_states["backward_y"]:
delta_y -= self.y_step_size
if self.key_states["forward_z"]:
delta_z += self.z_step_size
if self.key_states["backward_z"]:
delta_z -= self.z_step_size
return delta_x, delta_y, delta_z
def should_quit(self):
"""Return True if ESC was pressed."""
return self.key_states["quit"]
def should_save(self):
"""Return True if Enter was pressed (save episode)."""
return self.key_states["success"] or self.key_states["failure"]
class GamepadController(InputController):
"""Generate motion deltas from gamepad input."""
def __init__(
self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1
):
super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone
self.joystick = None
self.intervention_flag = False
def start(self):
"""Initialize pygame and the gamepad."""
import pygame
pygame.init()
pygame.joystick.init()
if pygame.joystick.get_count() == 0:
logging.error(
"No gamepad detected. Please connect a gamepad and try again."
)
self.running = False
return
self.joystick = pygame.joystick.Joystick(0)
self.joystick.init()
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
print("Gamepad controls:")
print(" Left analog stick: Move in X-Y plane")
print(" Right analog stick (vertical): Move in Z axis")
print(" B/Circle button: Exit")
print(" Y/Triangle button: End episode with SUCCESS")
print(" A/Cross button: End episode with FAILURE")
print(" X/Square button: Rerecord episode")
def stop(self):
"""Clean up pygame resources."""
import pygame
if pygame.joystick.get_init():
if self.joystick:
self.joystick.quit()
pygame.joystick.quit()
pygame.quit()
def update(self):
"""Process pygame events to get fresh gamepad readings."""
import pygame
for event in pygame.event.get():
if event.type == pygame.JOYBUTTONDOWN:
if event.button == 3:
self.episode_end_status = "success"
# A button (1) for failure
elif event.button == 1:
self.episode_end_status = "failure"
# X button (0) for rerecord
elif event.button == 0:
self.episode_end_status = "rerecord_episode"
# Reset episode status on button release
elif event.type == pygame.JOYBUTTONUP:
if event.button in [0, 2, 3]:
self.episode_end_status = None
# Check for RB button (typically button 5) for intervention flag
if self.joystick.get_button(5):
self.intervention_flag = True
else:
self.intervention_flag = False
def get_deltas(self):
"""Get the current movement deltas from gamepad state."""
import pygame
try:
# Read joystick axes
# Left stick X and Y (typically axes 0 and 1)
x_input = self.joystick.get_axis(0) # Left/Right
y_input = self.joystick.get_axis(1) # Up/Down (often inverted)
# Right stick Y (typically axis 3 or 4)
z_input = self.joystick.get_axis(3) # Up/Down for Z
# Apply deadzone to avoid drift
x_input = 0 if abs(x_input) < self.deadzone else x_input
y_input = 0 if abs(y_input) < self.deadzone else y_input
z_input = 0 if abs(z_input) < self.deadzone else z_input
# Calculate deltas (note: may need to invert axes depending on controller)
delta_x = -y_input * self.y_step_size # Forward/backward
delta_y = -x_input * self.x_step_size # Left/right
delta_z = -z_input * self.z_step_size # Up/down
return delta_x, delta_y, delta_z
except pygame.error:
logging.error("Error reading gamepad. Is it still connected?")
return 0.0, 0.0, 0.0
def should_intervene(self):
"""Return True if intervention flag was set."""
return self.intervention_flag
class GamepadControllerHID(InputController):
"""Generate motion deltas from gamepad input using HIDAPI."""
def __init__(
self,
x_step_size=0.01,
y_step_size=0.01,
z_step_size=0.01,
deadzone=0.1,
vendor_id=0x046D,
product_id=0xC219,
):
"""
Initialize the HID gamepad controller.
Args:
step_size: Base movement step size in meters
z_scale: Scaling factor for Z-axis movement
deadzone: Joystick deadzone to prevent drift
vendor_id: USB vendor ID of the gamepad (default: Logitech)
product_id: USB product ID of the gamepad (default: RumblePad 2)
"""
super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone
self.vendor_id = vendor_id
self.product_id = product_id
self.device = None
self.device_info = None
# Movement values (normalized from -1.0 to 1.0)
self.left_x = 0.0
self.left_y = 0.0
self.right_x = 0.0
self.right_y = 0.0
# Button states
self.buttons = {}
self.quit_requested = False
self.save_requested = False
self.intervention_flag = False
def find_device(self):
"""Look for the gamepad device by vendor and product ID."""
import hid
devices = hid.enumerate()
for device in devices:
if (
device["vendor_id"] == self.vendor_id
and device["product_id"] == self.product_id
):
logging.info(
f"Found gamepad: {device.get('product_string', 'Unknown')}"
)
return device
logging.error(
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and "
f"product ID 0x{self.product_id:04X} found"
)
return None
def start(self):
"""Connect to the gamepad using HIDAPI."""
import hid
self.device_info = self.find_device()
if not self.device_info:
self.running = False
return
try:
logging.info(f"Connecting to gamepad at path: {self.device_info['path']}")
self.device = hid.device()
self.device.open_path(self.device_info["path"])
self.device.set_nonblocking(1)
manufacturer = self.device.get_manufacturer_string()
product = self.device.get_product_string()
logging.info(f"Connected to {manufacturer} {product}")
logging.info("Gamepad controls (HID mode):")
logging.info(" Left analog stick: Move in X-Y plane")
logging.info(" Right analog stick: Move in Z axis (vertical)")
logging.info(" Button 1/B/Circle: Exit")
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
logging.info(" Button 3/X/Square: End episode with FAILURE")
except OSError as e:
logging.error(f"Error opening gamepad: {e}")
logging.error(
"You might need to run this with sudo/admin privileges on some systems"
)
self.running = False
def stop(self):
"""Close the HID device connection."""
if self.device:
self.device.close()
self.device = None
def update(self):
"""
Read and process the latest gamepad data.
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
"""
for _ in range(10):
self._update()
def _update(self):
"""Read and process the latest gamepad data."""
if not self.device or not self.running:
return
try:
# Read data from the gamepad
data = self.device.read(64)
if data:
# Interpret gamepad data - this will vary by controller model
# These offsets are for the Logitech RumblePad 2
if len(data) >= 8:
# Normalize joystick values from 0-255 to -1.0-1.0
self.left_x = (data[1] - 128) / 128.0
self.left_y = (data[2] - 128) / 128.0
self.right_x = (data[3] - 128) / 128.0
self.right_y = (data[4] - 128) / 128.0
# Apply deadzone
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
self.right_x = (
0 if abs(self.right_x) < self.deadzone else self.right_x
)
self.right_y = (
0 if abs(self.right_y) < self.deadzone else self.right_y
)
# Parse button states (byte 5 in the Logitech RumblePad 2)
buttons = data[5]
# Check if RB is pressed then the intervention flag should be set
self.intervention_flag = data[6] == 2
# Check if Y/Triangle button (bit 7) is pressed for saving
# Check if X/Square button (bit 5) is pressed for failure
# Check if A/Cross button (bit 4) is pressed for rerecording
if buttons & 1 << 7:
self.episode_end_status = "success"
elif buttons & 1 << 5:
self.episode_end_status = "failure"
elif buttons & 1 << 4:
self.episode_end_status = "rerecord_episode"
else:
self.episode_end_status = None
except OSError as e:
logging.error(f"Error reading from gamepad: {e}")
def get_deltas(self):
"""Get the current movement deltas from gamepad state."""
# Calculate deltas - invert as needed based on controller orientation
delta_x = -self.left_y * self.x_step_size # Forward/backward
delta_y = -self.left_x * self.y_step_size # Left/right
delta_z = -self.right_y * self.z_step_size # Up/down
return delta_x, delta_y, delta_z
def should_quit(self):
"""Return True if quit button was pressed."""
return self.quit_requested
def should_save(self):
"""Return True if save button was pressed."""
return self.save_requested
def should_intervene(self):
"""Return True if intervention flag was set."""
return self.intervention_flag
def test_forward_kinematics(robot, fps=10):
logging.info("Testing Forward Kinematics")
timestep = time.perf_counter()
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
robot.teleop_step()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
logging.info(f"EE Position: {ee_pos[:3,3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def test_inverse_kinematics(robot, fps=10):
logging.info("Testing Inverse Kinematics")
timestep = time.perf_counter()
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
desired_ee_pos = ee_pos
target_joint_state = RobotKinematics.ik(
joint_positions, desired_ee_pos, position_only=True
)
robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Target Joint State: {target_joint_state}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_inverse_kinematics_with_leader(robot, fps=10):
logging.info("Testing Inverse Kinematics")
fk_func = RobotKinematics.fk_gripper_tip
timestep = time.perf_counter()
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = fk_func(joint_positions)
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = fk_func(leader_joint_positions)
desired_ee_pos = leader_ee
target_joint_state = RobotKinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Leader EE: {leader_ee[:3,3]}, Follower EE: {ee_pos[:3,3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
logging.info("Testing Delta End-Effector Control")
timestep = time.perf_counter()
# Initial position capture
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
fk_func = RobotKinematics.fk_gripper_tip
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
initial_leader_ee = fk_func(leader_joint_positions)
desired_ee_pos = np.diag(np.ones(4))
while time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
# Get leader state for teleoperation
leader_joint_positions = robot.leader_arms["main"].read("Present_Position")
leader_ee = fk_func(leader_joint_positions)
# Get current state
# obs = robot.capture_observation()
# joint_positions = obs["observation.state"].cpu().numpy()
joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = fk_func(joint_positions)
# Calculate delta between leader and follower end-effectors
# Scaling factor can be adjusted for sensitivity
scaling_factor = 1.0
ee_delta = (leader_ee - initial_leader_ee) * scaling_factor
# Apply delta to current position
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + ee_delta[0, 3]
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + ee_delta[1, 3]
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + ee_delta[2, 3]
if np.any(np.abs(ee_delta[:3, 3]) > 0.01):
# Compute joint targets via inverse kinematics
target_joint_state = RobotKinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
initial_leader_ee = leader_ee.copy()
# Send command to robot
robot.send_action(torch.from_numpy(target_joint_state))
# Logging
logging.info(
f"Current EE: {current_ee_pos[:3,3]}, Desired EE: {desired_ee_pos[:3,3]}"
)
logging.info(f"Delta EE: {ee_delta[:3,3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_delta_inverse_kinematics(
robot, controller, fps=10, bounds=None, fk_func=None
):
"""
Control a robot using delta end-effector movements from any input controller.
Args:
robot: Robot instance to control
controller: InputController instance (keyboard, gamepad, etc.)
fps: Control frequency in Hz
bounds: Optional position limits
fk_func: Forward kinematics function to use
"""
if fk_func is None:
fk_func = RobotKinematics.fk_gripper_tip
logging.info(
f"Testing Delta End-Effector Control with {controller.__class__.__name__}"
)
# Initial position capture
obs = robot.capture_observation()
joint_positions = obs["observation.state"].cpu().numpy()
current_ee_pos = fk_func(joint_positions)
# Initialize desired position with current position
desired_ee_pos = np.eye(4) # Identity matrix
timestep = time.perf_counter()
with controller:
while not controller.should_quit() and time.perf_counter() - timestep < 60.0:
loop_start_time = time.perf_counter()
# Process input events
controller.update()
# Get currrent robot state
joint_positions = robot.follower_arms["main"].read("Present_Position")
current_ee_pos = fk_func(joint_positions)
# Get movement deltas from the controller
delta_x, delta_y, delta_z = controller.get_deltas()
# Update desired position
desired_ee_pos[0, 3] = current_ee_pos[0, 3] + delta_x
desired_ee_pos[1, 3] = current_ee_pos[1, 3] + delta_y
desired_ee_pos[2, 3] = current_ee_pos[2, 3] + delta_z
# Apply bounds if provided
if bounds is not None:
desired_ee_pos[:3, 3] = np.clip(
desired_ee_pos[:3, 3], bounds["min"], bounds["max"]
)
# Only send commands if there's actual movement
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
# Compute joint targets via inverse kinematics
target_joint_state = RobotKinematics.ik(
joint_positions, desired_ee_pos, position_only=True, fk_func=fk_func
)
# Send command to robot
robot.send_action(torch.from_numpy(target_joint_state))
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_gym_env(env, controller, fps: int = 30):
"""
Control a robot through a gym environment using keyboard inputs.
Args:
env: A gym environment created with make_robot_env
fps: Target control frequency
"""
logging.info("Testing Keyboard Control of Gym Environment")
print("Keyboard controls:")
print(" Arrow keys: Move in X-Y plane")
print(" Shift and Shift_R: Move in Z axis")
print(" ESC: Exit")
# Reset the environment to get initial observation
obs, info = env.reset()
try:
with controller:
while not controller.should_quit():
loop_start_time = time.perf_counter()
# Process input events
controller.update()
# Get movement deltas from the controller
delta_x, delta_y, delta_z = controller.get_deltas()
# Create the action vector
action = np.array([delta_x, delta_y, delta_z])
# Skip if no movement
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
# Step the environment - pass action as a tensor with intervention flag
action_tensor = torch.from_numpy(action.astype(np.float32))
obs, reward, terminated, truncated, info = env.step(
(action_tensor, False)
)
# Log information
logging.info(
f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]"
)
logging.info(f"Reward: {reward}")
# Reset if episode ended
if terminated or truncated:
logging.info("Episode ended, resetting environment")
obs, info = env.reset()
# Maintain target frame rate
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
finally:
# Close the environment
env.close()
def make_robot_from_config(config_path, overrides=None):
"""Helper function to create a robot from a config file."""
if overrides is None:
overrides = []
robot_cfg = init_hydra_config(config_path, overrides)
return make_robot(robot_cfg)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument(
"--mode",
type=str,
default="keyboard",
choices=[
"keyboard",
"gamepad",
"keyboard_gym",
"gamepad_gym",
"leader",
"leader_abs",
],
help="Control mode to use",
)
parser.add_argument(
"--task",
type=str,
default="Robot manipulation task",
help="Description of the task being performed",
)
parser.add_argument(
"--push-to-hub",
default=True,
type=bool,
help="Push the dataset to Hugging Face Hub",
)
# Add the rest of your existing arguments
args = parser.parse_args()
robot = make_robot_from_config("lerobot/configs/robot/so100.yaml", [])
if not robot.is_connected:
robot.connect()
# Example bounds
bounds = {
"max": np.array([0.32170487, 0.201285, 0.10273342]),
"min": np.array([0.16631757, -0.08237468, 0.03364977]),
}
try:
# Determine controller type based on mode prefix
controller = None
if args.mode.startswith("keyboard"):
controller = KeyboardController(
x_step_size=0.01, y_step_size=0.01, z_step_size=0.05
)
elif args.mode.startswith("gamepad"):
controller = GamepadController(
x_step_size=0.02, y_step_size=0.02, z_step_size=0.05
)
# Handle mode categories
if args.mode in ["keyboard", "gamepad"]:
# Direct robot control modes
teleoperate_delta_inverse_kinematics(
robot, controller, bounds=bounds, fps=10
)
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
# Gym environment control modes
from lerobot.scripts.server.gym_manipulator import make_robot_env
cfg = init_hydra_config("lerobot/configs/env/so100_real.yaml", [])
cfg.env.wrapper.ee_action_space_params.use_gamepad = False
env = make_robot_env(robot, None, cfg)
teleoperate_gym_env(env, controller)
elif args.mode == "leader":
# Leader-follower modes don't use controllers
teleoperate_delta_inverse_kinematics_with_leader(robot)
elif args.mode == "leader_abs":
teleoperate_inverse_kinematics_with_leader(robot)
finally:
if robot.is_connected:
robot.disconnect()

View File

@@ -0,0 +1,121 @@
import argparse
import time
import cv2
import numpy as np
from lerobot.common.robot_devices.control_utils import is_headless
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.server.kinematics import RobotKinematics
def find_joint_bounds(
robot,
control_time_s=30,
display_cameras=False,
):
if not robot.is_connected:
robot.connect()
start_episode_t = time.perf_counter()
pos_list = []
while True:
observation, action = robot.teleop_step(record_data=True)
# Wait for 5 seconds to stabilize the robot initial position
if time.perf_counter() - start_episode_t < 5:
continue
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
if time.perf_counter() - start_episode_t > control_time_s:
max = np.max(np.stack(pos_list), 0)
min = np.min(np.stack(pos_list), 0)
print(f"Max angle position per joint {max}")
print(f"Min angle position per joint {min}")
break
def find_ee_bounds(
robot,
control_time_s=30,
display_cameras=False,
):
if not robot.is_connected:
robot.connect()
start_episode_t = time.perf_counter()
ee_list = []
while True:
observation, action = robot.teleop_step(record_data=True)
# Wait for 5 seconds to stabilize the robot initial position
if time.perf_counter() - start_episode_t < 5:
continue
joint_positions = robot.follower_arms["main"].read("Present_Position")
print(f"Joint positions: {joint_positions}")
ee_list.append(RobotKinematics.fk_gripper_tip(joint_positions)[:3, 3])
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
if time.perf_counter() - start_episode_t > control_time_s:
max = np.max(np.stack(ee_list), 0)
min = np.min(np.stack(ee_list), 0)
print(f"Max ee position {max}")
print(f"Min ee position {min}")
break
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument(
"--mode",
type=str,
default="joint",
choices=["joint", "ee"],
help="Mode to run the script in. Can be 'joint' or 'ee'.",
)
parser.add_argument(
"--control-time-s",
type=int,
default=30,
help="Time step to use for control.",
)
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
robot = make_robot(robot_cfg)
if args.mode == "joint":
find_joint_bounds(robot, args.control_time_s)
elif args.mode == "ee":
find_ee_bounds(robot, args.control_time_s)
if robot.is_connected:
robot.disconnect()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,55 @@
// !/usr/bin/env python
// Copyright 2024 The HuggingFace Inc. team.
// All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package hil_serl;
// LearnerService: the Actor calls this to push transitions.
// The Learner implements this service.
service LearnerService {
// Actor -> Learner to store transitions
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
rpc StreamParameters(Empty) returns (stream Parameters);
rpc SendTransitions(stream Transition) returns (Empty);
rpc SendInteractions(stream InteractionMessage) returns (Empty);
rpc Ready(Empty) returns (Empty);
}
enum TransferState {
TRANSFER_UNKNOWN = 0;
TRANSFER_BEGIN = 1;
TRANSFER_MIDDLE = 2;
TRANSFER_END = 3;
}
// Messages
message Transition {
TransferState transfer_state = 1;
bytes data = 2;
}
message Parameters {
TransferState transfer_state = 1;
bytes data = 2;
}
message InteractionMessage {
TransferState transfer_state = 1;
bytes data = 2;
}
message Empty {}

View File

@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: hilserl.proto
# Protobuf Python Version: 5.29.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
29,
0,
'',
'hilserl.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TRANSFERSTATE']._serialized_start=275
_globals['_TRANSFERSTATE']._serialized_end=371
_globals['_TRANSITION']._serialized_start=27
_globals['_TRANSITION']._serialized_end=102
_globals['_PARAMETERS']._serialized_start=104
_globals['_PARAMETERS']._serialized_end=179
_globals['_INTERACTIONMESSAGE']._serialized_start=181
_globals['_INTERACTIONMESSAGE']._serialized_end=264
_globals['_EMPTY']._serialized_start=266
_globals['_EMPTY']._serialized_end=273
_globals['_LEARNERSERVICE']._serialized_start=374
_globals['_LEARNERSERVICE']._serialized_end=696
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,276 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
import hilserl_pb2 as hilserl__pb2
GRPC_GENERATED_VERSION = '1.70.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in hilserl_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class LearnerServiceStub(object):
"""LearnerService: the Actor calls this to push transitions.
The Learner implements this service.
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.SendInteractionMessage = channel.unary_unary(
'/hil_serl.LearnerService/SendInteractionMessage',
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.StreamParameters = channel.unary_stream(
'/hil_serl.LearnerService/StreamParameters',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.Parameters.FromString,
_registered_method=True)
self.SendTransitions = channel.stream_unary(
'/hil_serl.LearnerService/SendTransitions',
request_serializer=hilserl__pb2.Transition.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.SendInteractions = channel.stream_unary(
'/hil_serl.LearnerService/SendInteractions',
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
self.Ready = channel.unary_unary(
'/hil_serl.LearnerService/Ready',
request_serializer=hilserl__pb2.Empty.SerializeToString,
response_deserializer=hilserl__pb2.Empty.FromString,
_registered_method=True)
class LearnerServiceServicer(object):
"""LearnerService: the Actor calls this to push transitions.
The Learner implements this service.
"""
def SendInteractionMessage(self, request, context):
"""Actor -> Learner to store transitions
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def StreamParameters(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendTransitions(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SendInteractions(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Ready(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_LearnerServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
servicer.SendInteractionMessage,
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'StreamParameters': grpc.unary_stream_rpc_method_handler(
servicer.StreamParameters,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.Parameters.SerializeToString,
),
'SendTransitions': grpc.stream_unary_rpc_method_handler(
servicer.SendTransitions,
request_deserializer=hilserl__pb2.Transition.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'SendInteractions': grpc.stream_unary_rpc_method_handler(
servicer.SendInteractions,
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
'Ready': grpc.unary_unary_rpc_method_handler(
servicer.Ready,
request_deserializer=hilserl__pb2.Empty.FromString,
response_serializer=hilserl__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'hil_serl.LearnerService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('hil_serl.LearnerService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class LearnerService(object):
"""LearnerService: the Actor calls this to push transitions.
The Learner implements this service.
"""
@staticmethod
def SendInteractionMessage(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.LearnerService/SendInteractionMessage',
hilserl__pb2.InteractionMessage.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def StreamParameters(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(
request,
target,
'/hil_serl.LearnerService/StreamParameters',
hilserl__pb2.Empty.SerializeToString,
hilserl__pb2.Parameters.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendTransitions(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/hil_serl.LearnerService/SendTransitions',
hilserl__pb2.Transition.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SendInteractions(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(
request_iterator,
target,
'/hil_serl.LearnerService/SendInteractions',
hilserl__pb2.InteractionMessage.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def Ready(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/hil_serl.LearnerService/Ready',
hilserl__pb2.Empty.SerializeToString,
hilserl__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@@ -0,0 +1,543 @@
import numpy as np
from scipy.spatial.transform import Rotation
def skew_symmetric(w):
"""Creates the skew-symmetric matrix from a 3D vector."""
return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]])
def rodrigues_rotation(w, theta):
"""Computes the rotation matrix using Rodrigues' formula."""
w_hat = skew_symmetric(w)
return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
def screw_axis_to_transform(S, theta):
"""Converts a screw axis to a 4x4 transformation matrix."""
S_w = S[:3]
S_v = S[3:]
if np.allclose(S_w, 0) and np.linalg.norm(S_v) == 1: # Pure translation
T = np.eye(4)
T[:3, 3] = S_v * theta
elif np.linalg.norm(S_w) == 1: # Rotation and translation
w_hat = skew_symmetric(S_w)
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
t = (
np.eye(3) * theta
+ (1 - np.cos(theta)) * w_hat
+ (theta - np.sin(theta)) * w_hat @ w_hat
) @ S_v
T = np.eye(4)
T[:3, :3] = R
T[:3, 3] = t
else:
raise ValueError("Invalid screw axis parameters")
return T
def pose_difference_se3(pose1, pose2):
"""
Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices.
pose1 - pose2
Args:
pose1: A 4x4 numpy array representing the first pose.
pose2: A 4x4 numpy array representing the second pose.
Returns:
A tuple (translation_diff, rotation_diff) where:
- translation_diff is a 3x1 numpy array representing the translational difference.
- rotation_diff is a 3x1 numpy array representing the rotational difference in axis-angle representation.
"""
# Extract rotation matrices from poses
R1 = pose1[:3, :3]
R2 = pose2[:3, :3]
# Calculate translational difference
translation_diff = pose1[:3, 3] - pose2[:3, 3]
# Calculate rotational difference using scipy's Rotation library
R_diff = Rotation.from_matrix(R1 @ R2.T)
rotation_diff = R_diff.as_rotvec() # Convert to axis-angle representation
return np.concatenate([translation_diff, rotation_diff])
def se3_error(target_pose, current_pose):
pos_error = target_pose[:3, 3] - current_pose[:3, 3]
R_target = target_pose[:3, :3]
R_current = current_pose[:3, :3]
R_error = R_target @ R_current.T
rot_error = Rotation.from_matrix(R_error).as_rotvec()
return np.concatenate([pos_error, rot_error])
class RobotKinematics:
"""Robot kinematics class supporting multiple robot models."""
# Robot measurements dictionary
ROBOT_MEASUREMENTS = {
"koch": {
"gripper": [0.239, -0.001, 0.024],
"wrist": [0.209, 0, 0.024],
"forearm": [0.108, 0, 0.02],
"humerus": [0, 0, 0.036],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
"so100": {
"gripper": [0.320, 0, 0.050],
"wrist": [0.278, 0, 0.050],
"forearm": [0.143, 0, 0.044],
"humerus": [0.031, 0, 0.072],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
"moss": {
"gripper": [0.246, 0.013, 0.111],
"wrist": [0.245, 0.002, 0.064],
"forearm": [0.122, 0, 0.064],
"humerus": [0.001, 0.001, 0.063],
"shoulder": [0, 0, 0],
"base": [0, 0, 0.02],
},
}
def __init__(self, robot_type="so100"):
"""Initialize kinematics for the specified robot type.
Args:
robot_type: String specifying the robot model ("koch", "so100", or "moss")
"""
if robot_type not in self.ROBOT_MEASUREMENTS:
raise ValueError(
f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}"
)
self.robot_type = robot_type
self.measurements = self.ROBOT_MEASUREMENTS[robot_type]
# Initialize all transformation matrices and screw axes
self._setup_transforms()
def _create_translation_matrix(self, x=0, y=0, z=0):
"""Create a 4x4 translation matrix."""
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]])
def _setup_transforms(self):
"""Setup all transformation matrices and screw axes for the robot."""
# Set up rotation matrices (constant across robot types)
# Gripper orientation
self.gripper_X0 = np.array(
[
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, -1, 0, 0],
[0, 0, 0, 1],
]
)
# Wrist orientation
self.wrist_X0 = np.array(
[
[0, -1, 0, 0],
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
)
# Base orientation
self.base_X0 = np.array(
[
[0, 0, 1, 0],
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
]
)
# Gripper
# Screw axis of gripper frame wrt base frame
self.S_BG = np.array(
[
1,
0,
0,
0,
self.measurements["gripper"][2],
-self.measurements["gripper"][1],
]
)
# Gripper origin to centroid transform
self.X_GoGc = self._create_translation_matrix(x=0.07)
# Gripper origin to tip transform
self.X_GoGt = self._create_translation_matrix(x=0.12)
# 0-position gripper frame pose wrt base
self.X_BoGo = self._create_translation_matrix(
x=self.measurements["gripper"][0],
y=self.measurements["gripper"][1],
z=self.measurements["gripper"][2],
)
# Wrist
# Screw axis of wrist frame wrt base frame
self.S_BR = np.array(
[0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]]
)
# 0-position origin to centroid transform
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
# 0-position wrist frame pose wrt base
self.X_BR = self._create_translation_matrix(
x=self.measurements["wrist"][0],
y=self.measurements["wrist"][1],
z=self.measurements["wrist"][2],
)
# Forearm
# Screw axis of forearm frame wrt base frame
self.S_BF = np.array(
[
0,
1,
0,
-self.measurements["forearm"][2],
0,
self.measurements["forearm"][0],
]
)
# Forearm origin + centroid transform
self.X_FoFc = self._create_translation_matrix(x=0.036)
# 0-position forearm frame pose wrt base
self.X_BF = self._create_translation_matrix(
x=self.measurements["forearm"][0],
y=self.measurements["forearm"][1],
z=self.measurements["forearm"][2],
)
# Humerus
# Screw axis of humerus frame wrt base frame
self.S_BH = np.array(
[
0,
-1,
0,
self.measurements["humerus"][2],
0,
-self.measurements["humerus"][0],
]
)
# Humerus origin to centroid transform
self.X_HoHc = self._create_translation_matrix(x=0.0475)
# 0-position humerus frame pose wrt base
self.X_BH = self._create_translation_matrix(
x=self.measurements["humerus"][0],
y=self.measurements["humerus"][1],
z=self.measurements["humerus"][2],
)
# Shoulder
# Screw axis of shoulder frame wrt Base frame
self.S_BS = np.array([0, 0, -1, 0, 0, 0])
# Shoulder origin to centroid transform
self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235)
# 0-position shoulder frame pose wrt base
self.X_BS = self._create_translation_matrix(
x=self.measurements["shoulder"][0],
y=self.measurements["shoulder"][1],
z=self.measurements["shoulder"][2],
)
# Base
# Base origin to centroid transform
self.X_BoBc = self._create_translation_matrix(y=0.015)
# World to base transform
self.X_WoBo = self._create_translation_matrix(
x=self.measurements["base"][0],
y=self.measurements["base"][1],
z=self.measurements["base"][2],
)
# Pre-compute gripper post-multiplication matrix
self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0
def fk_base(self):
"""Forward kinematics for the base frame."""
return self.X_WoBo @ self.X_BoBc @ self.base_X0
def fk_shoulder(self, robot_pos_deg):
"""Forward kinematics for the shoulder frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ self.X_SoSc
@ self.X_BS
)
def fk_humerus(self, robot_pos_deg):
"""Forward kinematics for the humerus frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ self.X_HoHc
@ self.X_BH
)
def fk_forearm(self, robot_pos_deg):
"""Forward kinematics for the forearm frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ self.X_FoFc
@ self.X_BF
)
def fk_wrist(self, robot_pos_deg):
"""Forward kinematics for the wrist frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ self.X_RoRc
@ self.X_BR
@ self.wrist_X0
)
def fk_gripper(self, robot_pos_deg):
"""Forward kinematics for the gripper frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
@ self._fk_gripper_post
)
def fk_gripper_tip(self, robot_pos_deg):
"""Forward kinematics for the gripper tip frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ screw_axis_to_transform(self.S_BH, robot_pos_rad[1])
@ screw_axis_to_transform(self.S_BF, robot_pos_rad[2])
@ screw_axis_to_transform(self.S_BR, robot_pos_rad[3])
@ screw_axis_to_transform(self.S_BG, robot_pos_rad[4])
@ self.X_GoGt
@ self.X_BoGo
@ self.gripper_X0
)
def compute_jacobian(self, robot_pos_deg, fk_func=None):
"""Finite differences to compute the Jacobian.
J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change
in the jth joint's velocity.
Args:
robot_pos_deg: Current joint positions in degrees
fk_func: Forward kinematics function to use (defaults to fk_gripper)
"""
if fk_func is None:
fk_func = self.fk_gripper
eps = 1e-8
jac = np.zeros(shape=(6, 5))
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
for el_ix in range(len(robot_pos_deg[:-1])):
delta *= 0
delta[el_ix] = eps / 2
Sdot = (
pose_difference_se3(
fk_func(robot_pos_deg[:-1] + delta),
fk_func(robot_pos_deg[:-1] - delta),
)
/ eps
)
jac[:, el_ix] = Sdot
return jac
def compute_positional_jacobian(self, robot_pos_deg, fk_func=None):
"""Finite differences to compute the positional Jacobian.
J(i, j) represents how the ith component of the end-effector's position changes wrt a small change
in the jth joint's velocity.
Args:
robot_pos_deg: Current joint positions in degrees
fk_func: Forward kinematics function to use (defaults to fk_gripper)
"""
if fk_func is None:
fk_func = self.fk_gripper
eps = 1e-8
jac = np.zeros(shape=(3, 5))
delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64)
for el_ix in range(len(robot_pos_deg[:-1])):
delta *= 0
delta[el_ix] = eps / 2
Sdot = (
fk_func(robot_pos_deg[:-1] + delta)[:3, 3]
- fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
) / eps
jac[:, el_ix] = Sdot
return jac
def ik(
self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None
):
"""Inverse kinematics using gradient descent.
Args:
current_joint_state: Initial joint positions in degrees
desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix
position_only: If True, only match end-effector position, not orientation
fk_func: Forward kinematics function to use (defaults to fk_gripper)
Returns:
Joint positions in degrees that achieve the desired end-effector pose
"""
if fk_func is None:
fk_func = self.fk_gripper
# Do gradient descent.
max_iterations = 5
learning_rate = 1
for _ in range(max_iterations):
current_ee_pose = fk_func(current_joint_state)
if not position_only:
error = se3_error(desired_ee_pose, current_ee_pose)
jac = self.compute_jacobian(current_joint_state, fk_func)
else:
error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3]
jac = self.compute_positional_jacobian(current_joint_state, fk_func)
delta_angles = np.linalg.pinv(jac) @ error
current_joint_state[:-1] += learning_rate * delta_angles
if np.linalg.norm(error) < 5e-3:
return current_joint_state
return current_joint_state
if __name__ == "__main__":
import time
def run_test(robot_type):
"""Run test suite for a specific robot type."""
print(f"\n--- Testing {robot_type.upper()} Robot ---")
# Initialize kinematics for this robot
robot = RobotKinematics(robot_type)
# Test 1: Forward kinematics consistency
print("Test 1: Forward kinematics consistency")
test_angles = np.array(
[30, 45, -30, 20, 10, 0]
) # Example joint angles in degrees
# Calculate FK for different joints
shoulder_pose = robot.fk_shoulder(test_angles)
humerus_pose = robot.fk_humerus(test_angles)
forearm_pose = robot.fk_forearm(test_angles)
wrist_pose = robot.fk_wrist(test_angles)
gripper_pose = robot.fk_gripper(test_angles)
gripper_tip_pose = robot.fk_gripper_tip(test_angles)
# Check that poses form a consistent kinematic chain (positions should be progressively further from origin)
distances = [
np.linalg.norm(shoulder_pose[:3, 3]),
np.linalg.norm(humerus_pose[:3, 3]),
np.linalg.norm(forearm_pose[:3, 3]),
np.linalg.norm(wrist_pose[:3, 3]),
np.linalg.norm(gripper_pose[:3, 3]),
np.linalg.norm(gripper_tip_pose[:3, 3]),
]
# Check if distances generally increase along the chain
is_consistent = all(
distances[i] <= distances[i + 1] for i in range(len(distances) - 1)
)
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
print(
f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}"
)
# Test 2: Jacobian computation
print("Test 2: Jacobian computation")
jacobian = robot.compute_jacobian(test_angles)
positional_jacobian = robot.compute_positional_jacobian(test_angles)
# Check shapes
jacobian_shape_ok = jacobian.shape == (6, 5)
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
print(
f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}"
)
# Test 3: Inverse kinematics
print("Test 3: Inverse kinematics (position only)")
# Generate target pose from known joint angles
original_angles = np.array([10, 20, 30, -10, 5, 0])
target_pose = robot.fk_gripper(original_angles)
# Start IK from a different position
initial_guess = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
# Measure IK performance
start_time = time.time()
computed_angles = robot.ik(initial_guess.copy(), target_pose)
ik_time = time.time() - start_time
# Compute resulting pose from IK solution
result_pose = robot.fk_gripper(computed_angles)
# Calculate position error
pos_error = np.linalg.norm(target_pose[:3, 3] - result_pose[:3, 3])
passed = pos_error < 0.01 # Accept errors less than 1cm
print(f" IK computation time: {ik_time:.4f} seconds")
print(f" Position error: {pos_error:.4f}")
print(f" IK position accuracy: {'PASSED' if passed else 'FAILED'}")
return is_consistent and jacobian_shape_ok and pos_jacobian_shape_ok and passed
# Run tests for all robot types
results = {}
for robot_type in ["koch", "so100", "moss"]:
results[robot_type] = run_test(robot_type)
# Print overall summary
print("\n=== Test Summary ===")
all_passed = all(results.values())
for robot_type, passed in results.items():
print(f"{robot_type.upper()}: {'PASSED' if passed else 'FAILED'}")
print(f"\nOverall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}")

View File

@@ -0,0 +1,870 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import shutil
import time
from pprint import pformat
from concurrent.futures import ThreadPoolExecutor
# from torch.multiprocessing import Event, Queue, Process
# from threading import Event, Thread
# from torch.multiprocessing import Queue, Event
from torch.multiprocessing import Queue
from lerobot.scripts.server.utils import setup_process_handlers
import grpc
# Import generated stubs
import hilserl_pb2_grpc # type: ignore
import hydra
import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import nn
from torch.optim.optimizer import Optimizer
from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.utils.utils import (
format_big_number,
get_global_random_state,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_random_state,
set_global_seed,
)
from lerobot.scripts.server.buffer import (
ReplayBuffer,
concatenate_batch_transitions,
move_transition_to_device,
move_state_dict_to_device,
bytes_to_transitions,
state_to_bytes,
bytes_to_python_object,
)
from lerobot.scripts.server import learner_service
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
if not cfg.resume:
if Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
"Use `resume=true` to resume training."
)
return cfg
# if resume == True
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
if not checkpoint_dir.exists():
raise RuntimeError(
f"No model checkpoint found in {checkpoint_dir} for resume=True"
)
checkpoint_cfg_path = str(
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
logging.info(
colored(
"Resume=True detected, resuming previous run",
color="yellow",
attrs=["bold"],
)
)
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
if len(diff) > 0:
logging.warning(
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
"Checkpoint configuration takes precedence."
)
checkpoint_cfg.resume = True
return checkpoint_cfg
def load_training_state(
cfg: DictConfig,
logger: Logger,
optimizers: Optimizer | dict,
):
if not cfg.resume:
return None, None
training_state = torch.load(
logger.last_checkpoint_dir / logger.training_state_file_name, weights_only=False
)
if isinstance(training_state["optimizer"], dict):
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
for k, v in training_state["optimizer"].items():
optimizers[k].load_state_dict(v)
else:
optimizers.load_state_dict(training_state["optimizer"])
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"], training_state["interaction_step"]
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.online_steps=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer(
cfg: DictConfig, logger: Logger, device: str, storage_device: str
) -> ReplayBuffer:
if not cfg.resume:
return ReplayBuffer(
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
storage_device=storage_device,
optimize_memory=True,
)
logging.info("Resume training load the online dataset")
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id,
local_files_only=True,
root=logger.log_dir / "dataset",
)
return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
optimize_memory=True,
)
def initialize_offline_replay_buffer(
cfg: DictConfig,
logger: Logger,
device: str,
storage_device: str,
active_action_dims: list[int] | None = None,
) -> ReplayBuffer:
if not cfg.resume:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
if cfg.resume:
logging.info("load offline dataset")
offline_dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id,
local_files_only=True,
root=logger.log_dir / "dataset_offline",
)
logging.info("Convert to a offline replay buffer")
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
capacity=cfg.training.offline_buffer_capacity,
)
return offline_replay_buffer
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if (
policy.config.vision_encoder_name is None
or not policy.config.freeze_vision_encoder
):
return None, None
with torch.no_grad():
observation_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_observation_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
return observation_features, next_observation_features
def use_threads(cfg: DictConfig) -> bool:
return cfg.actor_learner_config.concurrency.learner == "threads"
def start_learner_threads(
cfg: DictConfig,
logger: Logger,
out_dir: str,
shutdown_event: any, # Event,
) -> None:
# Create multiprocessing queues
transition_queue = Queue()
interaction_message_queue = Queue()
parameters_queue = Queue()
concurrency_entity = None
if use_threads(cfg):
from threading import Thread
concurrency_entity = Thread
else:
from torch.multiprocessing import Process
concurrency_entity = Process
communication_process = concurrency_entity(
target=start_learner_server,
args=(
parameters_queue,
transition_queue,
interaction_message_queue,
shutdown_event,
cfg,
),
daemon=True,
)
communication_process.start()
add_actor_information_and_train(
cfg,
logger,
out_dir,
shutdown_event,
transition_queue,
interaction_message_queue,
parameters_queue,
)
logging.info("[LEARNER] Training process stopped")
logging.info("[LEARNER] Closing queues")
transition_queue.close()
interaction_message_queue.close()
parameters_queue.close()
communication_process.join()
logging.info("[LEARNER] Communication process joined")
logging.info("[LEARNER] join queues")
transition_queue.cancel_join_thread()
interaction_message_queue.cancel_join_thread()
parameters_queue.cancel_join_thread()
logging.info("[LEARNER] queues closed")
def start_learner_server(
parameters_queue: Queue,
transition_queue: Queue,
interaction_message_queue: Queue,
shutdown_event: any, # Event,
cfg: DictConfig,
):
if not use_threads(cfg):
# We need init logging for MP separataly
init_logging()
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
# Return back for MP
setup_process_handlers(False)
service = learner_service.LearnerService(
shutdown_event,
parameters_queue,
cfg.actor_learner_config.policy_parameters_push_frequency,
transition_queue,
interaction_message_queue,
)
server = grpc.server(
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
],
)
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
service,
server,
)
host = cfg.actor_learner_config.learner_host
port = cfg.actor_learner_config.learner_port
server.add_insecure_port(f"{host}:{port}")
server.start()
logging.info("[LEARNER] gRPC server started")
shutdown_event.wait()
logging.info("[LEARNER] Stopping gRPC server...")
server.stop(learner_service.STUTDOWN_TIMEOUT)
logging.info("[LEARNER] gRPC server stopped")
def check_nan_in_transition(
observations: torch.Tensor,
actions: torch.Tensor,
next_state: torch.Tensor,
raise_error: bool = False,
) -> bool:
"""
Check for NaN values in transition data.
Args:
observations: Dictionary of observation tensors
actions: Action tensor
next_state: Dictionary of next state tensors
raise_error: If True, raises ValueError when NaN is detected
Returns:
bool: True if NaN values were detected, False otherwise
"""
nan_detected = False
# Check observations
for key, tensor in observations.items():
if torch.isnan(tensor).any():
logging.error(f"observations[{key}] contains NaN values")
nan_detected = True
if raise_error:
raise ValueError(f"NaN detected in observations[{key}]")
# Check next state
for key, tensor in next_state.items():
if torch.isnan(tensor).any():
logging.error(f"next_state[{key}] contains NaN values")
nan_detected = True
if raise_error:
raise ValueError(f"NaN detected in next_state[{key}]")
# Check actions
if torch.isnan(actions).any():
logging.error("actions contains NaN values")
nan_detected = True
if raise_error:
raise ValueError("NaN detected in actions")
return nan_detected
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
logging.debug("[LEARNER] Pushing actor policy to the queue")
state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
state_bytes = state_to_bytes(state_dict)
parameters_queue.put(state_bytes)
def add_actor_information_and_train(
cfg,
logger: Logger,
out_dir: str,
shutdown_event: any, # Event,
transition_queue: Queue,
interaction_message_queue: Queue,
parameters_queue: Queue,
):
"""
Handles data transfer from the actor to the learner, manages training updates,
and logs training progress in an online reinforcement learning setup.
This function continuously:
- Transfers transitions from the actor to the replay buffer.
- Logs received interaction messages.
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
- Samples batches from the replay buffer and performs multiple critic updates.
- Periodically updates the actor, critic, and temperature optimizers.
- Logs training statistics, including loss values and optimization frequency.
**NOTE:**
- This function performs multiple responsibilities (data transfer, training, and logging).
It should ideally be split into smaller functions in the future.
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
significantly reduces performance. Instead, this function executes all operations in a single thread.
Args:
cfg: Configuration object containing hyperparameters.
device (str): The computing device (`"cpu"` or `"cuda"`).
logger (Logger): Logger instance for tracking training progress.
out_dir (str): The output directory for storing training checkpoints and logs.
shutdown_event (Event): Event to signal shutdown.
transition_queue (Queue): Queue for receiving transitions from the actor.
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
parameters_queue (Queue): Queue for sending policy parameters to the actor.
"""
device = get_safe_torch_device(cfg.device, log=True)
storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device)
logging.info("Initializing policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy intance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
# TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy(
hydra_cfg=cfg,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
)
# Update the policy config with the grad_clip_norm value from training config if it exists
clip_grad_norm_value = cfg.training.grad_clip_norm
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
push_actor_policy_to_queue(parameters_queue, policy)
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state(
cfg, logger, optimizers
)
log_training_info(cfg, out_dir, policy)
replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device)
batch_size = cfg.training.batch_size
offline_replay_buffer = None
if cfg.dataset_repo_id is not None:
active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
offline_replay_buffer = initialize_offline_replay_buffer(
cfg=cfg,
logger=logger,
device=device,
storage_device=storage_device,
active_action_dims=active_action_dims,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
# are divided by 200. So we need to have a single thread that does all the work.
time.time()
logging.info("Starting learner thread")
interaction_message, transition = None, None
optimization_step = (
resume_optimization_step if resume_optimization_step is not None else 0
)
interaction_step_shift = (
resume_interaction_step if resume_interaction_step is not None else 0
)
# Extract variables from cfg
online_step_before_learning = cfg.training.online_step_before_learning
utd_ratio = cfg.policy.utd_ratio
dataset_repo_id = cfg.dataset_repo_id
fps = cfg.fps
log_freq = cfg.training.log_freq
save_freq = cfg.training.save_freq
device = cfg.device
storage_device = cfg.training.storage_device
policy_update_freq = cfg.training.policy_update_freq
policy_parameters_push_frequency = (
cfg.actor_learner_config.policy_parameters_push_frequency
)
save_checkpoint = cfg.training.save_checkpoint
online_steps = cfg.training.online_steps
while True:
if shutdown_event is not None and shutdown_event.is_set():
logging.info("[LEARNER] Shutdown signal received. Exiting...")
break
logging.debug("[LEARNER] Waiting for transitions")
while not transition_queue.empty() and not shutdown_event.is_set():
transition_list = transition_queue.get()
transition_list = bytes_to_transitions(transition_list)
for transition in transition_list:
transition = move_transition_to_device(transition, device=device)
if check_nan_in_transition(
transition["state"], transition["action"], transition["next_state"]
):
logging.warning("NaN detected in transition, skipping")
continue
replay_buffer.add(**transition)
if cfg.dataset_repo_id is not None and transition.get(
"complementary_info", {}
).get("is_intervention"):
offline_replay_buffer.add(**transition)
logging.debug("[LEARNER] Received transitions")
logging.debug("[LEARNER] Waiting for interactions")
while not interaction_message_queue.empty() and not shutdown_event.is_set():
interaction_message = interaction_message_queue.get()
interaction_message = bytes_to_python_object(interaction_message)
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(
interaction_message, mode="train", custom_step_key="Interaction step"
)
logging.debug("[LEARNER] Received interactions")
if len(replay_buffer) < online_step_before_learning:
continue
logging.debug("[LEARNER] Starting optimization loop")
time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
)
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.critic_ensemble.parameters(), clip_grad_norm_value
)
optimizers["critic"].step()
batch = replay_buffer.sample(batch_size)
if dataset_repo_id is not None:
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
)
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.critic_ensemble.parameters(), clip_grad_norm_value
).item()
optimizers["critic"].step()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
training_infos["critic_grad_norm"] = critic_grad_norm
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
loss_actor = policy.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
optimizers["actor"].zero_grad()
loss_actor.backward()
# clip gradients
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.actor.parameters_to_optimize, clip_grad_norm_value
).item()
optimizers["actor"].step()
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization
loss_temperature = policy.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
# clip gradients
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
[policy.log_alpha], clip_grad_norm_value
).item()
optimizers["temperature"].step()
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue, policy)
last_time_policy_pushed = time.time()
policy.update_target_networks()
if optimization_step % log_freq == 0:
training_infos["replay_buffer_size"] = len(replay_buffer)
if offline_replay_buffer is not None:
training_infos["offline_replay_buffer_size"] = len(
offline_replay_buffer
)
training_infos["Optimization step"] = optimization_step
logger.log_dict(
d=training_infos, mode="train", custom_step_key="Optimization step"
)
# logging.info(f"Training infos: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (
time_for_one_optimization_step + 1e-9
)
logging.info(
f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}"
)
logger.log_dict(
{
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
"Optimization step": optimization_step,
},
mode="train",
custom_step_key="Optimization step",
)
optimization_step += 1
if optimization_step % log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if save_checkpoint and (
optimization_step % save_freq == 0 or optimization_step == online_steps
):
logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
interaction_step = (
interaction_message["Interaction step"]
if interaction_message is not None
else 0
)
logger.save_checkpoint(
optimization_step,
policy,
optimizers,
scheduler=None,
identifier=step_identifier,
interaction_step=interaction_step,
)
# TODO : temporarly save replay buffer here, remove later when on the robot
# We want to control this with the keyboard inputs
dataset_dir = logger.log_dir / "dataset"
if dataset_dir.exists() and dataset_dir.is_dir():
shutil.rmtree(
dataset_dir,
)
replay_buffer.to_lerobot_dataset(
dataset_repo_id, fps=fps, root=logger.log_dir / "dataset"
)
if offline_replay_buffer is not None:
dataset_dir = logger.log_dir / "dataset_offline"
if dataset_dir.exists() and dataset_dir.is_dir():
shutil.rmtree(
dataset_dir,
)
offline_replay_buffer.to_lerobot_dataset(
cfg.dataset_repo_id,
fps=cfg.fps,
root=logger.log_dir / "dataset_offline",
)
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
**NOTE:**
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
A tuple containing:
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
"""
optimizer_actor = torch.optim.Adam(
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize,
lr=policy.config.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha], lr=policy.config.critic_lr
)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
return optimizers, lr_scheduler
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
init_logging()
logging.info(pformat(OmegaConf.to_container(cfg)))
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
cfg = handle_resume_logic(cfg, out_dir)
set_global_seed(cfg.seed)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
shutdown_event = setup_process_handlers(use_threads(cfg))
start_learner_threads(
cfg,
logger,
out_dir,
shutdown_event,
)
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
def train_cli(cfg: dict):
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
logging.info("[LEARNER] train_cli finished")
if __name__ == "__main__":
train_cli()
logging.info("[LEARNER] main finished")

View File

@@ -0,0 +1,82 @@
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import logging
from multiprocessing import Event, Queue
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks
from lerobot.scripts.server.network_utils import send_bytes_in_chunks
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
STUTDOWN_TIMEOUT = 10
class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
def __init__(
self,
shutdown_event: Event,
parameters_queue: Queue,
seconds_between_pushes: float,
transition_queue: Queue,
interaction_message_queue: Queue,
):
self.shutdown_event = shutdown_event
self.parameters_queue = parameters_queue
self.seconds_between_pushes = seconds_between_pushes
self.transition_queue = transition_queue
self.interaction_message_queue = interaction_message_queue
def StreamParameters(self, request, context):
# TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor")
while not self.shutdown_event.is_set():
logging.info("[LEARNER] Push parameters to the Actor")
buffer = self.parameters_queue.get()
yield from send_bytes_in_chunks(
buffer,
hilserl_pb2.Parameters,
log_prefix="[LEARNER] Sending parameters",
silent=True,
)
logging.info("[LEARNER] Parameters sent")
self.shutdown_event.wait(self.seconds_between_pushes)
logging.info("[LEARNER] Stream parameters finished")
return hilserl_pb2.Empty()
def SendTransitions(self, request_iterator, _context):
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor")
receive_bytes_in_chunks(
request_iterator,
self.transition_queue,
self.shutdown_event,
log_prefix="[LEARNER] transitions",
)
logging.debug("[LEARNER] Finished receiving transitions")
return hilserl_pb2.Empty()
def SendInteractions(self, request_iterator, _context):
# TODO: authorize the request
logging.info(
"[LEARNER] Received request to receive interactions from the Actor"
)
receive_bytes_in_chunks(
request_iterator,
self.interaction_message_queue,
self.shutdown_event,
log_prefix="[LEARNER] interactions",
)
logging.debug("[LEARNER] Finished receiving interactions")
return hilserl_pb2.Empty()
def Ready(self, request, context):
return hilserl_pb2.Empty()

View File

@@ -0,0 +1,192 @@
import einops
import numpy as np
import gymnasium as gym
import torch
from omegaconf import DictConfig
from typing import Any
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from mani_skill.utils.wrappers.record import RecordEpisode
def preprocess_maniskill_observation(
observations: dict[str, np.ndarray],
) -> dict[str, torch.Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# map to expected inputs for the policy
return_observations = {}
# TODO: You have to merge all tensors from agent key and extra key
# You don't keep sensor param key in the observation
# And you keep sensor data rgb
q_pos = observations["agent"]["qpos"]
q_vel = observations["agent"]["qvel"]
tcp_pos = observations["extra"]["tcp_pose"]
img = observations["sensor_data"]["base_camera"]["rgb"]
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
return_observations["observation.image"] = img
return_observations["observation.state"] = state
return return_observations
class ManiSkillObservationWrapper(gym.ObservationWrapper):
def __init__(self, env, device: torch.device = "cuda"):
super().__init__(env)
self.device = device
def observation(self, observation):
observation = preprocess_maniskill_observation(observation)
observation = {k: v.to(self.device) for k, v in observation.items()}
return observation
class ManiSkillCompat(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
new_action_space_shape = env.action_space.shape[-1]
new_low = np.squeeze(env.action_space.low, axis=0)
new_high = np.squeeze(env.action_space.high, axis=0)
self.action_space = gym.spaces.Box(
low=new_low, high=new_high, shape=(new_action_space_shape,)
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Any, dict[str, Any]]:
options = {}
return super().reset(seed=seed, options=options)
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
reward = reward.item()
terminated = terminated.item()
truncated = truncated.item()
return obs, reward, terminated, truncated, info
class ManiSkillActionWrapper(gym.ActionWrapper):
def __init__(self, env):
super().__init__(env)
self.action_space = gym.spaces.Tuple(
spaces=(env.action_space, gym.spaces.Discrete(2))
)
def action(self, action):
action, telop = action
return action
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
def __init__(self, env, multiply_factor: float = 1):
super().__init__(env)
self.multiply_factor = multiply_factor
action_space_agent: gym.spaces.Box = env.action_space[0]
action_space_agent.low = action_space_agent.low * multiply_factor
action_space_agent.high = action_space_agent.high * multiply_factor
self.action_space = gym.spaces.Tuple(
spaces=(action_space_agent, gym.spaces.Discrete(2))
)
def step(self, action):
if isinstance(action, tuple):
action, telop = action
else:
telop = 0
action = action / self.multiply_factor
obs, reward, terminated, truncated, info = self.env.step((action, telop))
return obs, reward, terminated, truncated, info
def make_maniskill(
cfg: DictConfig,
n_envs: int | None = None,
) -> gym.Env:
"""
Factory function to create a ManiSkill environment with standard wrappers.
Args:
task: Name of the ManiSkill task
obs_mode: Observation mode (rgb, rgbd, etc)
control_mode: Control mode for the robot
render_mode: Rendering mode
sensor_configs: Camera sensor configurations
n_envs: Number of parallel environments
Returns:
A wrapped ManiSkill environment
"""
env = gym.make(
cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs={"width": cfg.env.image_size, "height": cfg.env.image_size},
num_envs=n_envs,
)
if cfg.env.video_record.enabled:
env = RecordEpisode(
env,
output_dir=cfg.env.video_record.record_dir,
save_trajectory=True,
trajectory_name=cfg.env.video_record.trajectory_name,
save_video=True,
video_fps=30,
)
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env.unwrapped.metadata["render_fps"] = 20
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03)
return env
if __name__ == "__main__":
import argparse
import hydra
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", type=str, default="lerobot/configs/env/maniskill_example.yaml"
)
args = parser.parse_args()
# Initialize config
with hydra.initialize(version_base=None, config_path="../../configs"):
cfg = hydra.compose(config_name="env/maniskill_example.yaml")
env = make_maniskill(
task=cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs={"width": cfg.env.render_size, "height": cfg.env.render_size},
)
print("env done")
obs, info = env.reset()
random_action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(random_action)

Some files were not shown because too many files have changed in this diff Show More