Compare commits

...

319 Commits

Author SHA1 Message Date
Remi Cadene
7cee7a0f20 Add mobile Aloha and visu with rerun.io 2024-04-20 16:19:55 +02:00
Cadene
2a59825a00 fix online training 2024-04-20 00:12:34 +00:00
Cadene
06628ba059 fix online training 2024-04-19 23:58:38 +00:00
Cadene
b2b5329683 fix online training 2024-04-19 23:48:43 +00:00
Cadene
85f1554da8 fix visualize_dataset 2024-04-19 23:40:35 +00:00
Cadene
9b4c2e2a9f small fix 2024-04-19 23:30:39 +00:00
Cadene
20928021c0 Add tests/data 2024-04-19 23:27:11 +00:00
Cadene
c20cf2fbbc Remove Prod, Tests are passind 2024-04-19 23:27:10 +00:00
Cadene
35a573c98e Use v1.1, hf_transform_to_torch, Add 3 xarm datasets 2024-04-19 23:26:13 +00:00
Cadene
714a776277 id -> index, finish moving compute_stats before hf_dataset push_to_hub 2024-04-19 23:25:06 +00:00
Cadene
64b09ea7a7 WIP add load functions + episode_data_index 2024-04-19 23:24:08 +00:00
Cadene
0bd2ca8d82 Add meta_data, revision v1.1 2024-04-19 23:24:08 +00:00
Alexander Soare
e2168163cd Quality of life patches for eval.py (#86) 2024-04-19 12:33:47 +01:00
Simon Alibert
ac0ab27333 Hotfix test_examples.py (#87) 2024-04-19 12:36:04 +02:00
Alexander Soare
8d980940a2 Fix tolerance for delta_timestamps (#84)
Co-authored-by: Remi <re.cadene@gmail.com>
2024-04-18 18:48:22 +01:00
Simon Alibert
7ad1909641 Tests cleaning & simplification (#81) 2024-04-18 14:47:42 +02:00
Remi
0928afd37d Improve dataset examples (#82)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
2024-04-18 11:43:16 +02:00
Alexander Soare
d5c4b0c344 Merge pull request #80 from huggingface/alexander-soare/unify_policy_api
Unify policy API
2024-04-17 17:05:22 +01:00
Alexander Soare
dd9c6eed15 Add temporary patch in TD-MPC 2024-04-17 16:27:57 +01:00
Alexander Soare
2298ddf226 wip 2024-04-17 16:21:37 +01:00
Alexander Soare
63e5ec6483 revert some formatting changes 2024-04-17 11:40:49 +01:00
Alexander Soare
c50a13ab31 draft 2024-04-17 10:50:54 +01:00
Alexander Soare
18dd8f32cd Merge remote-tracking branch 'upstream/main' into unify_policy_api 2024-04-17 09:10:30 +01:00
Alexander Soare
6d78bca749 Merge pull request #77 from huggingface/alexander-soare/fix_save_stats
Fix issue with saving freshly computed stats
2024-04-17 09:09:18 +01:00
Alexander Soare
296bbfe1ad Merge branch 'fix_stats_saving' into unify_policy_api 2024-04-17 09:08:04 +01:00
Alexander Soare
c9454333d8 revision 2024-04-17 09:02:35 +01:00
Alexander Soare
1331d3b4e4 fix issue with saving freshly computed stats 2024-04-17 08:49:28 +01:00
Alexander Soare
bff4b673c9 Merge remote-tracking branch 'upstream/main' into unify_policy_api 2024-04-17 08:08:57 +01:00
Remi
3f1c322d56 Merge pull request #73 from huggingface/user/rcadene/2024_04_14_hf_datasets
Use Hugging Face datasets.Dataset
2024-04-16 21:54:37 +02:00
Simon Alibert
fbc31d906c Merge pull request #74 from huggingface/user/aliberts/2024_04_15_setup_contributions
Setup contributions
2024-04-16 19:51:13 +02:00
Simon Alibert
4057cc6c28 Apply suggestions from code review
Various fixes for #74

Co-authored-by: Remi <re.cadene@gmail.com>
2024-04-16 19:35:01 +02:00
Cadene
91badebdfc fix tests 2024-04-16 17:29:31 +00:00
Cadene
4327e43f19 fix merge thingy 2024-04-16 17:24:25 +00:00
Cadene
36d9e885ef Address comments 2024-04-16 17:20:54 +00:00
Cadene
b241ea46dd move download_and_upload_dataset.py to root_dir 2024-04-16 17:20:53 +00:00
Cadene
e09d25267e fix online training 2024-04-16 17:20:53 +00:00
Cadene
4a3eac4743 fix unit tests, stats was missing, visualize_dataset was broken 2024-04-16 17:20:53 +00:00
Cadene
69eeced9d9 add datasets to poetry/cpu 2024-04-16 17:20:52 +00:00
Cadene
0980fff6cc HF datasets works 2024-04-16 17:19:40 +00:00
Cadene
5edd9a89a0 Move stats_dataset init into else statement -> faster init 2024-04-16 17:19:39 +00:00
Cadene
c7a8218620 typo 2024-04-16 17:19:39 +00:00
Cadene
67d79732f9 Add download_and_upload_dataset.py in script, update all datasets, update online training 2024-04-16 17:19:39 +00:00
Cadene
c6aca7fe44 For Pusht: use hf datasets to train, rename load_data_with_delta_timestamps -> load_previous_and_future_frames 2024-04-16 17:19:06 +00:00
Alexander Soare
cb3978b5f3 backup wip 2024-04-16 18:12:39 +01:00
Remi
4ed55c3ba3 Merge pull request #76 from huggingface/user/aliberts/2024_04_16_refactor_pyproject
Refactor pyproject
2024-04-16 18:55:09 +02:00
Alexander Soare
0eb899de73 Merge remote-tracking branch 'upstream/main' into unify_policy_api 2024-04-16 17:30:41 +01:00
Alexander Soare
a8ddefcfd4 Merge pull request #71 from huggingface/user/alexander-soare/refactor_dp
Partial refactor of Diffusion Policy
2024-04-16 17:24:05 +01:00
Alexander Soare
a9496fde39 revision 1 2024-04-16 17:15:51 +01:00
Alexander Soare
8a322da422 backup wip 2024-04-16 16:35:04 +01:00
Alexander Soare
23be5e1e7b backup wip 2024-04-16 16:31:44 +01:00
Alexander Soare
43a614c173 Fix test_examples 2024-04-16 14:07:16 +01:00
Alexander Soare
9c2f10bd04 ready for review 2024-04-16 13:43:58 +01:00
Alexander Soare
03b08eb74e backup wip 2024-04-16 12:51:32 +01:00
Simon Alibert
38ef878eed Add COC badge 2024-04-16 13:37:20 +02:00
Simon Alibert
dee174f678 pre-commit autoupdate 2024-04-16 12:10:26 +02:00
Simon Alibert
70e8de95f7 Clean & update pyproject 2024-04-16 12:09:56 +02:00
Simon Alibert
376d75f8d3 Add env info 2024-04-16 10:35:43 +02:00
Simon Alibert
a621ec8d88 Add PR & issue templates 2024-04-16 10:34:29 +02:00
Simon Alibert
f9b48abe64 Add CONTRIBUTING 2024-04-16 10:33:58 +02:00
Simon Alibert
352010b718 Add COC 2024-04-16 10:32:09 +02:00
Alexander Soare
5608e659e6 backup wip 2024-04-15 19:06:44 +01:00
Alexander Soare
14f3ffb412 Merge remote-tracking branch 'upstream/main' into refactor_dp 2024-04-15 17:08:28 +01:00
Alexander Soare
e37f4e8c53 Merge pull request #72 from huggingface/alexander-soare/policy_config
Use policy configs instead of passing arguments directly to policy classes
2024-04-15 16:44:21 +01:00
Alexander Soare
30023535f9 revision 1 2024-04-15 10:56:43 +01:00
Alexander Soare
40d417ef60 Make sure to make remove all traces of omegaconf from policy config 2024-04-15 09:59:18 +01:00
Alexander Soare
9241b5e830 pass step as kwarg 2024-04-15 09:52:54 +01:00
Alexander Soare
2ccf89d78c try fix tests 2024-04-15 09:48:03 +01:00
Alexander Soare
ef4bd9e25c Use dataclass config for ACT 2024-04-15 09:39:23 +01:00
Alexander Soare
34f00753eb remove policy.py 2024-04-12 17:13:25 +01:00
Alexander Soare
55e484124a draft pr 2024-04-12 17:03:59 +01:00
Alexander Soare
6d0a45a97d ready for review 2024-04-12 11:36:52 +01:00
Alexander Soare
5666ec3ec7 backup wip 2024-04-11 18:33:54 +01:00
Alexander Soare
94cc22da9e Merge remote-tracking branch 'upstream/main' into refactor_dp 2024-04-11 17:52:10 +01:00
Alexander Soare
976a197f98 backup wip 2024-04-11 17:51:35 +01:00
Remi
5bd953e8e7 Merge pull request #64 from huggingface/user/rcadene/2024_03_31_remove_torchrl
Remove torchrl
2024-04-11 16:45:16 +02:00
Cadene
c1a618e567 fix pusht images type from float32 to uint8, update gym-pusht dependencies 2024-04-11 14:29:16 +00:00
Cadene
4216636084 fix 2024-04-11 14:01:23 +00:00
Cadene
8e5b4365ac fix 2024-04-11 13:57:22 +00:00
Cadene
84d2468da1 fix 2024-04-11 13:51:13 +00:00
Cadene
5be83fbff6 fix 2024-04-11 13:46:13 +00:00
Cadene
a605eec7e9 Use pusht as example so it's fast 2024-04-11 13:36:50 +00:00
Cadene
2f4af32d3f small fix 2024-04-11 13:21:06 +00:00
Cadene
36de77ac18 fix and clarify tests 2024-04-11 13:16:47 +00:00
Cadene
92701088a3 small fix 2024-04-11 13:04:27 +00:00
Cadene
657b27cc8f fix load_data_with_delta_timestamps and add tests 2024-04-11 13:00:09 +00:00
Remi
9229226522 Update lerobot/common/envs/utils.py 2024-04-11 10:35:17 +02:00
Simon Alibert
35069bb3e2 Update CI dep 2024-04-11 09:03:58 +02:00
Simon Alibert
0593d20348 Fix gym-xarm 2024-04-11 08:51:37 +02:00
Cadene
949f4d1a5b remove comment 2024-04-10 17:21:36 +00:00
Cadene
3914831585 remove __name__ outside script 2024-04-10 17:16:44 +00:00
Cadene
f8c5a2eb10 remove comment 2024-04-10 17:14:02 +00:00
Cadene
9874652c2f enable test_compute_stats
enable test_compute_stats
2024-04-10 17:12:54 +00:00
Remi
4c3d8b061e Update lerobot/scripts/eval.py
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
2024-04-10 18:07:27 +02:00
Cadene
0f0113a7a6 print_cuda_memory_usage docstring 2024-04-10 16:03:39 +00:00
Cadene
daecc3b64c Remove sbatch*.sh 2024-04-10 15:58:10 +00:00
Cadene
e8622154f8 Replace import gym_pusht in pusht dataset by dynamic import 2024-04-10 15:56:18 +00:00
Cadene
8866b22db1 remove policy is None eval end-to-end tests 2024-04-10 15:09:04 +00:00
Cadene
2186429fa8 policy.batch_size=2 2024-04-10 15:00:54 +00:00
Cadene
693f620df0 drop_last=False 2024-04-10 14:59:54 +00:00
Cadene
79542ecd13 eval_episodes=1 env.episode_length=8 2024-04-10 14:44:04 +00:00
Cadene
48ec479660 fix end-to-end aloha 2024-04-10 14:26:30 +00:00
Cadene
a18bcb39a7 cfg.env.fps 2024-04-10 14:02:11 +00:00
Cadene
ec4bf115d0 fix example_1 test 2024-04-10 13:55:59 +00:00
Cadene
c08003278e test_examples are passing 2024-04-10 13:45:45 +00:00
Cadene
6082a7bc73 Enable test_available.py 2024-04-10 13:06:48 +00:00
Cadene
7c8eb7ff19 Merge remote-tracking branch 'origin/user/rcadene/2024_03_31_remove_torchrl' into user/rcadene/2024_03_31_remove_torchrl 2024-04-10 11:34:51 +00:00
Cadene
06573d7f67 online training works (loss goes down), remove repeat_action, eval_policy outputs episodes data, eval_policy uses max_episodes_rendered 2024-04-10 11:34:01 +00:00
Simon Alibert
91ff69d64c Update gym_xarm 2024-04-09 17:08:36 +02:00
Alexander Soare
07c28a21ea Merge pull request #68 from alexander-soare/refactor_act
Refactor act
2024-04-09 15:35:20 +01:00
Alexander Soare
575891e8ac Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act 2024-04-09 15:19:29 +01:00
Simon Alibert
7f4ff0b170 CI fix attempt 2024-04-09 11:58:59 +02:00
Simon Alibert
d44950e020 Add ssh key 2024-04-09 11:44:55 +02:00
Simon Alibert
dba0375089 Fix CI 2024-04-09 10:45:58 +02:00
Simon Alibert
d21543eb4f Add env.close() 2024-04-09 10:41:20 +02:00
Simon Alibert
dfaacbcf5a Split dev/test dependencies 2024-04-09 10:40:11 +02:00
Simon Alibert
2573e89e1d Remove direct dependencies 2024-04-09 10:38:08 +02:00
Simon Alibert
274f20b49d Update gym-pusht 2024-04-09 10:25:41 +02:00
Simon Alibert
d9019d9e7e disable env_checker in factory 2024-04-09 10:24:28 +02:00
Alexander Soare
e6c6c2367f Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act 2024-04-09 08:36:28 +01:00
Cadene
19e7661b8d Remove torchrl/tensordict from dependecies + update poetry cpu 2024-04-09 03:50:49 +00:00
Cadene
253e495df2 remove render(mode=visualization) 2024-04-09 03:46:05 +00:00
Cadene
6902e01db0 tests are passing for aloha/act policies, removes abstract policy 2024-04-09 03:28:56 +00:00
Cadene
73dfa3c8e3 tests for tdmpc and diffusion policy are passing 2024-04-09 02:50:32 +00:00
Alexander Soare
50e4c8050c Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act 2024-04-08 17:13:11 +01:00
Remi
1e09507bc1 Merge pull request #69 from huggingface/user/aliberts/2024_04_08_remove_envs
Remove envs
2024-04-08 16:55:20 +02:00
Cadene
1149894e1d rename handle -> task 2024-04-08 14:54:52 +00:00
Alexander Soare
9c96349926 Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act 2024-04-08 15:44:00 +01:00
Simon Alibert
6c792f0d3d Update README 2024-04-08 16:24:11 +02:00
Simon Alibert
3f6dfa4916 Add gym-aloha, rename simxarm -> xarm, refactor 2024-04-08 16:24:11 +02:00
Simon Alibert
5dff6d8339 remove aloha 2024-04-08 16:22:13 +02:00
Cadene
70aaf1c4cb test_datasets.py are passing! 2024-04-08 14:16:57 +00:00
Alexander Soare
91e0e4e175 rever change 2024-04-08 15:05:40 +01:00
Alexander Soare
0b4c42f4ff typos 2024-04-08 14:59:37 +01:00
Alexander Soare
62b18a7607 Add type hints 2024-04-08 14:51:45 +01:00
Alexander Soare
86365adf9f revision 2024-04-08 14:44:46 +01:00
Alexander Soare
0a721f3d94 empty commit 2024-04-08 13:21:38 +01:00
Alexander Soare
863f28ffd8 ready for review 2024-04-08 13:10:19 +01:00
Alexander Soare
1bab4a1dd5 Eval reproduction works with gym_aloha 2024-04-08 10:23:26 +01:00
Alexander Soare
e982c732f1 Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl 2024-04-08 09:25:45 +01:00
Cadene
e1ac5dc62f fix aloha pixels env test 2024-04-07 17:20:54 +00:00
Cadene
4371a5570d Remove latency, tdmpc policy passes tests (TODO: make it work with online RL) 2024-04-07 16:01:22 +00:00
Cadene
44656d2706 test_envs are passing 2024-04-05 23:27:12 +00:00
Alexander Soare
8d2463f45b backup wip 2024-04-05 18:46:30 +01:00
Alexander Soare
ecc7dd3b17 Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl 2024-04-05 18:35:13 +01:00
Cadene
5eff40b3d6 rename task, sim_transfer -> transfer 2024-04-05 17:18:37 +00:00
Cadene
a2d3588fca wrap dm_control aloha into gymnasium (TODO: properly seeding the env) 2024-04-05 17:17:31 +00:00
Cadene
29032fbcd3 wrap dm_control aloha into gymnasium (TODO: properly seeding the env) 2024-04-05 17:17:14 +00:00
Alexander Soare
ab2286025b Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl 2024-04-05 18:06:00 +01:00
Alexander Soare
1e71196fe3 backup wip 2024-04-05 17:38:29 +01:00
Cadene
26602269cd test_envs.py are passing, remove simxarm and pusht directories 2024-04-05 16:21:07 +00:00
Alexander Soare
9c28ac8aa4 re-add pre-commit check 2024-04-05 15:25:11 +01:00
Cadene
f56b1a0e16 WIP tdmpc 2024-04-05 13:40:31 +00:00
Simon Alibert
ab3cd3a7ba (WIP) Add gym-xarm 2024-04-05 15:35:20 +02:00
Alexander Soare
0b8d27ff2c Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl 2024-04-05 12:48:11 +01:00
Cadene
c17dffe944 policies/utils.py 2024-04-05 11:47:15 +00:00
Alexander Soare
8ba88ba250 Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl 2024-04-05 12:34:14 +01:00
Cadene
a420714ee4 fix: action_is_pad was missing in compute_loss 2024-04-05 11:33:39 +00:00
Alexander Soare
4863e54ce9 Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl 2024-04-05 12:00:31 +01:00
Cadene
ad3379a73a fix memory leak due to itertools.cycle 2024-04-05 10:59:32 +00:00
Alexander Soare
9d77f5773d Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl 2024-04-05 11:41:11 +01:00
Alexander Soare
edb125b351 backup wip 2024-04-05 11:03:28 +01:00
Cadene
5af00d0c1e fix train.py, stats, eval.py (training is running) 2024-04-05 09:31:39 +00:00
Alexander Soare
3a4dfa82fe backup wip 2024-04-04 18:34:41 +01:00
Cadene
c93ce35d8c WIP stats (TODO: run tests on stats + cmpute them) 2024-04-04 16:36:03 +00:00
Cadene
1cdfbc8b52 WIP
WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
2024-04-04 15:31:03 +00:00
Alexander Soare
278336a39a backup wip 2024-04-03 19:23:22 +01:00
Alexander Soare
110ac5ffa1 backup wip 2024-04-03 14:21:07 +01:00
Alexander Soare
c7d70a8db9 Merge remote-tracking branch 'upstream/main' into refactor_act 2024-04-03 10:08:12 +01:00
Alexander Soare
920e0d118b Merge pull request #66 from alexander-soare/fix_stats_computation
fix stats computation
2024-04-03 10:03:47 +01:00
Alexander Soare
caf4ffcf65 add TODO 2024-04-03 09:56:46 +01:00
Alexander Soare
a6ec4fbf58 remove try-catch 2024-04-03 09:53:15 +01:00
Alexander Soare
c50a62dd6d clarifying math 2024-04-03 09:47:38 +01:00
Alexander Soare
e9eb262293 numerically sound mean computation 2024-04-03 09:44:20 +01:00
Alexander Soare
7242953197 revision 2024-04-02 19:19:13 +01:00
Alexander Soare
65ef8c30d0 backup wip 2024-04-02 19:13:49 +01:00
Alexander Soare
2b928eedd4 backup wip 2024-04-02 19:11:53 +01:00
Alexander Soare
c3234adc7d fix indentation 2024-04-02 16:59:19 +01:00
Alexander Soare
148df1c1d5 add comment on test 2024-04-02 16:57:25 +01:00
Alexander Soare
a6edb85da4 Remove random sampling 2024-04-02 16:52:38 +01:00
Alexander Soare
95293d459d fix stats computation 2024-04-02 16:40:33 +01:00
Alexander Soare
11cbf1bea1 Merge pull request #53 from alexander-soare/finish_examples
Add examples 2 and 3
2024-04-01 11:52:41 +01:00
Alexander Soare
f1148b8c2d Merge remote-tracking branch 'upstream/main' into finish_examples 2024-04-01 11:31:31 +01:00
Simon Alibert
2a98cc71ed Merge pull request #56 from huggingface/user/aliberts/2024_03_27_improve_ci
Add code coverage, more end-to-end tests
2024-03-28 10:57:44 +01:00
Simon Alibert
a7c9b78e56 Deactivate eval ACT on Aloha (policy is None) 2024-03-28 10:55:11 +01:00
Simon Alibert
404b8f8a75 Fix end-to-end ACT train on Aloha 2024-03-28 10:35:11 +01:00
Simon Alibert
17c2bbbeb8 remove todo 2024-03-28 10:35:11 +01:00
Simon Alibert
006e5feabf WIP add code coverage 2024-03-28 10:35:11 +01:00
Simon Alibert
b99ee8180a Add more end-to-end tests 2024-03-28 10:35:11 +01:00
Simon Alibert
6bddcb647e Add test_aloha env test 2024-03-28 10:35:11 +01:00
Simon Alibert
58df2066a9 Add pytest-cov 2024-03-28 10:35:11 +01:00
Simon Alibert
c89aa4f8ed Merge pull request #57 from huggingface/user/aliberts/2024_03_27_improve_readme
Improve readme
2024-03-28 10:26:48 +01:00
Simon Alibert
62aad7104b Pull merge 2024-03-28 10:03:25 +01:00
Simon Alibert
9d9148dad8 Fixes for #57 2024-03-28 10:01:33 +01:00
Simon Alibert
1b6cb2b1be Add space
Co-authored-by: Remi <re.cadene@gmail.com>
2024-03-27 20:51:52 +01:00
Simon Alibert
6f1a0aefab typo fix
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
2024-03-27 20:50:23 +01:00
Alexander Soare
b7c9c33072 revision 2024-03-27 18:33:48 +00:00
Alexander Soare
120f0aef5c Merge remote-tracking branch 'upstream/main' into finish_examples 2024-03-27 17:52:36 +00:00
Simon Alibert
032200e32c Typo fix 2024-03-27 17:05:04 +01:00
Simon Alibert
de1e9187c8 Formatting 2024-03-27 16:56:21 +01:00
Simon Alibert
4f8f1926f9 Update pip install without requirements.txt 2024-03-27 16:49:27 +01:00
Simon Alibert
6710121a29 Revert "Add requirements.txt"
This reverts commit 18e7f4c3e6.
2024-03-27 16:47:49 +01:00
Simon Alibert
5f4b8ab899 Add more exhaustive install instructions 2024-03-27 16:35:32 +01:00
Simon Alibert
18e7f4c3e6 Add requirements.txt 2024-03-27 16:33:54 +01:00
Simon Alibert
643d64e2a8 Add cmake 2024-03-27 16:33:26 +01:00
Alexander Soare
c037722e23 Merge pull request #58 from alexander-soare/update_diffusion_model
Update diffusion model
2024-03-27 13:34:01 +00:00
Alexander Soare
6cd671040f fix revision 2024-03-27 13:22:14 +00:00
Alexander Soare
b6353964ba fix bug: use provided revision instead of hardcoded one 2024-03-27 13:08:47 +00:00
Alexander Soare
64c8851c40 Merge branch 'tidy_diffusion_config' into update_diffusion_model 2024-03-27 13:06:08 +00:00
Alexander Soare
dc745e3037 Remove unused part of diffusion policy config 2024-03-27 13:05:13 +00:00
Simon Alibert
6f0c2445ca Improve readme format 2024-03-27 13:26:54 +01:00
Simon Alibert
d1d2229407 WIP add badges 2024-03-27 13:26:45 +01:00
Alexander Soare
68d02c80cf Remove b/c workaround 2024-03-27 12:03:19 +00:00
Alexander Soare
011f2d27fe fix tests 2024-03-26 16:40:54 +00:00
Alexander Soare
be4441c7ff update README 2024-03-26 16:28:16 +00:00
Alexander Soare
1ed0110900 finish examples 2 and 3 2024-03-26 16:13:40 +00:00
Remi
cb6d1e0871 Merge pull request #49 from huggingface/user/rcadene/2024_03_25_readme
Improve README
2024-03-26 11:50:24 +01:00
Cadene
9ced0cf1fb unskip 2024-03-26 10:45:31 +00:00
Cadene
98534d1a63 skip 2024-03-26 10:42:53 +00:00
Cadene
edacc1d2a0 Add root in example 2024-03-26 10:40:06 +00:00
Cadene
5a46b8a2a9 fix tests 2024-03-26 10:24:46 +00:00
Cadene
4a8c5e238e issue with cat_and_write_video 2024-03-26 10:12:16 +00:00
Alexander Soare
1a1308d62f fix environment seeding
add fixes for reproducibility

only try to start env if it is closed

revision

fix normalization and data type

Improve README

Improve README

Tests are passing, Eval pretrained model works, Add gif

Update gif

Update gif

Update gif

Update gif

Update README

Update README

update minor

Update README.md

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

Update README.md

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

Address suggestions

Update thumbnail + stats

Update thumbnail + stats

Update README.md

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>

Add more comments

Add test_examples.py
2024-03-26 10:10:43 +00:00
Simon Alibert
203bcd7ca5 Merge pull request #47 from huggingface/user/aliberts/2024_03_22_fix_simxarm
Port simxarm, upgrade gym to gymnasium
2024-03-26 10:17:52 +01:00
Simon Alibert
98b9631aa6 Add n_obs_steps in default.yaml config 2024-03-26 10:08:00 +01:00
Simon Alibert
c5635b7d94 Minor fixes for #47 2024-03-25 18:50:47 +01:00
Simon Alibert
f00252552a fix ci 2024-03-25 17:38:08 +01:00
Simon Alibert
90f6af9736 Update ci 2024-03-25 17:33:47 +01:00
Simon Alibert
bcfdba109f Update pre-commit & run on all files 2024-03-25 17:29:35 +01:00
Simon Alibert
0fae5b206b Allow mujoco minor revision updates 2024-03-25 17:25:55 +01:00
Simon Alibert
7cdd6d2450 Renamed set_seed -> set_global_seed 2024-03-25 17:19:28 +01:00
Simon Alibert
058ac991eb Add simxarm back into tests 2024-03-25 16:35:46 +01:00
Simon Alibert
d3adaf1379 Add stat.pth for xarm_lift_medium 2024-03-25 15:55:45 +01:00
Simon Alibert
dc89166bee Upgrade gym to gymnasium 2024-03-25 15:12:21 +01:00
Simon Alibert
5ef813ff1e Remove deprecated code 2024-03-25 13:22:49 +01:00
Simon Alibert
a2ac83276b Update ci env 2024-03-25 13:02:50 +01:00
Simon Alibert
c0833f1c2d Remove simxarm download and preproc hack 2024-03-25 12:41:17 +01:00
Simon Alibert
de5c30405e fix wrong version 2024-03-25 12:35:06 +01:00
Simon Alibert
462e7469e8 Add xarm_lift_medium revision 1.0 to hub 2024-03-25 12:28:07 +01:00
Simon Alibert
298d391b26 Add simxarm license 2024-03-25 12:28:07 +01:00
Cadene
be6364f109 fix, it's training now! 2024-03-25 12:28:07 +01:00
Simon Alibert
127de1258d WIP 2024-03-25 12:28:07 +01:00
Cadene
b905111895 fix render issue 2024-03-25 12:28:07 +01:00
Simon Alibert
0c41675986 fix __init__ import Base 2024-03-25 12:28:07 +01:00
Simon Alibert
1c24bbda3f WIP Upgrading simxam from mujoco-py to mujoco python bindings 2024-03-25 12:28:07 +01:00
Alexander Soare
e41c420a96 Merge pull request #48 from alexander-soare/fix_visualization
Fix normalization of last frame and data type in visualization
2024-03-25 09:53:06 +00:00
Alexander Soare
4a48b77540 fix normalization and data type 2024-03-25 09:44:03 +00:00
Remi
f3cfc8b3b4 Merge pull request #46 from huggingface/user/rcadene/2024_03_23_update_stats_v1.2
Fix bug with stats.pth + Move from cadene to lerobot + Update datasets to v1.2
2024-03-24 17:53:32 +01:00
Cadene
d2ef43436c move from cadene to lerobot 2024-03-23 13:34:35 +00:00
Cadene
40f3783fca v1.2 2024-03-23 11:41:56 +00:00
Alexander Soare
e21ed6f510 Merge pull request #45 from alexander-soare/fix_environment_seeding
Reproduce original diffusion policy pusht image eval
2024-03-22 16:27:48 +00:00
Alexander Soare
bd40ffc53c revision 2024-03-22 15:43:45 +00:00
Alexander Soare
d43fa600a0 only try to start env if it is closed 2024-03-22 15:32:55 +00:00
Alexander Soare
e698d38a35 Merge remote-tracking branch 'upstream/main' into fix_environment_seeding 2024-03-22 15:11:15 +00:00
Alexander Soare
15ff3b3af8 add fixes for reproducibility 2024-03-22 15:06:57 +00:00
Remi
a80d9c0257 Merge pull request #44 from alexander-soare/run_model_from_hub
Run model from HuggingFace hub
2024-03-22 16:03:27 +01:00
Alexander Soare
b9047fbdd2 fix environment seeding 2024-03-22 13:25:23 +00:00
Alexander Soare
115927d0f6 make sure to pass stats.pth arg 2024-03-22 12:58:59 +00:00
Alexander Soare
529f42643d revision 2024-03-22 12:33:25 +00:00
Alexander Soare
1b279a1fc0 fix test 2024-03-22 10:58:27 +00:00
Alexander Soare
3f0f95f4c0 update readme 2024-03-22 10:34:22 +00:00
Alexander Soare
8720c568d0 Add ability to eval hub model 2024-03-22 10:26:55 +00:00
Remi
b633748987 Merge pull request #41 from alexander-soare/fix_pusht_diffusion
Fixes issues with PushT diffusion
2024-03-22 00:05:40 +01:00
Alexander Soare
41912b962b remove TODO 2024-03-21 13:51:26 +00:00
Alexander Soare
98361073ef cpu poetry lock 2024-03-21 12:02:24 +00:00
Alexander Soare
48df15ed26 add cpu dep 2024-03-21 11:58:28 +00:00
Alexander Soare
b562f89c3b update deps 2024-03-21 11:42:45 +00:00
Alexander Soare
4e10cd306b revert changes to default.yaml 2024-03-21 10:27:07 +00:00
Alexander Soare
72d3c3120b Merge remote-tracking branch 'upstream/main' into fix_pusht_diffusion 2024-03-21 10:20:52 +00:00
Alexander Soare
acf1174447 ready for review 2024-03-21 10:18:50 +00:00
Simon Alibert
1bd50122be Merge pull request #40 from huggingface/user/aliberts/2024_03_20_enable_mps_device
Enable mps backend for Apple silicon devices
2024-03-20 19:33:12 +01:00
Remi
2b0221052a Merge pull request #39 from huggingface/user/rcadene/2024_03_20_update_stats_v1.1
Update stats for Pusht and Aloha datasets (v1.0 -> v1.1)
2024-03-20 18:41:15 +01:00
Simon Alibert
4631d36c05 Add get_safe_torch_device in policies 2024-03-20 18:38:55 +01:00
Cadene
f23a53c3e4 update stats.pth 2024-03-20 17:36:13 +00:00
Cadene
82e6e01651 v1.1 2024-03-20 17:34:00 +00:00
Alexander Soare
d323993569 backup wip 2024-03-20 15:01:27 +00:00
Remi
ec536ef0fa Merge pull request #26 from Cadene/user/alexander-soare/multistep_policy_and_serial_env
Incorporate SerialEnv and introduct multistep policy logic
2024-03-20 15:55:27 +01:00
Remi
3910c48e43 Merge pull request #37 from Cadene/user/rcadene/2024_03_19_improve_readme
Improve README/LICENSE
2024-03-20 15:50:22 +01:00
Alexander Soare
4b7ec81dde remove abstracmethods, fix online training 2024-03-20 14:49:41 +00:00
Cadene
98a816f0f8 WIP 2024-03-20 14:47:27 +00:00
Cadene
45a4a02b7e WIP 2024-03-20 10:19:55 +00:00
Cadene
8bed0fc465 WIP 2024-03-20 10:10:00 +00:00
Alexander Soare
32e3f71dd1 backup wip 2024-03-20 09:49:16 +00:00
Alexander Soare
5332766a82 revision 2024-03-20 09:45:45 +00:00
Alexander Soare
b1ec3da035 remove internal rendering hooks 2024-03-20 09:23:23 +00:00
Alexander Soare
d16f6a93b3 Merge remote-tracking branch 'upstream/main' into user/alexander-soare/multistep_policy_and_serial_env 2024-03-20 09:01:45 +00:00
Alexander Soare
52e149fbfd Only save video frames in first rollout 2024-03-20 08:32:11 +00:00
Alexander Soare
4f1955edfd Clear action queue when environment is reset 2024-03-20 08:31:06 +00:00
Alexander Soare
c5010fee9a fix seeding 2024-03-20 08:21:33 +00:00
Alexander Soare
18fa88475b Move reset_warning_issued flag to class attribute 2024-03-20 08:09:38 +00:00
Alexander Soare
b54cdc9a0f break_when_any_done==True for batch_size==1 2024-03-19 19:08:25 +00:00
Alexander Soare
46ac87d2a6 ready for review 2024-03-19 18:59:08 +00:00
Alexander Soare
896a11f60e backup wip 2024-03-19 18:50:04 +00:00
Remi
2d5abbbd6f Merge pull request #36 from Cadene/user/rcadene/2024_03_19_replay_buffer_folder
Add replay_buffer directory and dataset versioning
2024-03-19 18:04:29 +01:00
Cadene
7d5d99e036 Address more comments 2024-03-19 16:53:07 +00:00
Cadene
b420ab88f4 version naming conventions 2024-03-19 16:44:19 +00:00
Cadene
e799dc5e3f Improve mock_dataset 2024-03-19 16:38:07 +00:00
Cadene
10034e85c4 Aloha done 2024-03-19 16:03:42 +00:00
Alexander Soare
ea17f4ce50 backup wip 2024-03-19 16:02:09 +00:00
Cadene
6a1a29386a Add replay_buffer directory in pusht datasets + aloha (WIP) 2024-03-19 15:49:45 +00:00
Alexander Soare
88347965c2 revert dp changes, make act and tdmpc batch friendly 2024-03-18 19:18:21 +00:00
Alexander Soare
09ddd9bf92 Merge branch 'main' into user/alexander-soare/multistep_policy_and_serial_env 2024-03-18 18:27:50 +00:00
Remi
099a465367 Merge pull request #35 from Cadene/switch_between_train_and_eval
Switch between train and eval modes
2024-03-18 13:21:16 +01:00
Alexander Soare
8e346b379d switch between train and eval 2024-03-18 09:45:17 +00:00
Alexander Soare
bae7e7b41c Merge remote-tracking branch 'origin/main' into user/alexander-soare/multistep_policy_and_serial_env 2024-03-15 14:06:53 +00:00
Remi
75cc10198f Merge pull request #31 from Cadene/disable_wandb_artifact
Fix wandb artifact name and add disable option
2024-03-15 15:06:16 +01:00
Alexander Soare
3124f71ebd Merge remote-tracking branch 'origin/main' into user/alexander-soare/multistep_policy_and_serial_env 2024-03-15 14:04:23 +00:00
Alexander Soare
4ecfd17f9e fix wandb artifact name and add disable option 2024-03-15 13:56:55 +00:00
Remi
58d1787ee3 Merge pull request #30 from Cadene/user/rcadene/2024_03_15_fix_path
Use Path type instead of str for data_dir
2024-03-15 14:41:33 +01:00
Cadene
b752833f3f fix download 2024-03-15 13:19:18 +00:00
Alexander Soare
a45896dc8d Merge remote-tracking branch 'origin/main' into user/alexander-soare/multistep_policy_and_serial_env 2024-03-15 13:05:35 +00:00
Alexander Soare
a222c88c99 Merge branch 'user/alexander-soare/train_pusht' into user/alexander-soare/multistep_policy_and_serial_env 2024-03-14 16:06:21 +00:00
Alexander Soare
736bc969ca Merge branch 'main' into user/alexander-soare/train_pusht 2024-03-14 16:05:11 +00:00
Alexander Soare
4822d63dbe Merge branch 'main' into user/alexander-soare/multistep_policy_and_serial_env 2024-03-14 16:04:40 +00:00
Alexander Soare
ba91976944 wip: still needs batch logic for act and tdmp 2024-03-14 15:24:10 +00:00
Alexander Soare
98484ac68e ready for review 2024-03-12 21:59:01 +00:00
Alexander Soare
9512d1d2f3 Merge branch 'main' into user/alexander-soare/train_pusht 2024-03-12 19:41:27 +00:00
Alexander Soare
87fcc536f9 wip - still need to verify full training run 2024-03-11 18:45:21 +00:00
Alexander Soare
304355c917 Merge remote-tracking branch 'origin/main' into train_pusht 2024-03-11 15:37:37 +00:00
Alexander Soare
2a01487494 early training loss as expected 2024-03-11 13:34:04 +00:00
257 changed files with 9608 additions and 10324 deletions

1
.gitattributes vendored
View File

@@ -1 +1,2 @@
*.memmap filter=lfs diff=lfs merge=lfs -text
*.stl filter=lfs diff=lfs merge=lfs -text

54
.github/ISSUE_TEMPLATE/bug-report.yml vendored Normal file
View File

@@ -0,0 +1,54 @@
name: "\U0001F41B Bug Report"
description: Submit a bug report to help us improve LeRobot
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to submit a bug report! 🐛
If this is not a bug related to the LeRobot library directly, but instead a general question about your code or the library specifically please use our [discord](https://discord.gg/s3KuuzsPFb).
- type: textarea
id: system-info
attributes:
label: System Info
description: If needed, you can share your lerobot configuration with us by running `python -m lerobot.scripts.display_sys_info` and copy-pasting its outputs below
render: Shell
placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration
validations:
required: true
- type: checkboxes
id: information-scripts-examples
attributes:
label: Information
description: 'The problem arises when using:'
options:
- label: "One of the scripts in the examples/ folder of LeRobot"
- label: "My own task or dataset (give details below)"
- type: textarea
id: reproduction
validations:
required: true
attributes:
label: Reproduction
description: |
If needed, provide a simple code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
Sharing error messages or stack traces could be useful as well!
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Try to avoid screenshots, as they are hard to read and don't allow copy-and-pasting.
placeholder: |
Steps to reproduce the behavior:
1.
2.
3.
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: Expected behavior
description: "A clear and concise description of what you would expect to happen."

15
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View File

@@ -0,0 +1,15 @@
# What does this PR do?
Example: Fixes # (issue)
## Before submitting
- Read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
- Provide a minimal code example for the reviewer to checkout & try.
- Explain how you tested your changes.
## Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.

1678
.github/poetry/cpu/poetry.lock generated vendored

File diff suppressed because it is too large Load Diff

View File

@@ -1,19 +1,25 @@
[tool.poetry]
name = "lerobot"
version = "0.1.0"
description = "Le robot is learning"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
authors = [
"Rémi Cadène <re.cadene@gmail.com>",
"Alexander Soare <alexander.soare159@gmail.com>",
"Quentin Gallouédec <quentin.gallouedec@ec-lyon.fr>",
"Simon Alibert <alibert.sim@gmail.com>",
"Thomas Wolf <thomaswolfcontact@gmail.com>",
]
repository = "https://github.com/Cadene/lerobot"
repository = "https://github.com/huggingface/lerobot"
readme = "README.md"
license = "MIT"
license = "Apache-2.0"
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"Topic :: Software Development :: Build Tools",
"License :: OSI Approved :: MIT License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.10",
]
packages = [{include = "lerobot"}]
@@ -21,43 +27,41 @@ packages = [{include = "lerobot"}]
[tool.poetry.dependencies]
python = "^3.10"
cython = "^3.0.8"
termcolor = "^2.4.0"
omegaconf = "^2.3.0"
dm-env = "^1.6"
pandas = "^2.2.1"
wandb = "^0.16.3"
moviepy = "^1.0.3"
imageio = {extras = ["pyav"], version = "^2.34.0"}
imageio = {extras = ["ffmpeg"], version = "^2.34.0"}
gdown = "^5.1.0"
hydra-core = "^1.3.2"
einops = "^0.7.0"
pygame = "^2.5.2"
pymunk = "^6.6.0"
zarr = "^2.17.0"
shapely = "^2.0.3"
scikit-image = "^0.22.0"
numba = "^0.59.0"
mpmath = "^1.3.0"
torch = {version = "^2.2.1", source = "torch-cpu"}
tensordict = {git = "https://github.com/pytorch/tensordict"}
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
mujoco = "^3.1.2"
mujoco-py = "^2.1.2.14"
gym = "^0.26.2"
opencv-python = "^4.9.0.80"
diffusers = "^0.26.3"
torchvision = {version = "^0.17.1", source = "torch-cpu"}
h5py = "^3.10.0"
dm = "^1.3"
dm-control = "^1.0.16"
huggingface-hub = "^0.21.4"
robomimic = "0.2.0"
gymnasium = "^0.29.1"
cmake = "^3.29.0.1"
gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true}
pre-commit = {version = "^3.7.0", optional = true}
debugpy = {version = "^1.8.1", optional = true}
pytest = {version = "^8.1.0", optional = true}
pytest-cov = {version = "^5.0.0", optional = true}
datasets = "^2.18.0"
[tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2"
debugpy = "^1.8.1"
pytest = "^8.1.0"
[tool.poetry.extras]
pusht = ["gym-pusht"]
xarm = ["gym-xarm"]
aloha = ["gym-aloha"]
dev = ["pre-commit", "debugpy"]
test = ["pytest", "pytest-cov"]
[[tool.poetry.source]]
@@ -98,10 +102,6 @@ exclude = [
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
[tool.poetry-dynamic-versioning]
enable = true
[build-system]
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
build-backend = "poetry_dynamic_versioning.backend"
requires = ["poetry-core>=1.5.0"]
build-backend = "poetry.core.masonry.api"

View File

@@ -1,4 +1,4 @@
name: Test
name: Tests
on:
pull_request:
@@ -10,20 +10,15 @@ on:
- main
jobs:
test:
tests:
if: |
${{ github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'CI') }} ||
${{ github.event_name == 'push' }}
runs-on: ubuntu-latest
env:
POETRY_VERSION: 1.8.1
POETRY_VERSION: 1.8.2
DATA_DIR: tests/data
TMPDIR: ~/tmp
TEMP: ~/tmp
TMP: ~/tmp
PYOPENGL_PLATFORM: egl
MUJOCO_GL: egl
LEROBOT_TESTS_DEVICE: cpu
steps:
#----------------------------------------------
# check-out repo and set-up python
@@ -39,6 +34,11 @@ jobs:
with:
python-version: '3.10'
- name: Add SSH key for installing envs
uses: webfactory/ssh-agent@v0.9.0
with:
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
#----------------------------------------------
# install & configure poetry
#----------------------------------------------
@@ -86,9 +86,13 @@ jobs:
- name: Install dependencies
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
env:
TMPDIR: ~/tmp
TEMP: ~/tmp
TMP: ~/tmp
run: |
mkdir ~/tmp
poetry install --no-interaction --no-root
poetry install --no-interaction --no-root --all-extras
- name: Save cached venv
if: |
@@ -107,38 +111,102 @@ jobs:
# install project
#----------------------------------------------
- name: Install project
run: poetry install --no-interaction
run: poetry install --no-interaction --all-extras
#----------------------------------------------
# run tests
# run tests & coverage
#----------------------------------------------
- name: Run tests
run: |
source .venv/bin/activate
pytest tests
pytest -v --cov=./lerobot --cov-report=xml tests
- name: Test train pusht end-to-end
# TODO(aliberts): Link with HF Codecov account
# - name: Upload coverage reports to Codecov with GitHub Action
# uses: codecov/codecov-action@v4
# with:
# files: ./coverage.xml
# verbose: true
#----------------------------------------------
# run end-to-end tests
#----------------------------------------------
- name: Test train ACT on ALOHA end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/train.py \
hydra.job.name=pusht \
policy=act \
env=aloha \
wandb.enable=False \
offline_steps=2 \
online_steps=0 \
eval_episodes=1 \
device=cpu \
save_model=true \
save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
policy.batch_size=2 \
hydra.run.dir=tests/outputs/act/
- name: Test eval ACT on ALOHA end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config tests/outputs/act/.hydra/config.yaml \
eval_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/act/models/2.pt
- name: Test train Diffusion on PushT end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/train.py \
policy=diffusion \
env=pusht \
wandb.enable=False \
offline_steps=2 \
online_steps=0 \
eval_episodes=1 \
device=cpu \
save_model=true \
save_freq=1 \
hydra.run.dir=tests/outputs/
save_freq=2 \
policy.batch_size=2 \
hydra.run.dir=tests/outputs/diffusion/
- name: Test eval pusht end-to-end
- name: Test eval Diffusion on PushT end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
hydra.job.name=pusht \
env=pusht \
wandb.enable=False \
--config tests/outputs/diffusion/.hydra/config.yaml \
eval_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/models/1.pt
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
- name: Test train TDMPC on Simxarm end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/train.py \
policy=tdmpc \
env=xarm \
wandb.enable=False \
offline_steps=1 \
online_steps=1 \
eval_episodes=1 \
device=cpu \
save_model=true \
save_freq=2 \
policy.batch_size=2 \
hydra.run.dir=tests/outputs/tdmpc/
- name: Test eval TDMPC on Simxarm end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config tests/outputs/tdmpc/.hydra/config.yaml \
eval_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt

3
.gitignore vendored
View File

@@ -11,6 +11,9 @@ rl
nautilus/*.yaml
*.key
# Slurm
sbatch*.sh
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

View File

@@ -1,9 +1,9 @@
exclude: ^(data/|tests/)
exclude: ^(data/|tests/data)
default_language_version:
python: python3.10
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-added-large-files
- id: debug-statements
@@ -14,11 +14,11 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.1
rev: v3.15.2
hooks:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
rev: v0.3.7
hooks:
- id: ruff
args: [--fix]

133
CODE_OF_CONDUCT.md Normal file
View File

@@ -0,0 +1,133 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the overall
community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or advances of
any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email address,
without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official email address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
[feedback@huggingface.co](mailto:feedback@huggingface.co).
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of
actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the
community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
Community Impact Guidelines were inspired by
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].
[homepage]: https://www.contributor-covenant.org
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
[Mozilla CoC]: https://github.com/mozilla/diversity
[FAQ]: https://www.contributor-covenant.org/faq
[translations]: https://www.contributor-covenant.org/translations

274
CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,274 @@
# How to contribute to 🤗 LeRobot?
Everyone is welcome to contribute, and we value everybody's contribution. Code
is thus not the only way to help the community. Answering questions, helping
others, reaching out and improving the documentations are immensely valuable to
the community.
It also helps us if you spread the word: reference the library from blog posts
on the awesome projects it made possible, shout out on Twitter when it has
helped you, or simply ⭐️ the repo to say "thank you".
Whichever way you choose to contribute, please be mindful to respect our
[code of conduct](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md).
## You can contribute in so many ways!
Some of the ways you can contribute to 🤗 LeRobot:
* Fixing outstanding issues with the existing code.
* Implementing new models, datasets or simulation environments.
* Contributing to the examples or to the documentation.
* Submitting issues related to bugs or desired new features.
Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co).
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
## Submitting a new issue or feature request
Do your best to follow these guidelines when submitting an issue or a feature
request. It will make it easier for us to come back to you quickly and with good
feedback.
### Did you find a bug?
The 🤗 LeRobot library is robust and reliable thanks to the users who notify us of
the problems they encounter. So thank you for reporting an issue.
First, we would really appreciate it if you could **make sure the bug was not
already reported** (use the search bar on Github under Issues).
Did not find it? :( So we can act quickly on it, please follow these steps:
* Include your **OS type and version**, the versions of **Python** and **PyTorch**.
* A short, self-contained, code snippet that allows us to reproduce the bug in
less than 30s.
* The full traceback if an exception is raised.
* Attach any other additional information, like screenshots, you think may help.
### Do you want a new feature?
A good feature request addresses the following points:
1. Motivation first:
* Is it related to a problem/frustration with the library? If so, please explain
why. Providing a code snippet that demonstrates the problem is best.
* Is it related to something you would need for a project? We'd love to hear
about it!
* Is it something you worked on and think could benefit the community?
Awesome! Tell us what problem it solved for you.
2. Write a *paragraph* describing the feature.
3. Provide a **code snippet** that demonstrates its future use.
4. In case this is related to a paper, please attach a link.
5. Attach any additional information (drawings, screenshots, etc.) you think may help.
If your issue is well written we're already 80% of the way there by the time you
post it.
## Adding new policies, datasets or environments
Look at our implementations for [datasets](./lerobot/common/datasets/), [policies](./lerobot/common/policies/),
environments ([aloha](https://github.com/huggingface/gym-aloha),
[xarm](https://github.com/huggingface/gym-xarm),
[pusht](https://github.com/huggingface/gym-pusht))
and follow the same api design.
When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
- Update `available_datasets` in `lerobot/__init__.py`
- Copy it in the required `available_datasets` class attribute
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
- Update `available_policies` in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class
## Submitting a pull request (PR)
Before writing code, we strongly advise you to search through the existing PRs or
issues to make sure that nobody is already working on the same thing. If you are
unsure, it is always a good idea to open an issue to get some feedback.
You will need basic `git` proficiency to be able to contribute to
🤗 LeRobot. `git` is not the easiest tool to use but it has the greatest
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference.
Follow these steps to start contributing:
1. Fork the [repository](https://github.com/huggingface/lerobot) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
```bash
git clone git@github.com:<your Github handle>/lerobot.git
cd lerobot
git remote add upstream https://github.com/huggingface/lerobot.git
```
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
```bash
git checkout main
git fetch upstream
git rebase upstream/main
```
Once your `main` branch is synchronized, create a new branch from it:
```bash
git checkout -b a-descriptive-name-for-my-changes
```
🚨 **Do not** work on the `main` branch.
4. Instead of using `pip` directly, we use `poetry` for development purposes to easily track our dependencies.
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
Install the project with dev dependencies and all environments:
```bash
poetry install --sync --with dev --all-extras
```
This command should be run when pulling code with and updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the dependencies.
To selectively install environments (for example aloha and pusht) use:
```bash
poetry install --sync --with dev --extras "aloha pusht"
```
The equivalent of `pip install some-package`, would just be:
```bash
poetry add some-package
```
When changes are made to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
```bash
poetry lock --no-update
```
**NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated to the CI. To do this, you should create a separate environment and add the new package there as well. For example:
```bash
# Add the new package to your main poetry env
poetry add some-package
# Add the same package to the CPU-only env dedicated to CI
conda create -y -n lerobot-ci python=3.10
conda activate lerobot-ci
cd .github/poetry/cpu
poetry add some-package
```
5. Develop the features on your branch.
As you work on the features, you should make sure that the test suite
passes. You should run the tests impacted by your changes like this (see
below an explanation regarding the environment variable):
```bash
pytest tests/<TEST_TO_RUN>.py
```
6. Follow our style.
`lerobot` relies on `ruff` to format its source code
consistently. Set up [`pre-commit`](https://pre-commit.com/) to run these checks
automatically as Git commit hooks.
Install `pre-commit` hooks:
```bash
pre-commit install
```
You can run these hooks whenever you need on staged files with:
```bash
pre-commit
```
Once you're happy with your changes, add changed files using `git add` and
make a commit with `git commit` to record your changes locally:
```bash
git add modified_file.py
git commit
```
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
It is a good idea to sync your copy of the code with the original
repository regularly. This way you can quickly account for changes:
```bash
git fetch upstream
git rebase upstream/main
```
Push the changes to your account using:
```bash
git push -u origin a-descriptive-name-for-my-changes
```
6. Once you are satisfied (**and the checklist below is happy too**), go to the
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
to the project maintainers for review.
7. It's ok if maintainers ask you for changes. It happens to core contributors
too! So everyone can see the changes in the Pull request, work in your local
branch and push the changes to your fork. They will automatically appear in
the pull request.
### Checklist
1. The title of your pull request should be a summary of its contribution;
2. If your pull request addresses an issue, please mention the issue number in
the pull request description to make sure they are linked (and people
consulting the issue know you are working on it);
3. To indicate a work in progress please prefix the title with `[WIP]`, or preferably mark
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
it from PRs ready to be merged;
4. Make sure existing tests pass;
<!-- 5. Add high-coverage tests. No quality testing = no merge.
See an example of a good PR here: https://github.com/huggingface/lerobot/pull/ -->
### Tests
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in the [tests folder](https://github.com/huggingface/lerobot/tree/main/tests).
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
On Mac:
```bash
brew install git-lfs
git lfs install
```
On Ubuntu:
```bash
sudo apt-get install git-lfs
git lfs install
```
Pull artifacts if they're not in [tests/data](tests/data)
```bash
git lfs pull
```
We use `pytest` in order to run the tests. From the root of the
repository, here's how to run tests with `pytest` for the library:
```bash
DATA_DIR="tests/data" python -m pytest -sv ./tests
```
You can specify a smaller set of tests in order to test only the feature
you're working on.

229
LICENSE
View File

@@ -253,6 +253,31 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
## Some of lerobot's code is derived from simxarm, which is subject to the following copyright notice:
MIT License
Copyright (c) 2023 Nicklas Hansen & Yanjie Ze
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
## Some of lerobot's code is derived from ALOHA, which is subject to the following copyright notice:
MIT License
@@ -276,3 +301,207 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
## Some of lerobot's code is derived from DETR, which is subject to the following copyright notice:
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2020 - present, Facebook, Inc
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

380
README.md
View File

@@ -1,72 +1,297 @@
# LeRobot
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="media/lerobot-logo-thumbnail.png">
<source media="(prefers-color-scheme: light)" srcset="media/lerobot-logo-thumbnail.png">
<img alt="LeRobot, Hugging Face Robotics Library" src="media/lerobot-logo-thumbnail.png" style="max-width: 100%;">
</picture>
<br/>
<br/>
</p>
<div align="center">
[![Tests](https://github.com/huggingface/lerobot/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/test.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/huggingface/lerobot/branch/main/graph/badge.svg?token=TODO)](https://codecov.io/gh/huggingface/lerobot)
[![Python versions](https://img.shields.io/pypi/pyversions/lerobot)](https://www.python.org/downloads/)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/huggingface/lerobot/blob/main/LICENSE)
[![Status](https://img.shields.io/pypi/status/lerobot)](https://pypi.org/project/lerobot/)
[![Version](https://img.shields.io/pypi/v/lerobot)](https://pypi.org/project/lerobot/)
[![Examples](https://img.shields.io/badge/Examples-green.svg)](https://github.com/huggingface/lerobot/tree/main/examples)
[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v2.1%20adopted-ff69b4.svg)](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md)
[![Discord](https://dcbadge.vercel.app/api/server/C5P34WJ68S?style=flat)](https://discord.gg/s3KuuzsPFb)
</div>
<h3 align="center">
<p>State-of-the-art Machine Learning for real-world robotics</p>
</h3>
---
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulated environments so that everyone can get started. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
🤗 LeRobot hosts pretrained models and datasets on this HuggingFace community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
#### Examples of pretrained models and environments
<table>
<tr>
<td><img src="http://remicadene.com/assets/gif/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
<td><img src="http://remicadene.com/assets/gif/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
<td><img src="http://remicadene.com/assets/gif/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
</tr>
<tr>
<td align="center">ACT policy on ALOHA env</td>
<td align="center">TDMPC policy on SimXArm env</td>
<td align="center">Diffusion policy on PushT env</td>
</tr>
</table>
### Acknowledgment
- ACT policy and ALOHA environment are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha/)
- Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
- TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
- Abstractions and utilities for Reinforcement Learning come from [TorchRL](https://github.com/pytorch/rl)
## Installation
Create a virtual environment with Python 3.10, e.g. using `conda`:
```
conda create -y -n lerobot python=3.10
conda activate lerobot
Download our source code:
```bash
git clone https://github.com/huggingface/lerobot.git && cd lerobot
```
[Install `poetry`](https://python-poetry.org/docs/#installation) (if you don't have it already)
```
curl -sSL https://install.python-poetry.org | python -
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
```bash
conda create -y -n lerobot python=3.10 && conda activate lerobot
```
Install dependencies
```
poetry install
Install 🤗 LeRobot:
```bash
python -m pip install .
```
If you encounter a disk space error, try to change your tmp dir to a location where you have enough disk space, e.g.
```
mkdir ~/tmp
export TMPDIR='~/tmp'
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
- [aloha](https://github.com/huggingface/gym-aloha)
- [xarm](https://github.com/huggingface/gym-xarm)
- [pusht](https://github.com/huggingface/gym-pusht)
For instance, to install 🤗 LeRobot with aloha and pusht, use:
```bash
python -m pip install ".[aloha, pusht]"
```
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
```
```bash
wandb login
```
## Usage
### Train
## Walkthrough
```
python lerobot/scripts/train.py \
hydra.job.name=pusht \
env=pusht
```
### Visualize offline buffer
.
├── lerobot
| ├── configs # contains hydra yaml files with all options that you can override in the command line
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, xarm.yaml
| | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml
| ├── common # contains classes and utilities
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
| | ├── envs # various sim environments: aloha, pusht, xarm
| | └── policies # various policies: act, diffusion, tdmpc
| └── scripts # contains functions to execute via command line
| ├── visualize_dataset.py # load a dataset and render its demonstrations
| ├── eval.py # load policy and evaluate it on an environment
| └── train.py # train a policy via imitation learning and/or reinforcement learning
├── outputs # contains results of scripts execution: logs, videos, model checkpoints
├── .github
| └── workflows
| └── test.yml # defines install settings for continuous integration and specifies end-to-end tests
└── tests # contains pytest utilities for continuous integration
```
### Visualize datasets
You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities:
```python
""" Copy pasted from `examples/1_visualize_dataset.py` """
import os
from pathlib import Path
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
# TODO(rcadene): remove DATA_DIR
dataset = AlohaDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
video_paths = render_dataset(
dataset,
out_dir="outputs/visualize_dataset/example",
max_num_episodes=1,
)
print(video_paths)
# ['outputs/visualize_dataset/example/episode_0.mp4']
```
Or you can achieve the same result by executing our script from the command line:
```bash
python lerobot/scripts/visualize_dataset.py \
hydra.run.dir=tmp/$(date +"%Y_%m_%d") \
env=pusht
env=pusht \
hydra.run.dir=outputs/visualize_dataset/example
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
```
### Visualize online buffer / Eval
### Evaluate a pretrained policy
```
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
Or you can achieve the same result by executing our script from the command line:
```bash
python lerobot/scripts/eval.py \
hydra.run.dir=tmp/$(date +"%Y_%m_%d") \
env=pusht
--hub-id lerobot/diffusion_policy_pusht_image \
eval_episodes=10 \
hydra.run.dir=outputs/eval/example_hub
```
After training your own policy, you can also re-evaluate the checkpoints with:
```bash
python lerobot/scripts/eval.py \
--config PATH/TO/FOLDER/config.yaml \
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
eval_episodes=10 \
hydra.run.dir=outputs/eval/example_dir
```
## TODO
See `python lerobot/scripts/eval.py --help` for more instructions.
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/users/Cadene/projects/1)
### Train your own policy
Ask [Remi Cadene](re.cadene@gmail.com) for access if needed.
You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): check out [example 3](./examples/3_train_policy.py). After you run this, you may want to revisit [example 2](./examples/2_evaluate_pretrained_policy.py) to evaluate your training output!
In general, you can use our training script to easily train any policy on any environment:
```bash
python lerobot/scripts/train.py \
env=aloha \
task=sim_insertion \
dataset_id=aloha_sim_insertion_scripted \
policy=act \
hydra.run.dir=outputs/train/aloha_act
```
## Contribute
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
### Add a new dataset
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then you can upload it to the hub with:
```bash
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
--repo-type dataset \
--revision v1.0
```
You will need to set the corresponding version as a default argument in your dataset class:
```python
version: str | None = "v1.1",
```
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
For instance, for [lerobot/pusht](https://huggingface.co/datasets/lerobot/pusht), we used:
```bash
HF_USER=lerobot
DATASET=pusht
```
If you want to improve an existing dataset, you can download it locally with:
```bash
mkdir -p data/$DATASET
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download ${HF_USER}/$DATASET \
--repo-type dataset \
--local-dir data/$DATASET \
--local-dir-use-symlinks=False \
--revision v1.0
```
Iterate on your code and dataset with:
```bash
DATA_DIR=data python train.py
```
Upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant):
```bash
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
--repo-type dataset \
--revision v1.1 \
--delete "*"
```
Then you will need to set the corresponding version as a default argument in your dataset class:
```python
version: str | None = "v1.1",
```
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
## Profile
Finally, you might want to mock the dataset if you need to update the unit tests as well:
```bash
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
```
**Example**
### Add a pretrained policy
Once you have trained a policy you may upload it to the HuggingFace hub.
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
Secondly, assuming you have trained a policy, you need:
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
```
to_upload
├── config.yaml
├── model.pt
└── stats.pth
```
With the folder prepared, run the following with a desired revision ID.
```bash
huggingface-cli upload $HUB_ID to_upload --revision $REVISION_ID
```
If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers):
```bash
huggingface-cli upload $HUB_ID to_upload
```
See `eval.py` for an example of how a user may use your policy.
### Improve your code with profiling
An example of a code snippet to profile the evaluation of a policy:
```python
from torch.profiler import profile, record_function, ProfilerActivity
@@ -85,87 +310,12 @@ with profile(
with record_function("eval_policy"):
for i in range(num_episodes):
prof.step()
# insert code to profile, potentially whole body of eval_policy function
```
```bash
python lerobot/scripts/eval.py \
pretrained_model_path=/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt \
--config outputs/pusht/.hydra/config.yaml \
pretrained_model_path=outputs/pusht/model.pt \
eval_episodes=7
```
## Contribute
**Style**
```
# install if needed
pre-commit install
# apply style and linter checks before git commit
pre-commit run -a
```
**Adding dependencies (temporary)**
Right now, for the CI to work, whenever a new dependency is added it needs to be also added to the cpu env, eg:
```
# Run in this directory, adds the package to the main env with cuda
poetry add some-package
# Adds the same package to the cpu env
cd .github/poetry/cpu && poetry add some-package
```
**Tests**
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
On Mac:
```
brew install git-lfs
git lfs install
```
On Ubuntu:
```
sudo apt-get install git-lfs
git lfs install
```
Pull artifacts if they're not in [tests/data](tests/data)
```
git lfs pull
```
When adding a new dataset, mock it with
```
python tests/scripts/mock_dataset.py --in-data-dir data/<dataset_id> --out-data-dir tests/data/<dataset_id>
```
Run tests
```
DATA_DIR="tests/data" pytest -sx tests
```
**Datasets**
To add a pytorch rl dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
```
huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential
```
Then you can upload it to the hub with:
```
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload --repo-type dataset $HF_USER/$DATASET data/$DATASET
```
For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used:
```
HF_USER=cadene
DATASET=pusht
```
## Acknowledgment
- Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
- Our TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
- Our ACT policy and ALOHA environment are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha/)

View File

@@ -0,0 +1,656 @@
"""
This file contains all obsolete download scripts. They are centralized here to not have to load
useless dependencies when using datasets.
"""
import io
import json
import pickle
import shutil
from pathlib import Path
import einops
import h5py
import numpy as np
import torch
import tqdm
from datasets import Dataset, Features, Image, Sequence, Value
from huggingface_hub import HfApi
from PIL import Image as PILImage
from safetensors.torch import save_file
from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch
def download_and_upload(root, revision, dataset_id):
if "pusht" in dataset_id:
download_and_upload_pusht(root, revision, dataset_id)
elif "xarm" in dataset_id:
download_and_upload_xarm(root, revision, dataset_id)
elif "aloha_sim" in dataset_id:
download_and_upload_aloha(root, revision, dataset_id)
elif "aloha_mobile" in dataset_id:
download_and_upload_aloha_mobile(root, revision, dataset_id)
else:
raise ValueError(dataset_id)
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
import zipfile
import requests
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
return True
else:
return False
def concatenate_episodes(ep_dicts):
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
# push to main to indicate latest version
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
# push to version branch
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision)
# create and store meta_data
meta_data_dir = root / dataset_id / "meta_data"
meta_data_dir.mkdir(parents=True, exist_ok=True)
api = HfApi()
# info
info_path = meta_data_dir / "info.json"
with open(str(info_path), "w") as f:
json.dump(info, f, indent=4)
api.upload_file(
path_or_fileobj=info_path,
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=info_path,
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
revision=revision,
)
# stats
stats_path = meta_data_dir / "stats.safetensors"
save_file(flatten_dict(stats), stats_path)
api.upload_file(
path_or_fileobj=stats_path,
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=stats_path,
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
revision=revision,
)
# episode_data_index
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
save_file(episode_data_index, ep_data_idx_path)
api.upload_file(
path_or_fileobj=ep_data_idx_path,
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=ep_data_idx_path,
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
revision=revision,
)
# copy in tests folder, the first episode and the meta_data directory
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
f"tests/data/{dataset_id}/train"
)
if Path(f"tests/data/{dataset_id}/meta_data").exists():
shutil.rmtree(f"tests/data/{dataset_id}/meta_data")
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
except ModuleNotFoundError as e:
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
raise e
# as define in env
success_threshold = 0.95 # 95% coverage,
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
pusht_zarr = Path("pusht/pusht_cchi_v7_replay.zarr")
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
zarr_path = (raw_dir / pusht_zarr).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
download_and_extract_zip(pusht_url, raw_dir)
# load
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
assert len(
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames."
# TODO: verify that goal pose is expected to be fixed
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
imgs = torch.from_numpy(dataset_dict["img"]) # b h w c
states = torch.from_numpy(dataset_dict["state"])
actions = torch.from_numpy(dataset_dict["action"])
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
id_to = dataset_dict.meta["episode_ends"][episode_id]
num_frames = id_to - id_from
assert (episode_ids[id_from:id_to] == episode_id).all()
image = imgs[id_from:id_to]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
state = states[id_from:id_to]
agent_pos = state[:, :2]
block_pos = state[:, 2:4]
block_angle = state[:, 4]
reward = torch.zeros(num_frames)
success = torch.zeros(num_frames, dtype=torch.bool)
done = torch.zeros(num_frames, dtype=torch.bool)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0
space.damping = 0
# Add walls.
walls = [
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage = intersection_area / goal_area
reward[i] = np.clip(coverage / success_threshold, 0, 1)
success[i] = coverage > success_threshold
# last step of demonstration is considered done
done[-1] = True
ep_dict = {
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": agent_pos,
"action": actions[id_from:id_to],
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": image[1:],
# "next.observation.state": agent_pos[1:],
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
"next.done": torch.cat([done[1:], done[[-1]]]),
"next.success": torch.cat([success[1:], success[[-1]]]),
}
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
data_dict = concatenate_episodes(ep_dicts)
features = {
"observation.image": Image(),
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
"next.success": Value(dtype="bool", id=None),
"index": Value(dtype="int64", id=None),
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
root = Path(root)
raw_dir = root / "xarm_datasets_raw"
if not raw_dir.exists():
import zipfile
import gdown
raw_dir.mkdir(parents=True, exist_ok=True)
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
zip_path = raw_dir / "data.zip"
gdown.download(url, str(zip_path), quiet=False)
print("Extracting...")
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
for member in zip_f.namelist():
if member.startswith("data/xarm") and member.endswith(".pkl"):
print(member)
zip_f.extract(member=member)
zip_path.unlink()
dataset_path = root / f"{dataset_id}" / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
id_to = 0
episode_id = 0
total_frames = dataset_dict["actions"].shape[0]
for i in tqdm.tqdm(range(total_frames)):
id_to += 1
if not dataset_dict["dones"][i]:
continue
num_frames = id_to - id_from
image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to])
image = einops.rearrange(image, "b c h w -> b h w c")
state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to])
action = torch.tensor(dataset_dict["actions"][id_from:id_to])
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to])
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to])
next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to])
next_done = torch.tensor(dataset_dict["dones"][id_from:id_to])
ep_dict = {
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": state,
"action": action,
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": next_image,
# "next.observation.state": next_state,
"next.reward": next_reward,
"next.done": next_done,
}
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from = id_to
episode_id += 1
data_dict = concatenate_episodes(ep_dicts)
features = {
"observation.image": Image(),
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_aloha(root, revision, dataset_id, fps=50):
folder_urls = {
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
}
ep48_urls = {
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
}
ep49_urls = {
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
}
num_episodes = {
"aloha_sim_insertion_human": 50,
"aloha_sim_insertion_scripted": 50,
"aloha_sim_transfer_cube_human": 50,
"aloha_sim_transfer_cube_scripted": 50,
}
episode_len = {
"aloha_sim_insertion_human": 500,
"aloha_sim_insertion_scripted": 400,
"aloha_sim_transfer_cube_human": 400,
"aloha_sim_transfer_cube_scripted": 400,
}
cameras = {
"aloha_sim_insertion_human": ["top"],
"aloha_sim_insertion_scripted": ["top"],
"aloha_sim_transfer_cube_human": ["top"],
"aloha_sim_transfer_cube_scripted": ["top"],
}
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
if not raw_dir.is_dir():
import gdown
assert dataset_id in folder_urls
assert dataset_id in ep48_urls
assert dataset_id in ep49_urls
raw_dir.mkdir(parents=True, exist_ok=True)
gdown.download_folder(folder_urls[dataset_id], output=str(raw_dir))
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True)
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
num_frames = ep["/action"].shape[0]
assert episode_len[dataset_id] == num_frames
# last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
ep_dict = {}
for cam in cameras[dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
# image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
# ep_dict[f"next.observation.images.{cam}"] = image
ep_dict.update(
{
"observation.state": state,
"action": action,
"episode_index": torch.tensor([ep_id] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward,
"next.done": done,
# "next.success": success,
}
)
assert isinstance(ep_id, int)
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
data_dict = concatenate_episodes(ep_dicts)
features = {
"observation.images.top": Image(),
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
#'next.reward': Value(dtype='float32', id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_aloha_mobile(root, revision, dataset_id, fps=50):
num_episodes = {
"aloha_mobile_trossen_block_handoff": 5,
}
# episode_len = {
# "aloha_sim_insertion_human": 500,
# "aloha_sim_insertion_scripted": 400,
# "aloha_sim_transfer_cube_human": 400,
# "aloha_sim_transfer_cube_scripted": 400,
# }
cameras = {
"aloha_mobile_trossen_block_handoff": ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
}
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
num_frames = ep["/action"].shape[0]
#assert episode_len[dataset_id] == num_frames
# last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
state = torch.from_numpy(ep["/observations/qpos"][:num_frames])
action = torch.from_numpy(ep["/action"][:num_frames])
ep_dict = {}
for cam in cameras[dataset_id]:
image = ep[f"/observations/images/{cam}"][:num_frames] # b h w c
import cv2
# un-pad and uncompress from: https://github.com/MarkFzp/act-plus-plus/blob/26bab0789d05b7496bacef04f5c6b2541a4403b5/postprocess_episodes.py#L50
image = np.array([cv2.imdecode(x, 1) for x in image])
image = [PILImage.fromarray(x) for x in image]
ep_dict[f"observation.images.{cam}"] = image
ep_dict.update(
{
"observation.state": state,
"action": action,
"episode_index": torch.tensor([ep_id] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward,
"next.done": done,
# "next.success": success,
}
)
assert isinstance(ep_id, int)
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
break
data_dict = concatenate_episodes(ep_dicts)
features = {}
for cam in cameras[dataset_id]:
features[f"observation.images.{cam}"] = Image()
features.update({
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
#'next.reward': Value(dtype='float32', id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
})
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
if __name__ == "__main__":
root = "data"
revision = "v1.1"
dataset_ids = [
# "pusht",
# "xarm_lift_medium",
# "xarm_lift_medium_replay",
# "xarm_push_medium",
# "xarm_push_medium_replay",
# "aloha_sim_insertion_human",
# "aloha_sim_insertion_scripted",
# "aloha_sim_transfer_cube_human",
# "aloha_sim_transfer_cube_scripted",
"aloha_mobile_trossen_block_handoff",
]
for dataset_id in dataset_ids:
download_and_upload(root, revision, dataset_id)

View File

@@ -0,0 +1,67 @@
"""
This script demonstrates the visualization of various robotic datasets from Hugging Face hub.
It covers the steps from loading the datasets, filtering specific episodes, and converting the frame data to MP4 videos.
Importantly, the dataset format is agnostic to any deep learning library and doesn't require using `lerobot` functions.
It is compatible with pytorch, jax, numpy, etc.
As an example, this script saves frames of episode number 5 of the PushT dataset to a mp4 video and saves the result here:
`outputs/examples/1_visualize_hugging_face_datasets/episode_5.mp4`
This script supports several Hugging Face datasets, among which:
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
3. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay)
4. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium)
5. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay)
6. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
7. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
8. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
9. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
To try a different Hugging Face dataset, you can replace this line:
```python
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
```
by one of these:
```python
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium_replay", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium_replay", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
```
"""
from pathlib import Path
import imageio
from datasets import load_dataset
# TODO(rcadene): list available datasets on lerobot page using `datasets`
# download/load hugging face dataset in pyarrow format
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
# display name of dataset and its features
print(f"{hf_dataset=}")
print(f"{hf_dataset.features=}")
# display useful statistics about frames and episodes, which are sequences of frames from the same video
print(f"number of frames: {len(hf_dataset)=}")
print(f"number of episodes: {len(hf_dataset.unique('episode_index'))=}")
print(
f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_index')):.3f}"
)
# select the frames belonging to episode number 5
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# load all frames of episode 5 in RAM in PIL format
frames = hf_dataset["observation.image"]
# save episode frames to a mp4 video
Path("outputs/examples/1_load_hugging_face_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/1_load_hugging_face_dataset/episode_5.mp4", frames, fps=fps)

View File

@@ -0,0 +1,102 @@
"""
This script demonstrates the use of the PushtDataset class for handling and processing robotic datasets from Hugging Face.
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
Features included in this script:
- Loading a dataset and accessing its properties.
- Filtering data by episode number.
- Converting tensor data for visualization.
- Saving video files from dataset frames.
- Using advanced dataset features like timestamp-based frame selection.
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
The script ends with examples of how to batch process data using PyTorch's DataLoader.
To try a different Hugging Face dataset, you can replace:
```python
dataset = PushtDataset()
```
by one of these:
```python
dataset = XarmDataset("xarm_lift_medium")
dataset = XarmDataset("xarm_lift_medium_replay")
dataset = XarmDataset("xarm_push_medium")
dataset = XarmDataset("xarm_push_medium_replay")
dataset = AlohaDataset("aloha_sim_insertion_human")
dataset = AlohaDataset("aloha_sim_insertion_scripted")
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
dataset = AlohaDataset("aloha_sim_transfer_cube_scripted")
```
"""
from pathlib import Path
import imageio
import torch
from lerobot.common.datasets.pusht import PushtDataset
# TODO(rcadene): List available datasets and their dataset ids (e.g. PushtDataset, AlohaDataset(dataset_id="aloha_sim_insertion_human"))
# print("List of available datasets", lerobot.available_datasets)
# # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted',
# # 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted',
# # 'pusht', 'xarm_lift_medium']
# You can easily load datasets from LeRobot
dataset = PushtDataset()
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
print(f"{dataset=}")
print(f"{dataset.hf_dataset=}")
# and provide additional utilities for robotics and compatibility with pytorch
print(f"number of samples/frames: {dataset.num_samples=}")
print(f"number of episodes: {dataset.num_episodes=}")
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}")
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
frames = [sample["observation.image"] for sample in dataset]
# but frames are now float32 range [0,1] channel first to follow pytorch convention,
# to view them, we convert to uint8 range [0,255] channel last
frames = [(frame * 255).type(torch.uint8) for frame in frames]
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
# and finally save them to a mp4 video
Path("outputs/examples/2_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/2_load_lerobot_dataset/episode_5.mp4", frames, fps=dataset.fps)
# For many machine learning applications we need to load histories of past observations, or trajectorys of future actions. Our datasets can load previous and future frames for each key/modality,
# using timestamps differences with the current loaded frame. For instance:
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
"observation.image": [-1, -0.5, -0.20, 0],
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / dataset.fps for t in range(64)],
}
dataset = PushtDataset(delta_timestamps=delta_timestamps)
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
print(f"{dataset[0]['action'].shape=}") # (64,c)
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers
# because they are just PyTorch datasets.
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=32,
shuffle=True,
)
for batch in dataloader:
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
print(f"{batch['observation.state'].shape=}") # (32,8,c)
print(f"{batch['action'].shape=}") # (32,64,c)
break

View File

@@ -0,0 +1,40 @@
"""
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
"""
from pathlib import Path
from huggingface_hub import snapshot_download
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.eval import eval
# Get a pretrained policy from the hub.
# TODO(alexander-soare): This no longer works until we upload a new model that uses the current configs.
hub_id = "lerobot/diffusion_policy_pusht_image"
folder = Path(snapshot_download(hub_id))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# folder = Path("outputs/train/example_pusht_diffusion")
config_path = folder / "config.yaml"
weights_path = folder / "model.pt"
stats_path = folder / "stats.pth" # normalization stats
# Override some config parameters to do with evaluation.
overrides = [
f"policy.pretrained_model_path={weights_path}",
"eval_episodes=10",
"rollout_batch_size=10",
"device=cuda",
]
# Create a Hydra config.
cfg = init_hydra_config(config_path, overrides)
# Evaluate the policy and save the outputs including metrics and videos.
eval(
cfg,
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
)

View File

@@ -0,0 +1,68 @@
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
Once you have trained a model with this script, you can try to evaluate it on
examples/2_evaluate_pretrained_policy.py
"""
import os
from pathlib import Path
import torch
from omegaconf import OmegaConf
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.utils.utils import init_hydra_config
output_directory = Path("outputs/train/example_pusht_diffusion")
os.makedirs(output_directory, exist_ok=True)
# Number of offline training steps (we'll only do offline training for this example.
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
training_steps = 5000
device = torch.device("cuda")
log_freq = 250
# Set up the dataset.
hydra_cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
dataset = make_dataset(hydra_cfg)
# Set up the the policy.
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
# For this example, no arguments need to be passed because the defaults are set up for PushT.
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
# TODO(alexander-soare): Remove LR scheduler from the policy.
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
policy.train()
policy.to(device)
# Create dataloader for offline training.
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=cfg.batch_size,
shuffle=True,
pin_memory=device != torch.device("cpu"),
drop_last=True,
)
# Run training loop.
step = 0
done = False
while not done:
for batch in dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
info = policy.update(batch)
if step % log_freq == 0:
print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
step += 1
if step >= training_steps:
done = True
break
# Save the policy, configuration, and normalization stats for later use.
policy.save(output_directory / "model.pt")
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")

View File

@@ -1 +1,80 @@
"""
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
Example:
```python
import lerobot
print(lerobot.available_envs)
print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets)
print(lerobot.available_policies)
print(lerobot.available_policies_per_env)
```
When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
- Update `available_datasets` in `lerobot/__init__.py`
- Set the required `available_datasets` class attribute using the previously updated `lerobot.available_datasets`
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
- Update `available_policies` in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class
"""
from lerobot.__version__ import __version__ # noqa: F401
available_envs = [
"aloha",
"pusht",
"xarm",
]
available_tasks_per_env = {
"aloha": [
"AlohaInsertion-v0",
"AlohaTransferCube-v0",
],
"pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"],
}
available_datasets = {
"aloha": [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
],
"pusht": ["pusht"],
"xarm": [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
],
}
available_policies = [
"act",
"diffusion",
"tdmpc",
]
available_policies_per_env = {
"aloha": ["act"],
"pusht": ["diffusion"],
"xarm": ["tdmpc"],
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
env_dataset_pairs = [(env, dataset) for env, datasets in available_datasets.items() for dataset in datasets]
env_dataset_policy_triplets = [
(env, dataset, policy)
for env, datasets in available_datasets.items()
for dataset in datasets
for policy in available_policies_per_env[env]
]

View File

@@ -1,4 +1,4 @@
""" To enable `lerobot.__version__` """
"""To enable `lerobot.__version__`"""
from importlib.metadata import PackageNotFoundError, version

View File

@@ -1,3 +1,8 @@
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.
"""
from __future__ import annotations
import math

View File

@@ -1,159 +0,0 @@
import logging
from pathlib import Path
from typing import Callable
import einops
import torch
import torchrl
import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from torchrl.envs.transforms.transforms import Compose
class AbstractExperienceReplay(TensorDictReplayBuffer):
def __init__(
self,
dataset_id: str,
batch_size: int = None,
*,
shuffle: bool = True,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
transform: "torchrl.envs.Transform" = None,
):
self.dataset_id = dataset_id
self.shuffle = shuffle
self.root = root
storage = self._download_or_load_dataset()
super().__init__(
storage=storage,
sampler=sampler,
writer=ImmutableDatasetWriter() if writer is None else writer,
collate_fn=_collate_id if collate_fn is None else collate_fn,
pin_memory=pin_memory,
prefetch=prefetch,
batch_size=batch_size,
transform=transform,
)
@property
def stats_patterns(self) -> dict:
return {
("observation", "state"): "b c -> 1 c",
("observation", "image"): "b c h w -> 1 c 1 1",
("action",): "b c -> 1 c",
}
@property
def image_keys(self) -> list:
return [("observation", "image")]
@property
def num_cameras(self) -> int:
return len(self.image_keys)
@property
def num_samples(self) -> int:
return len(self)
@property
def num_episodes(self) -> int:
return len(self._storage._storage["episode"].unique())
@property
def transform(self):
return self._transform
def set_transform(self, transform):
if not isinstance(transform, Compose):
# required since torchrl calls `len(self._transform)` downstream
if isinstance(transform, list):
self._transform = Compose(*transform)
else:
self._transform = Compose(transform)
else:
self._transform = transform
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
stats_path = self.data_dir / "stats.pth"
if stats_path.exists():
stats = torch.load(stats_path)
else:
logging.info(f"compute_stats and save to {stats_path}")
stats = self._compute_stats(num_batch, batch_size)
torch.save(stats, stats_path)
return stats
def _download_or_load_dataset(self) -> torch.StorageBase:
if self.root is None:
self.data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")
else:
self.data_dir = self.root / self.dataset_id
return TensorStorage(TensorDict.load_memmap(self.data_dir))
def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer(
storage=self._storage,
batch_size=batch_size,
prefetch=True,
)
mean, std, max, min = {}, {}, {}, {}
# compute mean, min, max
for _ in tqdm.tqdm(range(num_batch)):
batch = rb.sample()
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
if key not in mean:
# first batch initialize mean, min, max
mean[key] = einops.reduce(batch[key], pattern, "mean")
max[key] = einops.reduce(batch[key], pattern, "max")
min[key] = einops.reduce(batch[key], pattern, "min")
else:
mean[key] += einops.reduce(batch[key], pattern, "mean")
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
batch = rb.sample()
for key in self.stats_patterns:
mean[key] /= num_batch
# compute std, min, max
for _ in tqdm.tqdm(range(num_batch)):
batch = rb.sample()
for key, pattern in self.stats_patterns.items():
batch[key] = batch[key].float()
batch_mean = einops.reduce(batch[key], pattern, "mean")
if key not in std:
# first batch initialize std
std[key] = (batch_mean - mean[key]) ** 2
else:
std[key] += (batch_mean - mean[key]) ** 2
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
for key in self.stats_patterns:
std[key] = torch.sqrt(std[key] / num_batch)
stats = TensorDict({}, batch_size=[])
for key in self.stats_patterns:
stats[(*key, "mean")] = mean[key]
stats[(*key, "std")] = std[key]
stats[(*key, "max")] = max[key]
stats[(*key, "min")] = min[key]
if key[0] == "observation":
# use same stats for the next observations
stats[("next", *key)] = stats[key]
return stats

View File

@@ -1,183 +1,78 @@
import logging
from pathlib import Path
from typing import Callable
import einops
import gdown
import h5py
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
DATASET_IDS = [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
FOLDER_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
}
EP48_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
}
EP49_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
}
NUM_EPISODES = {
"aloha_sim_insertion_human": 50,
"aloha_sim_insertion_scripted": 50,
"aloha_sim_transfer_cube_human": 50,
"aloha_sim_transfer_cube_scripted": 50,
}
EPISODE_LEN = {
"aloha_sim_insertion_human": 500,
"aloha_sim_insertion_scripted": 400,
"aloha_sim_transfer_cube_human": 400,
"aloha_sim_transfer_cube_scripted": 400,
}
CAMERAS = {
"aloha_sim_insertion_human": ["top"],
"aloha_sim_insertion_scripted": ["top"],
"aloha_sim_transfer_cube_human": ["top"],
"aloha_sim_transfer_cube_scripted": ["top"],
}
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
def download(data_dir, dataset_id):
assert dataset_id in DATASET_IDS
assert dataset_id in FOLDER_URLS
assert dataset_id in EP48_URLS
assert dataset_id in EP49_URLS
class AlohaDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted
"""
data_dir.mkdir(parents=True, exist_ok=True)
# Copied from lerobot/__init__.py
available_datasets = [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
fps = 50
image_keys = ["observation.images.top"]
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
class AlohaExperienceReplay(AbstractExperienceReplay):
def __init__(
self,
dataset_id: str,
batch_size: int = None,
*,
shuffle: bool = True,
version: str | None = "v1.1",
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
transform: "torchrl.envs.Transform" = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
assert dataset_id in DATASET_IDS
super().__init__(
dataset_id,
batch_size,
shuffle=shuffle,
root=root,
pin_memory=pin_memory,
prefetch=prefetch,
sampler=sampler,
collate_fn=collate_fn,
writer=writer,
transform=transform,
)
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def stats_patterns(self) -> dict:
d = {
("observation", "state"): "b c -> 1 c",
("action",): "b c -> 1 c",
}
for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1"
return d
def num_samples(self) -> int:
return len(self.hf_dataset)
@property
def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_index"))
def _download_and_preproc_obsolete(self):
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
if not raw_dir.is_dir():
download(raw_dir, self.dataset_id)
def __len__(self):
return self.num_samples
total_num_frames = 0
logging.info("Compute total number of frames to initialize offline buffer")
for ep_id in range(NUM_EPISODES[self.dataset_id]):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
total_num_frames += ep["/action"].shape[0] - 1
logging.info(f"{total_num_frames=}")
def __getitem__(self, idx):
item = self.hf_dataset[idx]
logging.info("Initialize and feed offline buffer")
idxtd = 0
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
ep_num_frames = ep["/action"].shape[0]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
# last step of demonstration is considered done
done = torch.zeros(ep_num_frames, 1, dtype=torch.bool)
done[-1] = True
if self.transform is not None:
item = self.transform(item)
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
ep_td = TensorDict(
{
("observation", "state"): state[:-1],
"action": action[:-1],
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
("next", "observation", "state"): state[1:],
# TODO: compute reward and success
# ("next", "reward"): reward[1:],
("next", "done"): done[1:],
# ("next", "success"): success[1:],
},
batch_size=ep_num_frames - 1,
)
for cam in CAMERAS[self.dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_td["observation", "image", cam] = image[:-1]
ep_td["next", "observation", "image", cam] = image[1:]
if ep_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")
td_data[idxtd : idxtd + len(ep_td)] = ep_td
idxtd = idxtd + len(ep_td)
return TensorStorage(td_data.lock_())
return item

View File

@@ -1,131 +1,90 @@
import logging
import os
from pathlib import Path
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
from torchvision.transforms import v2
from lerobot.common.envs.transforms import NormalizeTransform, Prod
from lerobot.common.transforms import NormalizeTransform
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
# to load a subset of our datasets for faster continuous integration.
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_offline_buffer(
cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None
def make_dataset(
cfg,
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
normalize=True,
stats_path=None,
split="train",
):
if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0
batch_size = None
pin_memory = False
prefetch = None
else:
assert cfg.online_steps == 0
num_slices = cfg.policy.batch_size
batch_size = cfg.policy.horizon * num_slices
pin_memory = cfg.device == "cuda"
prefetch = cfg.prefetch
if cfg.env.name == "xarm":
from lerobot.common.datasets.xarm import XarmDataset
if overwrite_batch_size is not None:
batch_size = overwrite_batch_size
if overwrite_prefetch is not None:
prefetch = overwrite_prefetch
if overwrite_sampler is None:
# TODO(rcadene): move batch_size outside
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
if cfg.offline_prioritized_sampler:
logging.info("use prioritized sampler for offline dataset")
sampler = PrioritizedSliceSampler(
max_capacity=100_000,
alpha=cfg.policy.per_alpha,
beta=cfg.policy.per_beta,
num_slices=num_traj_per_batch,
strict_length=False,
)
else:
logging.info("use simple sampler for offline dataset")
sampler = SliceSampler(
num_slices=num_traj_per_batch,
strict_length=False,
)
else:
sampler = overwrite_sampler
if cfg.env.name == "simxarm":
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
clsfunc = SimxarmExperienceReplay
dataset_id = f"xarm_{cfg.env.task}_medium"
clsfunc = XarmDataset
elif cfg.env.name == "pusht":
from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.pusht import PushtDataset
clsfunc = PushtExperienceReplay
dataset_id = "pusht"
clsfunc = PushtDataset
elif cfg.env.name == "aloha":
from lerobot.common.datasets.aloha import AlohaExperienceReplay
from lerobot.common.datasets.aloha import AlohaDataset
clsfunc = AlohaExperienceReplay
dataset_id = f"aloha_{cfg.env.task}"
clsfunc = AlohaDataset
else:
raise ValueError(cfg.env.name)
offline_buffer = clsfunc(
dataset_id=dataset_id,
sampler=sampler,
batch_size=batch_size,
root=DATA_DIR,
pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
if cfg.policy.name == "tdmpc":
img_keys = []
for key in offline_buffer.image_keys:
img_keys.append(("next", *key))
img_keys += offline_buffer.image_keys
else:
img_keys = offline_buffer.image_keys
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
transforms = None
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
# we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following)
in_keys = [("observation", "state"), ("action")]
if cfg.policy.name == "tdmpc":
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
in_keys += img_keys
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
in_keys += [("next", *key) for key in img_keys]
in_keys.append(("next", "observation", "state"))
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
offline_buffer.set_transform(transforms)
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {}
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation.state"] = {}
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
stats["action"] = {}
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
elif stats_path is None:
# load a first dataset to access precomputed stats
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
)
stats = stats_dataset.stats
else:
stats = torch.load(stats_path)
if not overwrite_sampler:
index = torch.arange(0, offline_buffer.num_samples, 1)
sampler.extend(index)
transforms = v2.Compose(
[
NormalizeTransform(
stats,
in_keys=[
"observation.state",
"action",
],
mode=normalization_mode,
),
]
)
return offline_buffer
delta_timestamps = cfg.policy.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
delta_timestamps[key] = eval(delta_timestamps[key])
dataset = clsfunc(
dataset_id=cfg.dataset_id,
split=split,
root=DATA_DIR,
delta_timestamps=delta_timestamps,
transform=transforms,
)
return dataset

View File

@@ -1,219 +1,76 @@
from pathlib import Path
from typing import Callable
import einops
import numpy as np
import pygame
import pymunk
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
from lerobot.common.datasets.utils import download_and_extract_zip
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
def get_goal_pose_body(pose):
mass = 1
inertia = pymunk.moment_for_box(mass, (50, 100))
body = pymunk.Body(mass, inertia)
# preserving the legacy assignment order for compatibility
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
body.position = pose[:2].tolist()
body.angle = pose[2]
return body
class PushtDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/pusht
Arguments
----------
delta_timestamps : dict[list[float]] | None, optional
Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image)
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
"""
def add_segment(space, a, b, radius):
shape = pymunk.Segment(space.static_body, a, b, radius)
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
return shape
# Copied from lerobot/__init__.py
available_datasets = ["pusht"]
fps = 10
image_keys = ["observation.image"]
def add_tee(
space,
position,
angle,
scale=30,
color="LightSlateGray",
mask=None,
):
if mask is None:
mask = pymunk.ShapeFilter.ALL_MASKS()
mass = 1
length = 4
vertices1 = [
(-length * scale / 2, scale),
(length * scale / 2, scale),
(length * scale / 2, 0),
(-length * scale / 2, 0),
]
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
vertices2 = [
(-scale / 2, scale),
(-scale / 2, length * scale),
(scale / 2, length * scale),
(scale / 2, scale),
]
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
body = pymunk.Body(mass, inertia1 + inertia2)
shape1 = pymunk.Poly(body, vertices1)
shape2 = pymunk.Poly(body, vertices2)
shape1.color = pygame.Color(color)
shape2.color = pygame.Color(color)
shape1.filter = pymunk.ShapeFilter(mask=mask)
shape2.filter = pymunk.ShapeFilter(mask=mask)
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
body.position = position
body.angle = angle
body.friction = 1
space.add(body, shape1, shape2)
return body
class PushtExperienceReplay(AbstractExperienceReplay):
def __init__(
self,
dataset_id: str,
batch_size: int = None,
*,
shuffle: bool = True,
dataset_id: str = "pusht",
version: str | None = "v1.1",
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
transform: "torchrl.envs.Transform" = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__(
dataset_id,
batch_size,
shuffle=shuffle,
root=root,
pin_memory=pin_memory,
prefetch=prefetch,
sampler=sampler,
collate_fn=collate_fn,
writer=writer,
transform=transform,
)
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
def _download_and_preproc_obsolete(self):
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True)
download_and_extract_zip(PUSHT_URL, raw_dir)
@property
def num_samples(self) -> int:
return len(self.hf_dataset)
# load
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
zarr_path
) # , keys=['img', 'state', 'action'])
@property
def num_episodes(self) -> int:
return len(self.episode_data_index["from"])
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
# to create test artifact
# num_episodes = 1
# total_frames = 50
assert len(
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames."
def __len__(self):
return self.num_samples
# TODO: verify that goal pose is expected to be fixed
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
goal_body = get_goal_pose_body(goal_pos_angle)
def __getitem__(self, idx):
item = self.hf_dataset[idx]
imgs = torch.from_numpy(dataset_dict["img"])
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
states = torch.from_numpy(dataset_dict["state"])
actions = torch.from_numpy(dataset_dict["action"])
idx0 = 0
idxtd = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id]
# to create test artifact
# idx1 = 51
num_frames = idx1 - idx0
assert (episode_ids[idx0:idx1] == episode_id).all()
image = imgs[idx0:idx1]
state = states[idx0:idx1]
agent_pos = state[:, :2]
block_pos = state[:, 2:4]
block_angle = state[:, 4]
reward = torch.zeros(num_frames, 1)
success = torch.zeros(num_frames, 1, dtype=torch.bool)
done = torch.zeros(num_frames, 1, dtype=torch.bool)
for i in range(num_frames):
space = pymunk.Space()
space.gravity = 0, 0
space.damping = 0
# Add walls.
walls = [
add_segment(space, (5, 506), (5, 5), 2),
add_segment(space, (5, 5), (506, 5), 2),
add_segment(space, (506, 5), (506, 506), 2),
add_segment(space, (5, 506), (506, 506), 2),
]
space.add(*walls)
block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item())
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage = intersection_area / goal_area
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
success[i] = coverage > SUCCESS_THRESHOLD
# last step of demonstration is considered done
done[-1] = True
ep_td = TensorDict(
{
("observation", "image"): image[:-1],
("observation", "state"): agent_pos[:-1],
"action": actions[idx0:idx1][:-1],
"episode": episode_ids[idx0:idx1][:-1],
"frame_id": torch.arange(0, num_frames - 1, 1),
("next", "observation", "image"): image[1:],
("next", "observation", "state"): agent_pos[1:],
# TODO: verify that reward and done are aligned with image and agent_pos
("next", "reward"): reward[1:],
("next", "done"): done[1:],
("next", "success"): success[1:],
},
batch_size=num_frames - 1,
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
if self.transform is not None:
item = self.transform(item)
td_data[idxtd : idxtd + len(ep_td)] = ep_td
idx0 = idx1
idxtd = idxtd + len(ep_td)
return TensorStorage(td_data.lock_())
return item

View File

@@ -1,121 +0,0 @@
import pickle
import zipfile
from pathlib import Path
from typing import Callable
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import (
SliceSampler,
)
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
def download():
raise NotImplementedError()
import gdown
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
download_path = "data.zip"
gdown.download(url, download_path, quiet=False)
print("Extracting...")
with zipfile.ZipFile(download_path, "r") as zip_f:
for member in zip_f.namelist():
if member.startswith("data/xarm") and member.endswith(".pkl"):
print(member)
zip_f.extract(member=member)
Path(download_path).unlink()
class SimxarmExperienceReplay(AbstractExperienceReplay):
available_datasets = [
"xarm_lift_medium",
]
def __init__(
self,
dataset_id: str,
batch_size: int = None,
*,
shuffle: bool = True,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
transform: "torchrl.envs.Transform" = None,
):
super().__init__(
dataset_id,
batch_size,
shuffle=shuffle,
root=root,
pin_memory=pin_memory,
prefetch=prefetch,
sampler=sampler,
collate_fn=collate_fn,
writer=writer,
transform=transform,
)
def _download_and_preproc_obsolete(self):
assert self.root is not None
# TODO(rcadene): finish download
download()
dataset_path = self.root / f"{self.dataset_id}_raw" / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
total_frames = dataset_dict["actions"].shape[0]
idx0 = 0
idx1 = 0
episode_id = 0
for i in tqdm.tqdm(range(total_frames)):
idx1 += 1
if not dataset_dict["dones"][i]:
continue
num_frames = idx1 - idx0
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
episode = TensorDict(
{
("observation", "image"): image,
("observation", "state"): state,
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
("next", "observation", "image"): next_image,
("next", "observation", "state"): next_state,
("next", "observation", "reward"): next_reward,
("next", "observation", "done"): next_done,
},
batch_size=num_frames,
)
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = episode[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
td_data[idx0:idx1] = episode
episode_id += 1
idx0 = idx1
return TensorStorage(td_data.lock_())

View File

@@ -1,30 +1,336 @@
import io
import zipfile
from copy import deepcopy
from math import ceil
from pathlib import Path
import requests
import datasets
import einops
import torch
import tqdm
from datasets import Image, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
from lerobot.common.utils.utils import set_global_seed
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
print(f"downloading from {url}")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
def flatten_dict(d, parent_key="", sep="/"):
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
zip_file = io.BytesIO()
for chunk in response.iter_content(chunk_size=1024):
if chunk:
zip_file.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
def unflatten_dict(d, sep="/"):
outdict = {}
for key, value in d.items():
parts = key.split(sep)
d = outdict
for part in parts[:-1]:
if part not in d:
d[part] = {}
d = d[part]
d[parts[-1]] = value
return outdict
zip_file.seek(0)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(destination_folder)
return True
def hf_transform_to_torch(items_dict):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
with channel first (c h w) of float32 type in range [0,1].
"""
for key in items_dict:
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
return items_dict
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(Path(root) / dataset_id / split)
else:
return False
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
)
return load_file(path)
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
stats = load_file(path)
return unflatten_dict(stats)
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
delta_timestamps: dict[str, list[float]],
tol: float,
) -> dict[torch.Tensor]:
"""
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function
populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array
is useful during batched training to not supervise actions associated to timestamps coming after the end of the
episode, or to pad the observations in a specific way. Note that by default the observation frames before the start
of the episode are the same as the first frame of the episode.
Parameters:
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
modality (e.g., "timestamp", "observation.image", "action").
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
smallest expected inter-frame period, but large enough to account for jitter.
Returns:
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for
each modality (e.g. "observation.image_is_pad").
Raises:
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization
issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps
ep_id = item["episode_index"].item()
ep_data_id_from = episode_data_index["from"][ep_id].item()
ep_data_id_to = episode_data_index["to"][ep_id].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
# load timestamps
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
ep_timestamps = torch.stack(ep_timestamps)
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]
ep_last_ts = ep_timestamps[-1]
current_ts = item["timestamp"].item()
for key in delta_timestamps:
# get timestamps used as query to retrieve data of previous/future frames
delta_ts = delta_timestamps[key]
query_ts = current_ts + torch.tensor(delta_ts)
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
min_, argmin_ = dist.min(1)
# TODO(rcadene): synchronize timestamps + interpolation if needed
is_pad = min_ > tol
# check violated query timestamps are all outside the episode range
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
"This might be due to synchronization issues with timestamps during data collection."
)
# get dataset indices corresponding to frames to be loaded
data_ids = ep_data_ids[argmin_]
# load frames modality
item[key] = hf_dataset.select_columns(key)[data_ids][key]
item[key] = torch.stack(item[key])
item[f"{key}_is_pad"] = is_pad
return item
def get_stats_einops_patterns(hf_dataset):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images are returned in channel first format
"""
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=0,
batch_size=2,
shuffle=False,
)
batch = next(iter(dataloader))
stats_patterns = {}
for key, feats_type in hf_dataset.features.items():
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
if isinstance(feats_type, Image):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
# sanity check that images are float32 in range [0,1]
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
stats_patterns[key] = "b c -> c "
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
return stats_patterns
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(hf_dataset)
stats_patterns = get_stats_einops_patterns(hf_dataset)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
for key in stats_patterns:
mean[key] = torch.tensor(0.0).float()
std[key] = torch.tensor(0.0).float()
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(hf_dataset, batch_size, seed):
set_global_seed(seed)
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=4,
batch_size=batch_size,
shuffle=True,
drop_last=False,
)
return dataloader
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation.
batch_mean = einops.reduce(batch[key], pattern, "mean")
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
if i == ceil(max_num_samples / batch_size) - 1:
break
first_batch_ = None
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
for key in stats_patterns:
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
if i == ceil(max_num_samples / batch_size) - 1:
break
for key in stats_patterns:
std[key] = torch.sqrt(std[key])
stats = {}
for key in stats_patterns:
stats[key] = {
"mean": mean[key],
"std": std[key],
"max": max[key],
"min": min[key],
}
return stats
def cycle(iterable):
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
"""
iterator = iter(iterable)
while True:
try:
yield next(iterator)
except StopIteration:
iterator = iter(iterable)

View File

@@ -0,0 +1,78 @@
from pathlib import Path
import torch
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class XarmDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/xarm_lift_medium
https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay
https://huggingface.co/datasets/lerobot/xarm_push_medium
https://huggingface.co/datasets/lerobot/xarm_push_medium_replay
"""
# Copied from lerobot/__init__.py
available_datasets = [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
]
fps = 15
image_keys = ["observation.image"]
def __init__(
self,
dataset_id: str,
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
return len(self.hf_dataset)
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_index"))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = self.hf_dataset[idx]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
if self.transform is not None:
item = self.transform(item)
return item

View File

@@ -1,80 +0,0 @@
import abc
from collections import deque
from typing import Optional
from tensordict import TensorDict
from torchrl.envs import EnvBase
class AbstractEnv(EnvBase):
def __init__(
self,
task,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
self.num_prev_obs = num_prev_obs
self.num_prev_action = num_prev_action
self._rendering_hooks = []
if pixels_only:
assert from_pixels
if from_pixels:
assert image_size
self._make_env()
self._make_spec()
self._current_seed = self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
if self.num_prev_action > 0:
raise NotImplementedError()
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
def register_rendering_hook(self, func):
self._rendering_hooks.append(func)
def call_rendering_hooks(self):
for func in self._rendering_hooks:
func(self)
def reset_rendering_hooks(self):
self._rendering_hooks = []
@abc.abstractmethod
def render(self, mode="rgb_array", width=640, height=480):
raise NotImplementedError()
@abc.abstractmethod
def _reset(self, tensordict: Optional[TensorDict] = None):
raise NotImplementedError()
@abc.abstractmethod
def _step(self, tensordict: TensorDict):
raise NotImplementedError()
@abc.abstractmethod
def _make_env(self):
raise NotImplementedError()
@abc.abstractmethod
def _make_spec(self):
raise NotImplementedError()
@abc.abstractmethod
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError()

View File

@@ -1,59 +0,0 @@
<mujoco>
<include file="scene.xml"/>
<include file="vx300s_dependencies.xml"/>
<equality>
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
</equality>
<worldbody>
<include file="vx300s_left.xml" />
<include file="vx300s_right.xml" />
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
</body>
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
</body>
<body name="peg" pos="0.2 0.5 0.05">
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
</body>
<body name="socket" pos="-0.2 0.5 0.05">
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
</body>
</worldbody>
<actuator>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
</actuator>
<keyframe>
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
</keyframe>
</mujoco>

View File

@@ -1,48 +0,0 @@
<mujoco>
<include file="scene.xml"/>
<include file="vx300s_dependencies.xml"/>
<equality>
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
</equality>
<worldbody>
<include file="vx300s_left.xml" />
<include file="vx300s_right.xml" />
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
</body>
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
</body>
<body name="box" pos="0.2 0.5 0.05">
<joint name="red_box_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
</body>
</worldbody>
<actuator>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
</actuator>
<keyframe>
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
</keyframe>
</mujoco>

View File

@@ -1,53 +0,0 @@
<mujoco>
<include file="scene.xml"/>
<include file="vx300s_dependencies.xml"/>
<worldbody>
<include file="vx300s_left.xml" />
<include file="vx300s_right.xml" />
<body name="peg" pos="0.2 0.5 0.05">
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
</body>
<body name="socket" pos="-0.2 0.5 0.05">
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
</body>
</worldbody>
<actuator>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
</actuator>
<keyframe>
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
</keyframe>
</mujoco>

View File

@@ -1,42 +0,0 @@
<mujoco>
<include file="scene.xml"/>
<include file="vx300s_dependencies.xml"/>
<worldbody>
<include file="vx300s_left.xml" />
<include file="vx300s_right.xml" />
<body name="box" pos="0.2 0.5 0.05">
<joint name="red_box_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
</body>
</worldbody>
<actuator>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
</actuator>
<keyframe>
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
</keyframe>
</mujoco>

View File

@@ -1,38 +0,0 @@
<mujocoinclude>
<!-- <option timestep='0.0025' iterations="50" tolerance="1e-10" solver="Newton" jacobian="dense" cone="elliptic"/>-->
<asset>
<mesh file="tabletop.stl" name="tabletop" scale="0.001 0.001 0.001"/>
</asset>
<visual>
<map fogstart="1.5" fogend="5" force="0.1" znear="0.1"/>
<quality shadowsize="4096" offsamples="4"/>
<headlight ambient="0.4 0.4 0.4"/>
</visual>
<worldbody>
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='-1 -1 1'
dir='1 1 -1'/>
<light directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='1 -1 1' dir='-1 1 -1'/>
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='0 1 1'
dir='0 -1 -1'/>
<body name="table" pos="0 .6 0">
<geom group="1" mesh="tabletop" pos="0 0 0" type="mesh" conaffinity="1" contype="1" name="table" rgba="0.2 0.2 0.2 1" />
</body>
<body name="midair" pos="0 .6 0.2">
<site pos="0 0 0" size="0.01" type="sphere" name="midair" rgba="1 0 0 0"/>
</body>
<camera name="left_pillar" pos="-0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
<camera name="right_pillar" pos="0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
<camera name="top" pos="0 0.6 0.8" fovy="78" mode="targetbody" target="table"/>
<camera name="angle" pos="0 0 0.6" fovy="78" mode="targetbody" target="table"/>
<camera name="front_close" pos="0 0.2 0.4" fovy="78" mode="targetbody" target="vx300s_left/camera_focus"/>
</worldbody>
</mujocoinclude>

View File

@@ -1,17 +0,0 @@
<mujocoinclude>
<compiler angle="radian" inertiafromgeom="auto" inertiagrouprange="4 5"/>
<asset>
<mesh name="vx300s_1_base" file="vx300s_1_base.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_2_shoulder" file="vx300s_2_shoulder.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_3_upper_arm" file="vx300s_3_upper_arm.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_4_upper_forearm" file="vx300s_4_upper_forearm.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_5_lower_forearm" file="vx300s_5_lower_forearm.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_6_wrist" file="vx300s_6_wrist.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_7_gripper" file="vx300s_7_gripper.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_8_gripper_prop" file="vx300s_8_gripper_prop.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_9_gripper_bar" file="vx300s_9_gripper_bar.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_10_gripper_finger_left" file="vx300s_10_custom_finger_left.stl" scale="0.001 0.001 0.001" />
<mesh name="vx300s_10_gripper_finger_right" file="vx300s_10_custom_finger_right.stl" scale="0.001 0.001 0.001" />
</asset>
</mujocoinclude>

View File

@@ -1,59 +0,0 @@
<mujocoinclude>
<body name="vx300s_left" pos="-0.469 0.5 0">
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_left/1_base" contype="0" conaffinity="0"/>
<body name="vx300s_left/shoulder_link" pos="0 0 0.079">
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
<joint name="vx300s_left/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_left/2_shoulder" />
<body name="vx300s_left/upper_arm_link" pos="0 0 0.04805">
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
<joint name="vx300s_left/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_left/3_upper_arm"/>
<body name="vx300s_left/upper_forearm_link" pos="0.05955 0 0.3">
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
<joint name="vx300s_left/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_left/4_upper_forearm" />
<body name="vx300s_left/lower_forearm_link" pos="0.2 0 0">
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
<joint name="vx300s_left/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_left/5_lower_forearm"/>
<body name="vx300s_left/wrist_link" pos="0.1 0 0">
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
<joint name="vx300s_left/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_left/6_wrist" />
<body name="vx300s_left/gripper_link" pos="0.069744 0 0">
<body name="vx300s_left/camera_focus" pos="0.15 0 0.01">
<site pos="0 0 0" size="0.01" type="sphere" name="left_cam_focus" rgba="0 0 1 0"/>
</body>
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_left_site1" rgba="0 0 1 0"/>
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_left_site2" rgba="0 0 1 0"/>
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_left_site3" rgba="0 0 1 0"/>
<camera name="left_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_left/camera_focus"/>
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
<joint name="vx300s_left/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_left/7_gripper" />
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_left/9_gripper_bar" />
<body name="vx300s_left/gripper_prop_link" pos="0.0485 0 0">
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
<!-- <joint name="vx300s_left/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_left/8_gripper_prop" />
</body>
<body name="vx300s_left/left_finger_link" pos="0.0687 0 0">
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
<joint name="vx300s_left/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_left/10_left_gripper_finger"/>
</body>
<body name="vx300s_left/right_finger_link" pos="0.0687 0 0">
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
<joint name="vx300s_left/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_left/10_right_gripper_finger"/>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</mujocoinclude>

View File

@@ -1,59 +0,0 @@
<mujocoinclude>
<body name="vx300s_right" pos="0.469 0.5 0" euler="0 0 3.1416">
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_right/1_base" contype="0" conaffinity="0"/>
<body name="vx300s_right/shoulder_link" pos="0 0 0.079">
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
<joint name="vx300s_right/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_right/2_shoulder" />
<body name="vx300s_right/upper_arm_link" pos="0 0 0.04805">
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
<joint name="vx300s_right/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_right/3_upper_arm"/>
<body name="vx300s_right/upper_forearm_link" pos="0.05955 0 0.3">
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
<joint name="vx300s_right/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_right/4_upper_forearm" />
<body name="vx300s_right/lower_forearm_link" pos="0.2 0 0">
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
<joint name="vx300s_right/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_right/5_lower_forearm"/>
<body name="vx300s_right/wrist_link" pos="0.1 0 0">
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
<joint name="vx300s_right/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_right/6_wrist" />
<body name="vx300s_right/gripper_link" pos="0.069744 0 0">
<body name="vx300s_right/camera_focus" pos="0.15 0 0.01">
<site pos="0 0 0" size="0.01" type="sphere" name="right_cam_focus" rgba="0 0 1 0"/>
</body>
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_right_site1" rgba="0 0 1 0"/>
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_right_site2" rgba="0 0 1 0"/>
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_right_site3" rgba="0 0 1 0"/>
<camera name="right_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_right/camera_focus"/>
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
<joint name="vx300s_right/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_right/7_gripper" />
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_right/9_gripper_bar" />
<body name="vx300s_right/gripper_prop_link" pos="0.0485 0 0">
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
<!-- <joint name="vx300s_right/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_right/8_gripper_prop" />
</body>
<body name="vx300s_right/left_finger_link" pos="0.0687 0 0">
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
<joint name="vx300s_right/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_right/10_left_gripper_finger"/>
</body>
<body name="vx300s_right/right_finger_link" pos="0.0687 0 0">
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
<joint name="vx300s_right/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_right/10_right_gripper_finger"/>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</mujocoinclude>

View File

@@ -1,163 +0,0 @@
from pathlib import Path
### Simulation envs fixed constants
DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz
FPS = 50
JOINTS = [
# absolute joint position
"left_arm_waist",
"left_arm_shoulder",
"left_arm_elbow",
"left_arm_forearm_roll",
"left_arm_wrist_angle",
"left_arm_wrist_rotate",
# normalized gripper position 0: close, 1: open
"left_arm_gripper",
# absolute joint position
"right_arm_waist",
"right_arm_shoulder",
"right_arm_elbow",
"right_arm_forearm_roll",
"right_arm_wrist_angle",
"right_arm_wrist_rotate",
# normalized gripper position 0: close, 1: open
"right_arm_gripper",
]
ACTIONS = [
# position and quaternion for end effector
"left_arm_waist",
"left_arm_shoulder",
"left_arm_elbow",
"left_arm_forearm_roll",
"left_arm_wrist_angle",
"left_arm_wrist_rotate",
# normalized gripper position (0: close, 1: open)
"left_arm_gripper",
"right_arm_waist",
"right_arm_shoulder",
"right_arm_elbow",
"right_arm_forearm_roll",
"right_arm_wrist_angle",
"right_arm_wrist_rotate",
# normalized gripper position (0: close, 1: open)
"right_arm_gripper",
]
START_ARM_POSE = [
0,
-0.96,
1.16,
0,
-0.3,
0,
0.02239,
-0.02239,
0,
-0.96,
1.16,
0,
-0.3,
0,
0.02239,
-0.02239,
]
ASSETS_DIR = Path(__file__).parent.resolve() / "assets" # note: absolute path
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
MASTER_GRIPPER_POSITION_OPEN = 0.02417
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
# Gripper joint limits (qpos[6])
MASTER_GRIPPER_JOINT_OPEN = 0.3083
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
############################ Helper functions ############################
def normalize_master_gripper_position(x):
return (x - MASTER_GRIPPER_POSITION_CLOSE) / (
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
)
def normalize_puppet_gripper_position(x):
return (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
)
def unnormalize_master_gripper_position(x):
return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
def unnormalize_puppet_gripper_position(x):
return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
def convert_position_from_master_to_puppet(x):
return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x))
def normalizer_master_gripper_joint(x):
return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
def normalize_puppet_gripper_joint(x):
return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
def unnormalize_master_gripper_joint(x):
return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
def unnormalize_puppet_gripper_joint(x):
return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
def convert_join_from_master_to_puppet(x):
return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x))
def normalize_master_gripper_velocity(x):
return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
def normalize_puppet_gripper_velocity(x):
return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
def convert_master_from_position_to_joint(x):
return (
normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+ MASTER_GRIPPER_JOINT_CLOSE
)
def convert_master_from_joint_to_position(x):
return unnormalize_master_gripper_position(
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
)
def convert_puppet_from_position_to_join(x):
return (
normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+ PUPPET_GRIPPER_JOINT_CLOSE
)
def convert_puppet_from_joint_to_position(x):
return unnormalize_puppet_gripper_position(
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
)

View File

@@ -1,311 +0,0 @@
import importlib
import logging
from collections import deque
from typing import Optional
import einops
import numpy as np
import torch
from dm_control import mujoco
from dm_control.rl import control
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.envs.aloha.constants import (
ACTIONS,
ASSETS_DIR,
DT,
JOINTS,
)
from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask
from lerobot.common.envs.aloha.tasks.sim_end_effector import (
InsertionEndEffectorTask,
TransferCubeEndEffectorTask,
)
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
from lerobot.common.utils import set_seed
_has_gym = importlib.util.find_spec("gym") is not None
class AlohaEnv(AbstractEnv):
def __init__(
self,
task,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(
task=task,
frame_skip=frame_skip,
from_pixels=from_pixels,
pixels_only=pixels_only,
image_size=image_size,
seed=seed,
device=device,
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
def _make_env(self):
if not _has_gym:
raise ImportError("Cannot import gym.")
if not self.from_pixels:
raise NotImplementedError()
self._env = self._make_env_task(self.task)
def render(self, mode="rgb_array", width=640, height=480):
# TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
image = self._env.physics.render(height=height, width=width, camera_id="top")
return image
def _make_env_task(self, task_name):
# time limit is controlled by StepCounter in env factory
time_limit = float("inf")
if "sim_transfer_cube" in task_name:
xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = TransferCubeTask(random=False)
elif "sim_insertion" in task_name:
xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = InsertionTask(random=False)
elif "sim_end_effector_transfer_cube" in task_name:
raise NotImplementedError()
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = TransferCubeEndEffectorTask(random=False)
elif "sim_end_effector_insertion" in task_name:
raise NotImplementedError()
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
physics = mujoco.Physics.from_xml_path(str(xml_path))
task = InsertionEndEffectorTask(random=False)
else:
raise NotImplementedError(task_name)
env = control.Environment(
physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False
)
return env
def _format_raw_obs(self, raw_obs):
if self.from_pixels:
image = torch.from_numpy(raw_obs["images"]["top"].copy())
image = einops.rearrange(image, "h w c -> c h w")
assert image.dtype == torch.uint8
obs = {"image": {"top": image}}
if not self.pixels_only:
obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32)
else:
# TODO(rcadene):
raise NotImplementedError()
# obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
# TODO(rcadene): do not use global variable for this
if "sim_transfer_cube" in self.task:
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif "sim_insertion" in self.task:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
raw_obs = self._env.reset()
# TODO(rcadene): add assert
# assert self._current_seed == self._env._seed
obs = self._format_raw_obs(raw_obs.observation)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
num_action_steps = action.shape[0]
for i in range(num_action_steps):
_, reward, discount, raw_obs = self._env.step(action[i])
del discount # not used
# TOOD(rcadene): add an enum
success = done = reward == 4
sum_reward += reward
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"]["top"])
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([success], dtype=torch.bool),
},
batch_size=[],
)
return td
def _make_spec(self):
obs = {}
from omegaconf import OmegaConf
if self.from_pixels:
if isinstance(self.image_size, int):
image_shape = (3, self.image_size, self.image_size)
elif OmegaConf.is_list(self.image_size):
assert len(self.image_size) == 3 # c h w
assert self.image_size[0] == 3 # c is RGB
image_shape = tuple(self.image_size)
else:
raise ValueError(self.image_size)
if self.num_prev_obs > 0:
image_shape = (self.num_prev_obs + 1, *image_shape)
obs["image"] = {
"top": BoundedTensorSpec(
low=0,
high=255,
shape=image_shape,
dtype=torch.uint8,
device=self.device,
)
}
if not self.pixels_only:
state_shape = (len(JOINTS),)
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec(
# TODO: add low and high bounds
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
state_shape = (len(JOINTS),)
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec(
# TODO: add low and high bounds
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
self.observation_spec = CompositeSpec({"observation": obs})
# TODO(rcadene): valid when controling end effector?
# action_space = self._env.action_spec()
# self.action_spec = BoundedTensorSpec(
# low=action_space.minimum,
# high=action_space.maximum,
# shape=action_space.shape,
# dtype=torch.float32,
# device=self.device,
# )
# TODO(rcaene): add bounds (where are they????)
self.action_spec = BoundedTensorSpec(
shape=(len(ACTIONS)),
low=-1,
high=1,
dtype=torch.float32,
device=self.device,
)
self.reward_spec = UnboundedContinuousTensorSpec(
shape=(1,),
dtype=torch.float32,
device=self.device,
)
self.done_spec = CompositeSpec(
{
"done": DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
),
"success": DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
),
}
)
def _set_seed(self, seed: Optional[int]):
set_seed(seed)
# TODO(rcadene): seed the env
# self._env.seed(seed)
logging.warning("Aloha env is not seeded")

View File

@@ -1,219 +0,0 @@
import collections
import numpy as np
from dm_control.suite import base
from lerobot.common.envs.aloha.constants import (
START_ARM_POSE,
normalize_puppet_gripper_position,
normalize_puppet_gripper_velocity,
unnormalize_puppet_gripper_position,
)
BOX_POSE = [None] # to be changed from outside
"""
Environment for simulated robot bi-manual manipulation, with joint position control
Action space: [left_arm_qpos (6), # absolute joint position
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
"""
class BimanualViperXTask(base.Task):
def __init__(self, random=None):
super().__init__(random=random)
def before_step(self, action, physics):
left_arm_action = action[:6]
right_arm_action = action[7 : 7 + 6]
normalized_left_gripper_action = action[6]
normalized_right_gripper_action = action[7 + 6]
left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action)
right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action)
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
env_action = np.concatenate(
[left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action]
)
super().before_step(env_action, physics)
return
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
super().initialize_episode(physics)
@staticmethod
def get_qpos(physics):
qpos_raw = physics.data.qpos.copy()
left_qpos_raw = qpos_raw[:8]
right_qpos_raw = qpos_raw[8:16]
left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
@staticmethod
def get_qvel(physics):
qvel_raw = physics.data.qvel.copy()
left_qvel_raw = qvel_raw[:8]
right_qvel_raw = qvel_raw[8:16]
left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
@staticmethod
def get_env_state(physics):
raise NotImplementedError
def get_observation(self, physics):
obs = collections.OrderedDict()
obs["qpos"] = self.get_qpos(physics)
obs["qvel"] = self.get_qvel(physics)
obs["env_state"] = self.get_env_state(physics)
obs["images"] = {}
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
return obs
def get_reward(self, physics):
# return whether left gripper is holding the box
raise NotImplementedError
class TransferCubeTask(BimanualViperXTask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
# reset qpos, control and box position
with physics.reset_context():
physics.named.data.qpos[:16] = START_ARM_POSE
np.copyto(physics.data.ctrl, START_ARM_POSE)
assert BOX_POSE[0] is not None
physics.named.data.qpos[-7:] = BOX_POSE[0]
# print(f"{BOX_POSE=}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether left gripper is holding the box
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_table = ("red_box", "table") in all_contact_pairs
reward = 0
if touch_right_gripper:
reward = 1
if touch_right_gripper and not touch_table: # lifted
reward = 2
if touch_left_gripper: # attempted transfer
reward = 3
if touch_left_gripper and not touch_table: # successful transfer
reward = 4
return reward
class InsertionTask(BimanualViperXTask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
# reset qpos, control and box position
with physics.reset_context():
physics.named.data.qpos[:16] = START_ARM_POSE
np.copyto(physics.data.ctrl, START_ARM_POSE)
assert BOX_POSE[0] is not None
physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects
# print(f"{BOX_POSE=}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether peg touches the pin
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_left_gripper = (
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
)
peg_touch_table = ("red_peg", "table") in all_contact_pairs
socket_touch_table = (
("socket-1", "table") in all_contact_pairs
or ("socket-2", "table") in all_contact_pairs
or ("socket-3", "table") in all_contact_pairs
or ("socket-4", "table") in all_contact_pairs
)
peg_touch_socket = (
("red_peg", "socket-1") in all_contact_pairs
or ("red_peg", "socket-2") in all_contact_pairs
or ("red_peg", "socket-3") in all_contact_pairs
or ("red_peg", "socket-4") in all_contact_pairs
)
pin_touched = ("red_peg", "pin") in all_contact_pairs
reward = 0
if touch_left_gripper and touch_right_gripper: # touch both
reward = 1
if (
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
): # grasp both
reward = 2
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
reward = 3
if pin_touched: # successful insertion
reward = 4
return reward

View File

@@ -1,263 +0,0 @@
import collections
import numpy as np
from dm_control.suite import base
from lerobot.common.envs.aloha.constants import (
PUPPET_GRIPPER_POSITION_CLOSE,
START_ARM_POSE,
normalize_puppet_gripper_position,
normalize_puppet_gripper_velocity,
unnormalize_puppet_gripper_position,
)
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
"""
Environment for simulated robot bi-manual manipulation, with end-effector control.
Action space: [left_arm_pose (7), # position and quaternion for end effector
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_pose (7), # position and quaternion for end effector
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
"""
class BimanualViperXEndEffectorTask(base.Task):
def __init__(self, random=None):
super().__init__(random=random)
def before_step(self, action, physics):
a_len = len(action) // 2
action_left = action[:a_len]
action_right = action[a_len:]
# set mocap position and quat
# left
np.copyto(physics.data.mocap_pos[0], action_left[:3])
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
# right
np.copyto(physics.data.mocap_pos[1], action_right[:3])
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
# set gripper
g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7])
g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7])
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
def initialize_robots(self, physics):
# reset joint position
physics.named.data.qpos[:16] = START_ARM_POSE
# reset mocap to align with end effector
# to obtain these numbers:
# (1) make an ee_sim env and reset to the same start_pose
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
# repeat the same for right side
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
# right
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
# reset gripper control
close_gripper_control = np.array(
[
PUPPET_GRIPPER_POSITION_CLOSE,
-PUPPET_GRIPPER_POSITION_CLOSE,
PUPPET_GRIPPER_POSITION_CLOSE,
-PUPPET_GRIPPER_POSITION_CLOSE,
]
)
np.copyto(physics.data.ctrl, close_gripper_control)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
super().initialize_episode(physics)
@staticmethod
def get_qpos(physics):
qpos_raw = physics.data.qpos.copy()
left_qpos_raw = qpos_raw[:8]
right_qpos_raw = qpos_raw[8:16]
left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
@staticmethod
def get_qvel(physics):
qvel_raw = physics.data.qvel.copy()
left_qvel_raw = qvel_raw[:8]
right_qvel_raw = qvel_raw[8:16]
left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
@staticmethod
def get_env_state(physics):
raise NotImplementedError
def get_observation(self, physics):
# note: it is important to do .copy()
obs = collections.OrderedDict()
obs["qpos"] = self.get_qpos(physics)
obs["qvel"] = self.get_qvel(physics)
obs["env_state"] = self.get_env_state(physics)
obs["images"] = {}
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
# used in scripted policy to obtain starting pose
obs["mocap_pose_left"] = np.concatenate(
[physics.data.mocap_pos[0], physics.data.mocap_quat[0]]
).copy()
obs["mocap_pose_right"] = np.concatenate(
[physics.data.mocap_pos[1], physics.data.mocap_quat[1]]
).copy()
# used when replaying joint trajectory
obs["gripper_ctrl"] = physics.data.ctrl.copy()
return obs
def get_reward(self, physics):
raise NotImplementedError
class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
self.initialize_robots(physics)
# randomize box position
cube_pose = sample_box_pose()
box_start_idx = physics.model.name2id("red_box_joint", "joint")
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
# print(f"randomized cube position to {cube_position}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether left gripper is holding the box
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_table = ("red_box", "table") in all_contact_pairs
reward = 0
if touch_right_gripper:
reward = 1
if touch_right_gripper and not touch_table: # lifted
reward = 2
if touch_left_gripper: # attempted transfer
reward = 3
if touch_left_gripper and not touch_table: # successful transfer
reward = 4
return reward
class InsertionEndEffectorTask(BimanualViperXEndEffectorTask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
self.initialize_robots(physics)
# randomize peg and socket position
peg_pose, socket_pose = sample_insertion_pose()
def id2index(j_id):
return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
peg_start_id = physics.model.name2id("red_peg_joint", "joint")
peg_start_idx = id2index(peg_start_id)
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
# print(f"randomized cube position to {cube_position}")
socket_start_id = physics.model.name2id("blue_socket_joint", "joint")
socket_start_idx = id2index(socket_start_id)
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
# print(f"randomized cube position to {cube_position}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether peg touches the pin
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_left_gripper = (
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
)
peg_touch_table = ("red_peg", "table") in all_contact_pairs
socket_touch_table = (
("socket-1", "table") in all_contact_pairs
or ("socket-2", "table") in all_contact_pairs
or ("socket-3", "table") in all_contact_pairs
or ("socket-4", "table") in all_contact_pairs
)
peg_touch_socket = (
("red_peg", "socket-1") in all_contact_pairs
or ("red_peg", "socket-2") in all_contact_pairs
or ("red_peg", "socket-3") in all_contact_pairs
or ("red_peg", "socket-4") in all_contact_pairs
)
pin_touched = ("red_peg", "pin") in all_contact_pairs
reward = 0
if touch_left_gripper and touch_right_gripper: # touch both
reward = 1
if (
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
): # grasp both
reward = 2
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
reward = 3
if pin_touched: # successful insertion
reward = 4
return reward

View File

@@ -1,39 +0,0 @@
import numpy as np
def sample_box_pose():
x_range = [0.0, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
cube_quat = np.array([1, 0, 0, 0])
return np.concatenate([cube_position, cube_quat])
def sample_insertion_pose():
# Peg
x_range = [0.1, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
peg_quat = np.array([1, 0, 0, 0])
peg_pose = np.concatenate([peg_position, peg_quat])
# Socket
x_range = [-0.2, -0.1]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
socket_quat = np.array([1, 0, 0, 0])
socket_pose = np.concatenate([socket_position, socket_quat])
return peg_pose, socket_pose

View File

@@ -1,73 +1,42 @@
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
import importlib
import gymnasium as gym
def make_env(cfg, transform=None):
def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
"""
Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and
returns batched observation, reward, terminated, truncated of `num_parallel_envs` items.
"""
kwargs = {
"frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed,
"num_prev_obs": cfg.n_obs_steps - 1,
"obs_type": "pixels_agent_pos",
"render_mode": "rgb_array",
"max_episode_steps": cfg.env.episode_length,
"visualization_width": 384,
"visualization_height": 384,
}
if cfg.env.name == "simxarm":
from lerobot.common.envs.simxarm import SimxarmEnv
package_name = f"gym_{cfg.env.name}"
kwargs["task"] = cfg.env.task
clsfunc = SimxarmEnv
elif cfg.env.name == "pusht":
from lerobot.common.envs.pusht.env import PushtEnv
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`"
)
raise e
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
gym_handle = f"{package_name}/{cfg.env.task}"
clsfunc = PushtEnv
elif cfg.env.name == "aloha":
from lerobot.common.envs.aloha.env import AlohaEnv
kwargs["task"] = cfg.env.task
clsfunc = AlohaEnv
if num_parallel_envs == 0:
# non-batched version of the env that returns an observation of shape (c)
env = gym.make(gym_handle, disable_env_checker=True, **kwargs)
else:
raise ValueError(cfg.env.name)
env = clsfunc(**kwargs)
# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
if transform is not None:
# useful to add normalization
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()
# batched version of the env that returns an observation of shape (b, c)
env = gym.vector.SyncVectorEnv(
[
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
for _ in range(num_parallel_envs)
]
)
return env
# def make_env(env_name, frame_skip, device, is_test=False):
# env = GymEnv(
# env_name,
# frame_skip=frame_skip,
# from_pixels=True,
# pixels_only=False,
# device=device,
# )
# env = TransformedEnv(env)
# env.append_transform(NoopResetEnv(noops=30, random=True))
# if not is_test:
# env.append_transform(EndOfLifeTransform())
# env.append_transform(RewardClipping(-1, 1))
# env.append_transform(ToTensorImage())
# env.append_transform(GrayScale())
# env.append_transform(Resize(84, 84))
# env.append_transform(CatFrames(N=4, dim=-3))
# env.append_transform(RewardSum())
# env.append_transform(StepCounter(max_steps=4500))
# env.append_transform(DoubleToFloat())
# env.append_transform(VecNorm(in_keys=["pixels"]))
# return env

View File

@@ -1,234 +0,0 @@
import importlib
from collections import deque
from typing import Optional
import einops
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.utils import set_seed
_has_gym = importlib.util.find_spec("gym") is not None
class PushtEnv(AbstractEnv):
def __init__(
self,
task="pusht",
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=1,
num_prev_action=0,
):
super().__init__(
task=task,
frame_skip=frame_skip,
from_pixels=from_pixels,
pixels_only=pixels_only,
image_size=image_size,
seed=seed,
device=device,
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
def _make_env(self):
if not _has_gym:
raise ImportError("Cannot import gym.")
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
if not self.from_pixels:
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv
self._env = PushTImageEnv(render_size=self.image_size)
def render(self, mode="rgb_array", width=384, height=384):
if width != height:
raise NotImplementedError()
tmp = self._env.render_size
self._env.render_size = width
out = self._env.render(mode)
self._env.render_size = tmp
return out
def _format_raw_obs(self, raw_obs):
if self.from_pixels:
image = torch.from_numpy(raw_obs["image"])
obs = {"image": image}
if not self.pixels_only:
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
else:
# TODO:
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
raw_obs = self._env.reset()
assert self._current_seed == self._env._seed
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
num_action_steps = action.shape[0]
for i in range(num_action_steps):
raw_obs, reward, done, info = self._env.step(action[i])
sum_reward += reward
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([done], dtype=torch.bool),
},
batch_size=[],
)
return td
def _make_spec(self):
obs = {}
if self.from_pixels:
image_shape = (3, self.image_size, self.image_size)
if self.num_prev_obs > 0:
image_shape = (self.num_prev_obs + 1, *image_shape)
obs["image"] = BoundedTensorSpec(
low=0,
high=255,
shape=image_shape,
dtype=torch.uint8,
device=self.device,
)
if not self.pixels_only:
state_shape = self._env.observation_space["agent_pos"].shape
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = BoundedTensorSpec(
low=0,
high=512,
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
state_shape = self._env.observation_space["observation"].shape
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec(
# TODO:
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
self.observation_spec = CompositeSpec({"observation": obs})
self.action_spec = _gym_to_torchrl_spec_transform(
self._env.action_space,
device=self.device,
)
self.reward_spec = UnboundedContinuousTensorSpec(
shape=(1,),
dtype=torch.float32,
device=self.device,
)
self.done_spec = CompositeSpec(
{
"done": DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
),
"success": DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
),
}
)
def _set_seed(self, seed: Optional[int]):
set_seed(seed)
self._env.seed(seed)

View File

@@ -1,378 +0,0 @@
import collections
import cv2
import gym
import numpy as np
import pygame
import pymunk
import pymunk.pygame_util
import shapely.geometry as sg
import skimage.transform as st
from gym import spaces
from pymunk.vec2d import Vec2d
from lerobot.common.envs.pusht.pymunk_override import DrawOptions
def pymunk_to_shapely(body, shapes):
geoms = []
for shape in shapes:
if isinstance(shape, pymunk.shapes.Poly):
verts = [body.local_to_world(v) for v in shape.get_vertices()]
verts += [verts[0]]
geoms.append(sg.Polygon(verts))
else:
raise RuntimeError(f"Unsupported shape type {type(shape)}")
geom = sg.MultiPolygon(geoms)
return geom
class PushTEnv(gym.Env):
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
reward_range = (0.0, 1.0)
def __init__(
self,
legacy=False,
block_cog=None,
damping=None,
render_action=True,
render_size=96,
reset_to_state=None,
):
self._seed = None
self.seed()
self.window_size = ws = 512 # The size of the PyGame window
self.render_size = render_size
self.sim_hz = 100
# Local controller params.
self.k_p, self.k_v = 100, 20 # PD control.z
self.control_hz = self.metadata["video.frames_per_second"]
# legcay set_state for data compatibility
self.legacy = legacy
# agent_pos, block_pos, block_angle
self.observation_space = spaces.Box(
low=np.array([0, 0, 0, 0, 0], dtype=np.float64),
high=np.array([ws, ws, ws, ws, np.pi * 2], dtype=np.float64),
shape=(5,),
dtype=np.float64,
)
# positional goal for agent
self.action_space = spaces.Box(
low=np.array([0, 0], dtype=np.float64),
high=np.array([ws, ws], dtype=np.float64),
shape=(2,),
dtype=np.float64,
)
self.block_cog = block_cog
self.damping = damping
self.render_action = render_action
"""
If human-rendering is used, `self.window` will be a reference
to the window that we draw to. `self.clock` will be a clock that is used
to ensure that the environment is rendered at the correct framerate in
human-mode. They will remain `None` until human-mode is used for the
first time.
"""
self.window = None
self.clock = None
self.screen = None
self.space = None
self.teleop = None
self.render_buffer = None
self.latest_action = None
self.reset_to_state = reset_to_state
def reset(self):
seed = self._seed
self._setup()
if self.block_cog is not None:
self.block.center_of_gravity = self.block_cog
if self.damping is not None:
self.space.damping = self.damping
# use legacy RandomState for compatibility
state = self.reset_to_state
if state is None:
rs = np.random.RandomState(seed=seed)
state = np.array(
[
rs.randint(50, 450),
rs.randint(50, 450),
rs.randint(100, 400),
rs.randint(100, 400),
rs.randn() * 2 * np.pi - np.pi,
]
)
self._set_state(state)
observation = self._get_obs()
return observation
def step(self, action):
dt = 1.0 / self.sim_hz
self.n_contact_points = 0
n_steps = self.sim_hz // self.control_hz
if action is not None:
self.latest_action = action
for _ in range(n_steps):
# Step PD control.
# self.agent.velocity = self.k_p * (act - self.agent.position) # P control works too.
acceleration = self.k_p * (action - self.agent.position) + self.k_v * (
Vec2d(0, 0) - self.agent.velocity
)
self.agent.velocity += acceleration * dt
# Step physics.
self.space.step(dt)
# compute reward
goal_body = self._get_goal_pose_body(self.goal_pose)
goal_geom = pymunk_to_shapely(goal_body, self.block.shapes)
block_geom = pymunk_to_shapely(self.block, self.block.shapes)
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage = intersection_area / goal_area
reward = np.clip(coverage / self.success_threshold, 0, 1)
done = coverage > self.success_threshold
observation = self._get_obs()
info = self._get_info()
return observation, reward, done, info
def render(self, mode):
return self._render_frame(mode)
def teleop_agent(self):
TeleopAgent = collections.namedtuple("TeleopAgent", ["act"])
def act(obs):
act = None
mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen)
if self.teleop or (mouse_position - self.agent.position).length < 30:
self.teleop = True
act = mouse_position
return act
return TeleopAgent(act)
def _get_obs(self):
obs = np.array(
tuple(self.agent.position) + tuple(self.block.position) + (self.block.angle % (2 * np.pi),)
)
return obs
def _get_goal_pose_body(self, pose):
mass = 1
inertia = pymunk.moment_for_box(mass, (50, 100))
body = pymunk.Body(mass, inertia)
# preserving the legacy assignment order for compatibility
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
body.position = pose[:2].tolist()
body.angle = pose[2]
return body
def _get_info(self):
n_steps = self.sim_hz // self.control_hz
n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps))
info = {
"pos_agent": np.array(self.agent.position),
"vel_agent": np.array(self.agent.velocity),
"block_pose": np.array(list(self.block.position) + [self.block.angle]),
"goal_pose": self.goal_pose,
"n_contacts": n_contact_points_per_step,
}
return info
def _render_frame(self, mode):
if self.window is None and mode == "human":
pygame.init()
pygame.display.init()
self.window = pygame.display.set_mode((self.window_size, self.window_size))
if self.clock is None and mode == "human":
self.clock = pygame.time.Clock()
canvas = pygame.Surface((self.window_size, self.window_size))
canvas.fill((255, 255, 255))
self.screen = canvas
draw_options = DrawOptions(canvas)
# Draw goal pose.
goal_body = self._get_goal_pose_body(self.goal_pose)
for shape in self.block.shapes:
goal_points = [
pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface)
for v in shape.get_vertices()
]
goal_points += [goal_points[0]]
pygame.draw.polygon(canvas, self.goal_color, goal_points)
# Draw agent and block.
self.space.debug_draw(draw_options)
if mode == "human":
# The following line copies our drawings from `canvas` to the visible window
self.window.blit(canvas, canvas.get_rect())
pygame.event.pump()
pygame.display.update()
# the clock is already ticked during in step for "human"
img = np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))
img = cv2.resize(img, (self.render_size, self.render_size))
if self.render_action and self.latest_action is not None:
action = np.array(self.latest_action)
coord = (action / 512 * 96).astype(np.int32)
marker_size = int(8 / 96 * self.render_size)
thickness = int(1 / 96 * self.render_size)
cv2.drawMarker(
img,
coord,
color=(255, 0, 0),
markerType=cv2.MARKER_CROSS,
markerSize=marker_size,
thickness=thickness,
)
return img
def close(self):
if self.window is not None:
pygame.display.quit()
pygame.quit()
def seed(self, seed=None):
if seed is None:
seed = np.random.randint(0, 25536)
self._seed = seed
self.np_random = np.random.default_rng(seed)
def _handle_collision(self, arbiter, space, data):
self.n_contact_points += len(arbiter.contact_point_set.points)
def _set_state(self, state):
if isinstance(state, np.ndarray):
state = state.tolist()
pos_agent = state[:2]
pos_block = state[2:4]
rot_block = state[4]
self.agent.position = pos_agent
# setting angle rotates with respect to center of mass
# therefore will modify the geometric position
# if not the same as CoM
# therefore should be modified first.
if self.legacy:
# for compatibility with legacy data
self.block.position = pos_block
self.block.angle = rot_block
else:
self.block.angle = rot_block
self.block.position = pos_block
# Run physics to take effect
self.space.step(1.0 / self.sim_hz)
def _set_state_local(self, state_local):
agent_pos_local = state_local[:2]
block_pose_local = state_local[2:]
tf_img_obj = st.AffineTransform(translation=self.goal_pose[:2], rotation=self.goal_pose[2])
tf_obj_new = st.AffineTransform(translation=block_pose_local[:2], rotation=block_pose_local[2])
tf_img_new = st.AffineTransform(matrix=tf_img_obj.params @ tf_obj_new.params)
agent_pos_new = tf_img_new(agent_pos_local)
new_state = np.array(list(agent_pos_new[0]) + list(tf_img_new.translation) + [tf_img_new.rotation])
self._set_state(new_state)
return new_state
def _setup(self):
self.space = pymunk.Space()
self.space.gravity = 0, 0
self.space.damping = 0
self.teleop = False
self.render_buffer = []
# Add walls.
walls = [
self._add_segment((5, 506), (5, 5), 2),
self._add_segment((5, 5), (506, 5), 2),
self._add_segment((506, 5), (506, 506), 2),
self._add_segment((5, 506), (506, 506), 2),
]
self.space.add(*walls)
# Add agent, block, and goal zone.
self.agent = self.add_circle((256, 400), 15)
self.block = self.add_tee((256, 300), 0)
self.goal_color = pygame.Color("LightGreen")
self.goal_pose = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
# Add collision handling
self.collision_handeler = self.space.add_collision_handler(0, 0)
self.collision_handeler.post_solve = self._handle_collision
self.n_contact_points = 0
self.max_score = 50 * 100
self.success_threshold = 0.95 # 95% coverage.
def _add_segment(self, a, b, radius):
shape = pymunk.Segment(self.space.static_body, a, b, radius)
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
return shape
def add_circle(self, position, radius):
body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)
body.position = position
body.friction = 1
shape = pymunk.Circle(body, radius)
shape.color = pygame.Color("RoyalBlue")
self.space.add(body, shape)
return body
def add_box(self, position, height, width):
mass = 1
inertia = pymunk.moment_for_box(mass, (height, width))
body = pymunk.Body(mass, inertia)
body.position = position
shape = pymunk.Poly.create_box(body, (height, width))
shape.color = pygame.Color("LightSlateGray")
self.space.add(body, shape)
return body
def add_tee(self, position, angle, scale=30, color="LightSlateGray", mask=None):
if mask is None:
mask = pymunk.ShapeFilter.ALL_MASKS()
mass = 1
length = 4
vertices1 = [
(-length * scale / 2, scale),
(length * scale / 2, scale),
(length * scale / 2, 0),
(-length * scale / 2, 0),
]
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
vertices2 = [
(-scale / 2, scale),
(-scale / 2, length * scale),
(scale / 2, length * scale),
(scale / 2, scale),
]
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
body = pymunk.Body(mass, inertia1 + inertia2)
shape1 = pymunk.Poly(body, vertices1)
shape2 = pymunk.Poly(body, vertices2)
shape1.color = pygame.Color(color)
shape2.color = pygame.Color(color)
shape1.filter = pymunk.ShapeFilter(mask=mask)
shape2.filter = pymunk.ShapeFilter(mask=mask)
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
body.position = position
body.angle = angle
body.friction = 1
self.space.add(body, shape1, shape2)
return body

View File

@@ -1,55 +0,0 @@
import cv2
import numpy as np
from gym import spaces
from lerobot.common.envs.pusht.pusht_env import PushTEnv
class PushTImageEnv(PushTEnv):
metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
def __init__(self, legacy=False, block_cog=None, damping=None, render_size=96):
super().__init__(
legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False
)
ws = self.window_size
self.observation_space = spaces.Dict(
{
"image": spaces.Box(low=0, high=1, shape=(3, render_size, render_size), dtype=np.float32),
"agent_pos": spaces.Box(low=0, high=ws, shape=(2,), dtype=np.float32),
}
)
self.render_cache = None
def _get_obs(self):
img = super()._render_frame(mode="rgb_array")
agent_pos = np.array(self.agent.position)
img_obs = np.moveaxis(img, -1, 0)
obs = {"image": img_obs, "agent_pos": agent_pos}
# draw action
if self.latest_action is not None:
action = np.array(self.latest_action)
coord = (action / 512 * 96).astype(np.int32)
marker_size = int(8 / 96 * self.render_size)
thickness = int(1 / 96 * self.render_size)
cv2.drawMarker(
img,
coord,
color=(255, 0, 0),
markerType=cv2.MARKER_CROSS,
markerSize=marker_size,
thickness=thickness,
)
self.render_cache = img
return obs
def render(self, mode):
assert mode == "rgb_array"
if self.render_cache is None:
self._get_obs()
return self.render_cache

View File

@@ -1,244 +0,0 @@
# ----------------------------------------------------------------------------
# pymunk
# Copyright (c) 2007-2016 Victor Blomqvist
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# ----------------------------------------------------------------------------
"""This submodule contains helper functions to help with quick prototyping
using pymunk together with pygame.
Intended to help with debugging and prototyping, not for actual production use
in a full application. The methods contained in this module is opinionated
about your coordinate system and not in any way optimized.
"""
__docformat__ = "reStructuredText"
__all__ = [
"DrawOptions",
"get_mouse_pos",
"to_pygame",
"from_pygame",
# "lighten",
"positive_y_is_up",
]
from typing import Sequence, Tuple
import numpy as np
import pygame
import pymunk
from pymunk.space_debug_draw_options import SpaceDebugColor
from pymunk.vec2d import Vec2d
positive_y_is_up: bool = False
"""Make increasing values of y point upwards.
When True::
y
^
| . (3, 3)
|
| . (2, 2)
|
+------ > x
When False::
+------ > x
|
| . (2, 2)
|
| . (3, 3)
v
y
"""
class DrawOptions(pymunk.SpaceDebugDrawOptions):
def __init__(self, surface: pygame.Surface) -> None:
"""Draw a pymunk.Space on a pygame.Surface object.
Typical usage::
>>> import pymunk
>>> surface = pygame.Surface((10,10))
>>> space = pymunk.Space()
>>> options = pymunk.pygame_util.DrawOptions(surface)
>>> space.debug_draw(options)
You can control the color of a shape by setting shape.color to the color
you want it drawn in::
>>> c = pymunk.Circle(None, 10)
>>> c.color = pygame.Color("pink")
See pygame_util.demo.py for a full example
Since pygame uses a coordinate system where y points down (in contrast
to many other cases), you either have to make the physics simulation
with Pymunk also behave in that way, or flip everything when you draw.
The easiest is probably to just make the simulation behave the same
way as Pygame does. In that way all coordinates used are in the same
orientation and easy to reason about::
>>> space = pymunk.Space()
>>> space.gravity = (0, -1000)
>>> body = pymunk.Body()
>>> body.position = (0, 0) # will be positioned in the top left corner
>>> space.debug_draw(options)
To flip the drawing its possible to set the module property
:py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
the simulation upside down before drawing::
>>> positive_y_is_up = True
>>> body = pymunk.Body()
>>> body.position = (0, 0)
>>> # Body will be position in bottom left corner
:Parameters:
surface : pygame.Surface
Surface that the objects will be drawn on
"""
self.surface = surface
super().__init__()
def draw_circle(
self,
pos: Vec2d,
angle: float,
radius: float,
outline_color: SpaceDebugColor,
fill_color: SpaceDebugColor,
) -> None:
p = to_pygame(pos, self.surface)
pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius - 4), 0)
# circle_edge = pos + Vec2d(radius, 0).rotated(angle)
# p2 = to_pygame(circle_edge, self.surface)
# line_r = 2 if radius > 20 else 1
# pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)
def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
p1 = to_pygame(a, self.surface)
p2 = to_pygame(b, self.surface)
pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])
def draw_fat_segment(
self,
a: Tuple[float, float],
b: Tuple[float, float],
radius: float,
outline_color: SpaceDebugColor,
fill_color: SpaceDebugColor,
) -> None:
p1 = to_pygame(a, self.surface)
p2 = to_pygame(b, self.surface)
r = round(max(1, radius * 2))
pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
if r > 2:
orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
if orthog[0] == 0 and orthog[1] == 0:
return
scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
orthog[0] = round(orthog[0] * scale)
orthog[1] = round(orthog[1] * scale)
points = [
(p1[0] - orthog[0], p1[1] - orthog[1]),
(p1[0] + orthog[0], p1[1] + orthog[1]),
(p2[0] + orthog[0], p2[1] + orthog[1]),
(p2[0] - orthog[0], p2[1] - orthog[1]),
]
pygame.draw.polygon(self.surface, fill_color.as_int(), points)
pygame.draw.circle(
self.surface,
fill_color.as_int(),
(round(p1[0]), round(p1[1])),
round(radius),
)
pygame.draw.circle(
self.surface,
fill_color.as_int(),
(round(p2[0]), round(p2[1])),
round(radius),
)
def draw_polygon(
self,
verts: Sequence[Tuple[float, float]],
radius: float,
outline_color: SpaceDebugColor,
fill_color: SpaceDebugColor,
) -> None:
ps = [to_pygame(v, self.surface) for v in verts]
ps += [ps[0]]
radius = 2
pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)
if radius > 0:
for i in range(len(verts)):
a = verts[i]
b = verts[(i + 1) % len(verts)]
self.draw_fat_segment(a, b, radius, fill_color, fill_color)
def draw_dot(self, size: float, pos: Tuple[float, float], color: SpaceDebugColor) -> None:
p = to_pygame(pos, self.surface)
pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)
def get_mouse_pos(surface: pygame.Surface) -> Tuple[int, int]:
"""Get position of the mouse pointer in pymunk coordinates."""
p = pygame.mouse.get_pos()
return from_pygame(p, surface)
def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
"""Convenience method to convert pymunk coordinates to pygame surface
local coordinates.
Note that in case positive_y_is_up is False, this function won't actually do
anything except converting the point to integers.
"""
if positive_y_is_up:
return round(p[0]), surface.get_height() - round(p[1])
else:
return round(p[0]), round(p[1])
def from_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
"""Convenience method to convert pygame surface local coordinates to
pymunk coordinates
"""
return to_pygame(p, surface)
def light_color(color: SpaceDebugColor):
color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
return color

View File

@@ -1,236 +0,0 @@
import importlib
from collections import deque
from typing import Optional
import einops
import numpy as np
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.utils import set_seed
MAX_NUM_ACTIONS = 4
_has_gym = importlib.util.find_spec("gym") is not None
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
class SimxarmEnv(AbstractEnv):
def __init__(
self,
task,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
seed=1337,
device="cpu",
num_prev_obs=0,
num_prev_action=0,
):
super().__init__(
task=task,
frame_skip=frame_skip,
from_pixels=from_pixels,
pixels_only=pixels_only,
image_size=image_size,
seed=seed,
device=device,
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
def _make_env(self):
if not _has_simxarm:
raise ImportError("Cannot import simxarm.")
if not _has_gym:
raise ImportError("Cannot import gym.")
import gym
from simxarm import TASKS
if self.task not in TASKS:
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
self._env = TASKS[self.task]["env"]()
num_actions = len(TASKS[self.task]["action_space"])
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
if "w" not in TASKS[self.task]["action_space"]:
self._action_padding[-1] = 1.0
def render(self, mode="rgb_array", width=384, height=384):
return self._env.render(mode, width=width, height=height)
def _format_raw_obs(self, raw_obs):
if self.from_pixels:
image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
image = torch.tensor(image.copy(), dtype=torch.uint8)
obs = {"image": image}
if not self.pixels_only:
obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32)
else:
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
obs = TensorDict(obs, batch_size=[])
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
td = tensordict
if td is None or td.is_empty():
raw_obs = self._env.reset()
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue = deque(
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
"observation": TensorDict(obs, batch_size=[]),
"done": torch.tensor([False], dtype=torch.bool),
},
batch_size=[],
)
else:
raise NotImplementedError()
self.call_rendering_hooks()
return td
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
# step expects shape=(4,) so we pad if necessary
action = np.concatenate([action, self._action_padding])
# TODO(rcadene): add info["is_success"] and info["success"] ?
sum_reward = 0
if action.ndim == 1:
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
else:
if self.frame_skip > 1:
raise NotImplementedError()
num_action_steps = action.shape[0]
for i in range(num_action_steps):
raw_obs, reward, done, info = self._env.step(action[i])
sum_reward += reward
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
self.call_rendering_hooks()
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([info["success"]], dtype=torch.bool),
},
batch_size=[],
)
return td
def _make_spec(self):
obs = {}
if self.from_pixels:
image_shape = (3, self.image_size, self.image_size)
if self.num_prev_obs > 0:
image_shape = (self.num_prev_obs + 1, *image_shape)
obs["image"] = BoundedTensorSpec(
low=0,
high=255,
shape=image_shape,
dtype=torch.uint8,
device=self.device,
)
if not self.pixels_only:
state_shape = (len(self._env.robot_state),)
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec(
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
state_shape = self._env.observation_space["observation"].shape
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape)
obs["state"] = UnboundedContinuousTensorSpec(
# TODO:
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
self.observation_spec = CompositeSpec({"observation": obs})
self.action_spec = _gym_to_torchrl_spec_transform(
self._action_space,
device=self.device,
)
self.reward_spec = UnboundedContinuousTensorSpec(
shape=(1,),
dtype=torch.float32,
device=self.device,
)
self.done_spec = CompositeSpec(
{
"done": DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
),
"success": DiscreteTensorSpec(
2,
shape=(1,),
dtype=torch.bool,
device=self.device,
),
}
)
def _set_seed(self, seed: Optional[int]):
set_seed(seed)
self._env.seed(seed)

View File

@@ -1,119 +0,0 @@
from typing import Sequence
import torch
from tensordict import TensorDictBase
from tensordict.nn import dispatch
from tensordict.utils import NestedKey
from torchrl.envs.transforms import ObservationTransform, Transform
class Prod(ObservationTransform):
invertible = True
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
super().__init__()
self.in_keys = in_keys
self.prod = prod
self.original_dtypes = {}
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
# _reset is called once when the environment reset to normalize the first observation
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict)
def _call(self, td):
for key in self.in_keys:
if td.get(key, None) is None:
continue
self.original_dtypes[key] = td[key].dtype
td[key] = td[key].type(torch.float32) * self.prod
return td
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
for key in self.in_keys:
if td.get(key, None) is None:
continue
td[key] = (td[key] / self.prod).type(self.original_dtypes[key])
return td
def transform_observation_spec(self, obs_spec):
for key in self.in_keys:
if obs_spec.get(key, None) is None:
continue
obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
obs_spec[key].dtype = torch.float32
return obs_spec
class NormalizeTransform(Transform):
invertible = True
def __init__(
self,
stats: TensorDictBase,
in_keys: Sequence[NestedKey] = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
mode="mean_std",
):
if out_keys is None:
out_keys = in_keys
if in_keys_inv is None:
in_keys_inv = out_keys
if out_keys_inv is None:
out_keys_inv = in_keys
super().__init__(
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
)
self.stats = stats
assert mode in ["mean_std", "min_max"]
self.mode = mode
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
# _reset is called once when the environment reset to normalize the first observation
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict)
def _call(self, td: TensorDictBase) -> TensorDictBase:
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
# normalize to [0,1]
td[outkey] = (td[inkey] - min) / (max - min)
# normalize to [-1, 1]
td[outkey] = td[outkey] * 2 - 1
return td
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
td[outkey] = td[inkey] * std + mean
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
td[outkey] = (td[inkey] + 1) / 2
td[outkey] = td[outkey] * (max - min) + min
return td

View File

@@ -0,0 +1,52 @@
import einops
import torch
from lerobot.common.transforms import apply_inverse_transform
def preprocess_observation(observation, transform=None):
# map to expected inputs for the policy
obs = {}
if isinstance(observation["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
else:
imgs = {"observation.image": observation["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w")
img = img.type(torch.float32)
img /= 255
obs[imgkey] = img
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
# apply same transforms as in training
if transform is not None:
for key in obs:
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
return obs
def postprocess_action(action, transform=None):
action = action.to("cpu")
# action is a batch (num_env,action_dim) instead of an item (action_dim),
# we assume applying inverse transform on a batch works the same
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
assert (
action.ndim == 2
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
return action

View File

@@ -30,6 +30,7 @@ class Logger:
self._model_dir = self._log_dir / "models"
self._buffer_dir = self._log_dir / "buffers"
self._save_model = cfg.save_model
self._disable_wandb_artifact = cfg.wandb.disable_artifact
self._save_buffer = cfg.save_buffer
self._group = cfg_to_group(cfg)
self._seed = cfg.seed
@@ -71,9 +72,10 @@ class Logger:
self._model_dir.mkdir(parents=True, exist_ok=True)
fp = self._model_dir / f"{str(identifier)}.pt"
policy.save(fp)
if self._wandb:
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name
artifact = self._wandb.Artifact(
self._group + "-" + str(self._seed) + "-" + str(identifier),
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
type="model",
)
artifact.add_file(fp)

View File

@@ -1,115 +0,0 @@
from typing import List
import torch
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from .position_encoding import build_position_encoding
from .utils import NestedTensor, is_main_process
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super().__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(
self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool
):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {"layer4": "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(),
norm_layer=FrozenBatchNorm2d,
) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for _, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))
return out, pos
def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model

View File

@@ -0,0 +1,123 @@
from dataclasses import dataclass, field
@dataclass
class ActionChunkingTransformerConfig:
"""Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `state_dim`, `action_dim` and `camera_names`.
Args:
state_dim: Dimensionality of the observation state space (excluding images).
action_dim: Dimensionality of the action space.
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
camera_names: The (unique) set of names for the cameras.
chunk_size: The size of the action prediction "chunks" in units of environment steps.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
[0, 1]) for normalization.
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
subtracted).
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
torchvision.
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
convolution.
pre_norm: Whether to use "pre-norm" in the transformer blocks.
d_model: The transformer blocks' main hidden dimension.
n_heads: The number of heads to use in the transformer blocks' multi-head attention.
dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
layers.
feedforward_activation: The activation to use in the transformer block's feed-forward layers.
n_encoder_layers: The number of transformer layers to use for the transformer encoder.
n_decoder_layers: The number of transformer layers to use for the transformer decoder.
use_vae: Whether to use a variational objective during training. This introduces another transformer
which is used as the VAE's encoder (not to be confused with the transformer encoder - see
documentation in the policy class).
latent_dim: The VAE's latent dimension.
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given
environment step.
dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
"""
# Environment.
state_dim: int = 14
action_dim: int = 14
# Inputs / output structure.
n_obs_steps: int = 1
camera_names: tuple[str] = ("top",)
chunk_size: int = 100
n_action_steps: int = 100
# Vision preprocessing.
image_normalization_mean: tuple[float, float, float] = field(
default_factory=lambda: [0.485, 0.456, 0.406]
)
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
# Architecture.
# Vision backbone.
vision_backbone: str = "resnet18"
use_pretrained_backbone: bool = True
replace_final_stride_with_dilation: int = False
# Transformer layers.
pre_norm: bool = False
d_model: int = 512
n_heads: int = 8
dim_feedforward: int = 3200
feedforward_activation: str = "relu"
n_encoder_layers: int = 4
n_decoder_layers: int = 1
# VAE.
use_vae: bool = True
latent_dim: int = 32
n_vae_encoder_layers: int = 4
# Inference.
use_temporal_aggregation: bool = False
# Training and loss computation.
dropout: float = 0.1
kl_weight: float = 10.0
# ---
# TODO(alexander-soare): Remove these from the policy config.
batch_size: int = 8
lr: float = 1e-5
lr_backbone: float = 1e-5
weight_decay: float = 1e-4
grad_clip_norm: float = 10
utd: int = 1
def __post_init__(self):
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if self.use_temporal_aggregation:
raise NotImplementedError("Temporal aggregation is not yet implemented.")
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.camera_names != ["top"]:
raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.")
if len(set(self.camera_names)) != len(self.camera_names):
raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.")

View File

@@ -1,212 +0,0 @@
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from .backbone import build_backbone
from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps
def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class DETRVAE(nn.Module):
"""This is the DETR module that performs object detection"""
def __init__(
self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae
):
"""Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
self.vae = vae
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
# TODO(rcadene): understand what is env_state, and why it needs to be 7
self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None
# encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
self.latent_proj = nn.Linear(
hidden_dim, self.latent_dim * 2
) # project hidden state to latent std, var
self.register_buffer(
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
) # [CLS], qpos, a_seq
# decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(
2, hidden_dim
) # learned position embedding for proprio and latent
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
bs, _ = qpos.shape
### Obtain latent z from action sequence
if self.vae and is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
encoder_input = torch.cat(
[cls_embed, qpos_embed, action_embed], axis=1
) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
# do not mask cls token
# cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
# is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
mu = latent_info[:, : self.latent_dim]
logvar = latent_info[:, self.latent_dim :]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = None
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id, _ in enumerate(self.camera_names):
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(
src,
None,
self.query_embed.weight,
pos,
latent_input,
proprio_input,
self.additional_pos_embed.weight,
)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
return a_hat, is_pad_hat, [mu, logvar]
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
for _ in range(hidden_depth - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
mods.append(nn.Linear(hidden_dim, output_dim))
trunk = nn.Sequential(*mods)
return trunk
def build_encoder(args):
d_model = args.hidden_dim # 256
dropout = args.dropout # 0.1
nhead = args.nheads # 8
dim_feedforward = args.dim_feedforward # 2048
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
normalize_before = args.pre_norm # False
activation = "relu"
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
return encoder
def build(args):
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
backbone = build_backbone(args)
backbones.append(backbone)
transformer = build_transformer(args)
encoder = build_encoder(args)
model = DETRVAE(
backbones,
transformer,
encoder,
state_dim=args.state_dim,
action_dim=args.action_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
vae=args.vae,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
return model

View File

@@ -0,0 +1,599 @@
"""Action Chunking Transformer Policy
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
"""
import math
import time
from collections import deque
from itertools import chain
from typing import Callable
import einops
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
import torchvision.transforms as transforms
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
class ActionChunkingTransformerPolicy(nn.Module):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
model that encodes the target data (a sequence of actions), and the condition (the robot
joint-space).
- A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
have an option to train this model without the variational objective (in which case we drop the
`vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
Transformer
Used alone for inference
(acts as VAE decoder
during training)
┌───────────────────────┐
│ Outputs │
│ ▲ │
│ ┌─────►┌───────┐ │
┌──────┐ │ │ │Transf.│ │
│ │ │ ├─────►│decoder│ │
┌────┴────┐ │ │ │ │ │ │
│ │ │ │ ┌───┴───┬─►│ │ │
│ VAE │ │ │ │ │ └───────┘ │
│ encoder │ │ │ │Transf.│ │
│ │ │ │ │encoder│ │
└───▲─────┘ │ │ │ │ │
│ │ │ └───▲───┘ │
│ │ │ │ │
inputs └─────┼─────┘ │
│ │
└───────────────────────┘
"""
name = "act"
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
"""
super().__init__()
if cfg is None:
cfg = ActionChunkingTransformerConfig()
self.cfg = cfg
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
if self.cfg.use_vae:
self.vae_encoder = _TransformerEncoder(cfg)
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
# Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
self.latent_dim = cfg.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
# dimension.
self.register_buffer(
"vae_encoder_pos_enc",
_create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0),
)
# Backbone for image feature extraction.
self.image_normalizer = transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
pretrained=cfg.use_pretrained_backbone,
norm_layer=FrozenBatchNorm2d,
)
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
# map).
# Note: The forward method of this returns a dict: {"feature_map": output}.
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = _TransformerEncoder(cfg)
self.decoder = _TransformerDecoder(cfg)
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
self.encoder_robot_state_input_proj = nn.Linear(cfg.state_dim, cfg.d_model)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model)
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, cfg.d_model, kernel_size=1
)
# Transformer encoder positional embeddings.
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model)
self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(cfg.d_model, cfg.action_dim)
self._reset_parameters()
self._create_optimizer()
def _create_optimizer(self):
optimizer_params_dicts = [
{
"params": [
p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad
]
},
{
"params": [
p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad
],
"lr": self.cfg.lr_backbone,
},
]
self.optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
)
def _reset_parameters(self):
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def reset(self):
"""This should be called whenever the environment is reset."""
if self.cfg.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
if len(self._action_queue) == 0:
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
# has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch, **_) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {"l1_loss": l1_loss}
if self.cfg.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
else:
loss_dict["loss"] = l1_loss
return loss_dict
def update(self, batch, **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.train()
loss_dict = self.forward(batch)
loss = loss_dict["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
)
self.optimizer.step()
self.optimizer.zero_grad()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.cfg.lr,
"update_s": time.time() - start_time,
}
return info
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
This function expects `batch` to have (at least):
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images.{name}": (B, C, H, W) tensor of images.
}
"""
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
provided_cameras = {k.rsplit(".", 1)[-1] for k in batch if k.startswith("observation.images.")}
if len(missing := set(self.cfg.camera_names).difference(provided_cameras)) > 0:
raise ValueError(
f"The following camera images are missing from the provided batch: {missing}. Check the "
"configuration parameter: `camera_names`."
)
# Stack images in the order dictated by the camera names.
batch["observation.images"] = torch.stack(
[batch[f"observation.images.{name}"] for name in self.cfg.camera_names],
dim=-4,
)
def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
`batch` should have the following structure:
{
"observation.state": (B, state_dim) batch of robot states.
"observation.images": (B, n_cameras, C, H, W) batch of images.
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
}
Returns:
(B, chunk_size, action_dim) batch of action sequences
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
latent dimension.
"""
if self.cfg.use_vae and self.training:
assert (
"action" in batch
), "actions must be provided when using the variational objective in training mode."
self._stack_images(batch)
batch_size = batch["observation.state"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.cfg.use_vae and "action" in batch:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
1
) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
# Prepare fixed positional embedding.
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Forward pass through VAE encoder to get the latent PDF parameters.
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
# Sample the latent with the reparameterization trick.
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
else:
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
# Prepare all other transformer encoder inputs.
# Camera observation features and positional embeddings.
all_cam_features = []
all_cam_pos_embeds = []
images = self.image_normalizer(batch["observation.images"])
for cam_index in range(len(self.cfg.camera_names)):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
encoder_in = torch.cat(all_cam_features, axis=3)
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
# Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
latent_embed = self.encoder_latent_input_proj(latent_sample)
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in = torch.cat(
[
torch.stack([latent_embed, robot_state_embed], axis=0),
encoder_in.flatten(2).permute(2, 0, 1),
]
)
pos_embed = torch.cat(
[
self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1),
cam_pos_embed.flatten(2).permute(2, 0, 1),
],
axis=0,
)
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
decoder_in = torch.zeros(
(self.cfg.chunk_size, batch_size, self.cfg.d_model),
dtype=pos_embed.dtype,
device=pos_embed.device,
)
decoder_out = self.decoder(
decoder_in,
encoder_out,
encoder_pos_embed=pos_embed,
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
)
# Move back to (B, S, C).
decoder_out = decoder_out.transpose(0, 1)
actions = self.action_head(decoder_out)
return actions, (mu, log_sigma_x2)
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
class _TransformerEncoder(nn.Module):
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
def __init__(self, cfg: ActionChunkingTransformerConfig):
super().__init__()
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed)
x = self.norm(x)
return x
class _TransformerEncoderLayer(nn.Module):
def __init__(self, cfg: ActionChunkingTransformerConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
# Feed forward layers.
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
self.dropout = nn.Dropout(cfg.dropout)
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
self.norm1 = nn.LayerNorm(cfg.d_model)
self.norm2 = nn.LayerNorm(cfg.d_model)
self.dropout1 = nn.Dropout(cfg.dropout)
self.dropout2 = nn.Dropout(cfg.dropout)
self.activation = _get_activation_fn(cfg.feedforward_activation)
self.pre_norm = cfg.pre_norm
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
skip = x
if self.pre_norm:
x = self.norm1(x)
q = k = x if pos_embed is None else x + pos_embed
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
x = self.norm2(x)
else:
x = self.norm1(x)
skip = x
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = skip + self.dropout2(x)
if not self.pre_norm:
x = self.norm2(x)
return x
class _TransformerDecoder(nn.Module):
def __init__(self, cfg: ActionChunkingTransformerConfig):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
self.norm = nn.LayerNorm(cfg.d_model)
def forward(
self,
x: Tensor,
encoder_out: Tensor,
decoder_pos_embed: Tensor | None = None,
encoder_pos_embed: Tensor | None = None,
) -> Tensor:
for layer in self.layers:
x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
)
if self.norm is not None:
x = self.norm(x)
return x
class _TransformerDecoderLayer(nn.Module):
def __init__(self, cfg: ActionChunkingTransformerConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
# Feed forward layers.
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
self.dropout = nn.Dropout(cfg.dropout)
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
self.norm1 = nn.LayerNorm(cfg.d_model)
self.norm2 = nn.LayerNorm(cfg.d_model)
self.norm3 = nn.LayerNorm(cfg.d_model)
self.dropout1 = nn.Dropout(cfg.dropout)
self.dropout2 = nn.Dropout(cfg.dropout)
self.dropout3 = nn.Dropout(cfg.dropout)
self.activation = _get_activation_fn(cfg.feedforward_activation)
self.pre_norm = cfg.pre_norm
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
return tensor if pos_embed is None else tensor + pos_embed
def forward(
self,
x: Tensor,
encoder_out: Tensor,
decoder_pos_embed: Tensor | None = None,
encoder_pos_embed: Tensor | None = None,
) -> Tensor:
"""
Args:
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
cross-attending with.
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
Returns:
(DS, B, C) tensor of decoder output features.
"""
skip = x
if self.pre_norm:
x = self.norm1(x)
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
x = self.norm2(x)
else:
x = self.norm1(x)
skip = x
x = self.multihead_attn(
query=self.maybe_add_pos_embed(x, decoder_pos_embed),
key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
value=encoder_out,
)[0] # select just the output, not the attention weights
x = skip + self.dropout2(x)
if self.pre_norm:
skip = x
x = self.norm3(x)
else:
x = self.norm2(x)
skip = x
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = skip + self.dropout3(x)
if not self.pre_norm:
x = self.norm3(x)
return x
def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
"""1D sinusoidal positional embeddings as in Attention is All You Need.
Args:
num_positions: Number of token positions required.
Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension).
"""
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.from_numpy(sinusoid_table).float()
class _SinusoidalPositionEmbedding2D(nn.Module):
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
for the vertical direction, and 1/W for the horizontal direction.
"""
def __init__(self, dimension: int):
"""
Args:
dimension: The desired dimension of the embeddings.
"""
super().__init__()
self.dimension = dimension
self._two_pi = 2 * math.pi
self._eps = 1e-6
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
self._temperature = 10000
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.
Returns:
A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
"""
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
y_range = not_mask.cumsum(1, dtype=torch.float32)
x_range = not_mask.cumsum(2, dtype=torch.float32)
# "Normalize" the position index such that it ranges in [0, 2π].
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
# are non-zero by construction. This is an artifact of the original code.
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
inverse_frequency = self._temperature ** (
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
)
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
return pos_embed
def _get_activation_fn(activation: str) -> Callable:
"""Return an activation function given a string."""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")

View File

@@ -1,217 +0,0 @@
import logging
import time
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision.transforms as transforms
from lerobot.common.policies.act.detr_vae import build
def build_act_model_and_optimizer(cfg):
model = build(cfg)
param_dicts = [
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
"lr": cfg.lr_backbone,
},
]
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
return model, optimizer
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
dimension_wise_kld = klds.mean(0)
mean_kld = klds.mean(1).mean(0, True)
return total_kld, dimension_wise_kld, mean_kld
class ActionChunkingTransformerPolicy(nn.Module):
def __init__(self, cfg, device, n_action_steps=1):
super().__init__()
self.cfg = cfg
self.n_action_steps = n_action_steps
self.device = device
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
self.kl_weight = self.cfg.kl_weight
logging.info(f"KL Weight {self.kl_weight}")
self.to(self.device)
def update(self, replay_buffer, step):
del step
start_time = time.time()
self.train()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
def process_batch(batch, horizon, num_slices):
# trajectory t = 64, horizon h = 16
# (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon)
image = batch["observation", "image", "top"]
image = image[:, 0] # first observation t=0
# batch, num_cam, channel, height, width
image = image.unsqueeze(1)
assert image.ndim == 5
image = image.float()
state = batch["observation", "state"]
state = state[:, 0] # first observation t=0
# batch, qpos_dim
assert state.ndim == 2
action = batch["action"]
# batch, seq, action_dim
assert action.ndim == 3
assert action.shape[1] == horizon
if self.cfg.n_obs_steps > 1:
raise NotImplementedError()
# # keep first n observations of the slice corresponding to t=[-1,0]
# image = image[:, : self.cfg.n_obs_steps]
# state = state[:, : self.cfg.n_obs_steps]
out = {
"obs": {
"image": image.to(self.device, non_blocking=True),
"agent_pos": state.to(self.device, non_blocking=True),
},
"action": action.to(self.device, non_blocking=True),
}
return out
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time
loss = self.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
# self.lr_scheduler.step()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
# "lr": self.lr_scheduler.get_last_lr()[0],
"lr": self.cfg.lr,
"data_s": data_s,
"update_s": time.time() - start_time,
}
return info
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
def compute_loss(self, batch):
loss_dict = self._forward(
qpos=batch["obs"]["agent_pos"],
image=batch["obs"]["image"],
actions=batch["action"],
)
loss = loss_dict["loss"]
return loss
@torch.no_grad()
def forward(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count
self.eval()
# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image", "top"] = observation["image", "top"].unsqueeze(0)
# observation["state"] = observation["state"].unsqueeze(0)
# TODO(rcadene): remove hack
# add 1 camera dimension
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
obs_dict = {
"image": observation["image", "top"],
"agent_pos": observation["state"],
}
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation
raise NotImplementedError()
# all_time_actions[[t], t:t+num_queries] = action
# actions_for_curr_step = all_time_actions[:, t]
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
# actions_for_curr_step = actions_for_curr_step[actions_populated]
# k = 0.01
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
# exp_weights = exp_weights / exp_weights.sum()
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
# remove bsize=1
action = action.squeeze(0)
# take first predicted action or n first actions
action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps]
return action
def _forward(self, qpos, image, actions=None, is_pad=None):
env_state = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image = normalize(image)
is_training = actions is not None
if is_training: # training time
actions = actions[:, : self.model.num_queries]
if is_pad is not None:
is_pad = is_pad[:, : self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
loss_dict = {}
loss_dict["l1"] = l1
if self.cfg.vae:
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict["kl"] = total_kld[0]
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
else:
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
else:
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
return action

View File

@@ -1,101 +0,0 @@
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from .utils import NestedTensor
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor):
x = tensor
# mask = tensor_list.mask
# assert mask is not None
# not_mask = ~mask
not_mask = torch.ones_like(x[0, [0]])
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = (
torch.cat(
[
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
],
dim=-1,
)
.permute(2, 0, 1)
.unsqueeze(0)
.repeat(x.shape[0], 1, 1, 1)
)
return pos
def build_position_encoding(args):
n_steps = args.hidden_dim // 2
if args.position_embedding in ("v2", "sine"):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
elif args.position_embedding in ("v3", "learned"):
position_embedding = PositionEmbeddingLearned(n_steps)
else:
raise ValueError(f"not supported {args.position_embedding}")
return position_embedding

View File

@@ -1,370 +0,0 @@
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
import copy
from typing import Optional
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
class Transformer(nn.Module):
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
):
super().__init__()
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
decoder_layer = TransformerDecoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(
decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec
)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(
self,
src,
mask,
query_embed,
pos_embed,
latent_input=None,
proprio_input=None,
additional_pos_embed=None,
):
# TODO flatten only when input has H and W
if len(src.shape) == 4: # has H and W
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
# mask = mask.flatten(1)
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0)
else:
assert len(src.shape) == 3
# flatten NxHWxC to HWxNxC
bs, hw, c = src.shape
src = src.permute(1, 0, 2)
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
hs = hs.transpose(1, 2)
return hs
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
output = src
for layer in self.layers:
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
class TransformerEncoderLayer(nn.Module):
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(
tgt,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
return self.forward_post(
tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
)
def _get_clones(module, n):
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
def build_transformer(args):
return Transformer(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
dim_feedforward=args.dim_feedforward,
num_encoder_layers=args.enc_layers,
num_decoder_layers=args.dec_layers,
normalize_before=args.pre_norm,
return_intermediate_dec=True,
)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")

View File

@@ -1,477 +0,0 @@
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import datetime
import os
import pickle
import subprocess
import time
from collections import defaultdict, deque
from typing import List, Optional
import torch
import torch.distributed as dist
# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
from packaging import version
from torch import Tensor
if version.parse(torchvision.__version__) < version.parse("0.7"):
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list, strict=False):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416
return reduced_dict
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
]
)
mega_b = 1024.0 * 1024.0
for i, obj in enumerate(iterable):
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / mega_b,
)
)
else:
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
)
)
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
sha = "N/A"
diff = "clean"
branch = "N/A"
try:
sha = _run(["git", "rev-parse", "HEAD"])
subprocess.check_output(["git", "diff"], cwd=cwd)
diff = _run(["git", "diff-index", "HEAD"])
diff = "has uncommited changes" if diff else "clean"
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def collate_fn(batch):
batch = list(zip(*batch, strict=False))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor:
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
else:
raise ValueError("not supported")
return NestedTensor(tensor, mask)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(
torch.int64
)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if version.parse(torchvision.__version__) < version.parse("0.7"):
if input.numel() > 0:
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)

View File

@@ -0,0 +1,135 @@
from dataclasses import dataclass
@dataclass
class DiffusionConfig:
"""Configuration class for Diffusion Policy.
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `state_dim`, `action_dim` and `image_size`.
Args:
state_dim: Dimensionality of the observation state space (excluding images).
action_dim: Dimensionality of the action space.
image_size: (H, W) size of the input images.
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
[0, 1]) for normalization.
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
subtracted).
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
within the image size. If None, no cropping is done.
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
mode).
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
torchvision.
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
You may provide a variable number of dimensions, therefore also controlling the degree of
downsampling.
kernel_size: The convolutional kernel size of the diffusion modeling Unet.
n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
network. This is the output dimension of that network, i.e., the embedding dimension.
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
Bias modulation is used be default, while this parameter indicates whether to also use scale
modulation.
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
beta_start: Beta value for the first forward-diffusion step.
beta_end: Beta value for the last forward-diffusion step.
prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon"
or "sample". These have equivalent outcomes from a latent variable modeling perspective, but
"epsilon" has been shown to work better in many deep neural network settings.
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
denoising step at inference time. WARNING: you will need to make sure your action-space is
normalized to fit within this range.
clip_sample_range: The magnitude of the clipping range as described above.
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
"""
# Environment.
# Inherit these from the environment config.
state_dim: int = 2
action_dim: int = 2
image_size: tuple[int, int] = (96, 96)
# Inputs / output structure.
n_obs_steps: int = 2
horizon: int = 16
n_action_steps: int = 8
# Vision preprocessing.
image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5)
image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5)
# Architecture / modeling.
# Vision backbone.
vision_backbone: str = "resnet18"
crop_shape: tuple[int, int] | None = (84, 84)
crop_is_random: bool = True
use_pretrained_backbone: bool = False
use_group_norm: bool = True
spatial_softmax_num_keypoints: int = 32
# Unet.
down_dims: tuple[int, ...] = (512, 1024, 2048)
kernel_size: int = 5
n_groups: int = 8
diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True
# Noise scheduler.
num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001
beta_end: float = 0.02
prediction_type: str = "epsilon"
clip_sample: bool = True
clip_sample_range: float = 1.0
# Inference
num_inference_steps: int | None = None
# ---
# TODO(alexander-soare): Remove these from the policy config.
batch_size: int = 64
grad_clip_norm: int = 10
lr: float = 1.0e-4
lr_scheduler: str = "cosine"
lr_warmup_steps: int = 500
adam_betas: tuple[float, float] = (0.95, 0.999)
adam_eps: float = 1.0e-8
adam_weight_decay: float = 1.0e-6
utd: int = 1
use_ema: bool = True
ema_update_after_step: int = 0
ema_min_alpha: float = 0.0
ema_max_alpha: float = 0.9999
ema_inv_gamma: float = 1.0
ema_power: float = 0.75
def __post_init__(self):
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if self.crop_shape[0] > self.image_size[0] or self.crop_shape[1] > self.image_size[1]:
raise ValueError(
f"`crop_shape` should fit within `image_size`. Got {self.crop_shape} for `crop_shape` and "
f"{self.image_size} for `image_size`."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:
raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
)

View File

@@ -1,268 +0,0 @@
from typing import Dict
import torch
import torch.nn.functional as F # noqa: N812
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from einops import reduce
from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
from lerobot.common.policies.diffusion.model.mask_generator import LowdimMaskGenerator
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
from lerobot.common.policies.diffusion.model.normalizer import LinearNormalizer
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
class BaseImagePolicy(ModuleAttrMixin):
# init accepts keyword argument shape_meta, see config/task/*_image.yaml
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
obs_dict:
str: B,To,*
return: B,Ta,Da
"""
raise NotImplementedError()
# reset state for stateful policies
def reset(self):
pass
# ========== training ===========
# no standard training interface except setting normalizer
def set_normalizer(self, normalizer: LinearNormalizer):
raise NotImplementedError()
class DiffusionUnetImagePolicy(BaseImagePolicy):
def __init__(
self,
shape_meta: dict,
noise_scheduler: DDPMScheduler,
obs_encoder: MultiImageObsEncoder,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
obs_as_global_cond=True,
diffusion_step_embed_dim=256,
down_dims=(256, 512, 1024),
kernel_size=5,
n_groups=8,
cond_predict_scale=True,
# parameters passed to step
**kwargs,
):
super().__init__()
# parse shapes
action_shape = shape_meta["action"]["shape"]
assert len(action_shape) == 1
action_dim = action_shape[0]
# get feature dim
obs_feature_dim = obs_encoder.output_shape()[0]
# create diffusion model
input_dim = action_dim + obs_feature_dim
global_cond_dim = None
if obs_as_global_cond:
input_dim = action_dim
global_cond_dim = obs_feature_dim * n_obs_steps
model = ConditionalUnet1D(
input_dim=input_dim,
local_cond_dim=None,
global_cond_dim=global_cond_dim,
diffusion_step_embed_dim=diffusion_step_embed_dim,
down_dims=down_dims,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
)
self.obs_encoder = obs_encoder
self.model = model
self.noise_scheduler = noise_scheduler
self.mask_generator = LowdimMaskGenerator(
action_dim=action_dim,
obs_dim=0 if obs_as_global_cond else obs_feature_dim,
max_n_obs_steps=n_obs_steps,
fix_obs_steps=True,
action_visible=False,
)
self.horizon = horizon
self.obs_feature_dim = obs_feature_dim
self.action_dim = action_dim
self.n_action_steps = n_action_steps
self.n_obs_steps = n_obs_steps
self.obs_as_global_cond = obs_as_global_cond
self.kwargs = kwargs
if num_inference_steps is None:
num_inference_steps = noise_scheduler.config.num_train_timesteps
self.num_inference_steps = num_inference_steps
# ========= inference ============
def conditional_sample(
self,
condition_data,
condition_mask,
local_cond=None,
global_cond=None,
generator=None,
# keyword arguments to scheduler.step
**kwargs,
):
model = self.model
scheduler = self.noise_scheduler
trajectory = torch.randn(
size=condition_data.shape,
dtype=condition_data.dtype,
device=condition_data.device,
generator=generator,
)
# set step values
scheduler.set_timesteps(self.num_inference_steps)
for t in scheduler.timesteps:
# 1. apply conditioning
trajectory[condition_mask] = condition_data[condition_mask]
# 2. predict model output
model_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond)
# 3. compute previous image: x_t -> x_t-1
trajectory = scheduler.step(
model_output,
t,
trajectory,
generator=generator,
# **kwargs # TODO(rcadene): in diffusion_policy, expected to be {}
).prev_sample
# finally make sure conditioning is enforced
trajectory[condition_mask] = condition_data[condition_mask]
return trajectory
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
obs_dict: must include "obs" key
result: must include "action" key
"""
assert "past_action" not in obs_dict # not implemented yet
nobs = obs_dict
value = next(iter(nobs.values()))
bsize, n_obs_steps = value.shape[:2]
horizon = self.horizon
action_dim = self.action_dim
obs_dim = self.obs_feature_dim
assert self.n_obs_steps == n_obs_steps
# build input
device = self.device
dtype = self.dtype
# handle different ways of passing observation
local_cond = None
global_cond = None
if self.obs_as_global_cond:
# condition through global feature
this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs)
# reshape back to B, Do
global_cond = nobs_features.reshape(bsize, -1)
# empty data for action
cond_data = torch.zeros(size=(bsize, horizon, action_dim), device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
else:
# condition through impainting
this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs)
# reshape back to B, T, Do
nobs_features = nobs_features.reshape(bsize, n_obs_steps, -1)
cond_data = torch.zeros(size=(bsize, horizon, action_dim + obs_dim), device=device, dtype=dtype)
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
cond_data[:, :n_obs_steps, action_dim:] = nobs_features
cond_mask[:, :n_obs_steps, action_dim:] = True
# run sampling
nsample = self.conditional_sample(
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs
)
action_pred = nsample[..., :action_dim]
# get action
start = n_obs_steps - 1
end = start + self.n_action_steps
action = action_pred[:, start:end]
result = {"action": action, "action_pred": action_pred}
return result
def compute_loss(self, batch):
assert "valid_mask" not in batch
nobs = batch["obs"]
nactions = batch["action"]
batch_size = nactions.shape[0]
horizon = nactions.shape[1]
# handle different ways of passing observation
local_cond = None
global_cond = None
trajectory = nactions
cond_data = trajectory
if self.obs_as_global_cond:
# reshape B, T, ... to B*T
this_nobs = dict_apply(nobs, lambda x: x[:, : self.n_obs_steps, ...].reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs)
# reshape back to B, Do
global_cond = nobs_features.reshape(batch_size, -1)
else:
# reshape B, T, ... to B*T
this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
nobs_features = self.obs_encoder(this_nobs)
# reshape back to B, T, Do
nobs_features = nobs_features.reshape(batch_size, horizon, -1)
cond_data = torch.cat([nactions, nobs_features], dim=-1)
trajectory = cond_data.detach()
# generate impainting mask
condition_mask = self.mask_generator(trajectory.shape)
# Sample noise that we'll add to the images
noise = torch.randn(trajectory.shape, device=trajectory.device)
bsz = trajectory.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device
).long()
# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)
# compute loss mask
loss_mask = ~condition_mask
# apply conditioning
noisy_trajectory[condition_mask] = cond_data[condition_mask]
# Predict the noise residual
pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
pred_type = self.noise_scheduler.config.prediction_type
if pred_type == "epsilon":
target = noise
elif pred_type == "sample":
target = trajectory
else:
raise ValueError(f"Unsupported prediction type {pred_type}")
loss = F.mse_loss(pred, target, reduction="none")
loss = loss * loss_mask.type(loss.dtype)
loss = reduce(loss, "b ... -> b (...)", "mean")
loss = loss.mean()
return loss

View File

@@ -1,286 +0,0 @@
import logging
from typing import Union
import einops
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d
from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb
logger = logging.getLogger(__name__)
class ConditionalResidualBlock1D(nn.Module):
def __init__(
self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False
):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
]
)
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
cond_channels = out_channels
if cond_predict_scale:
cond_channels = out_channels * 2
self.cond_predict_scale = cond_predict_scale
self.out_channels = out_channels
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
Rearrange("batch t -> batch t 1"),
)
# make sure dimensions compatible
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x, cond):
"""
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
"""
out = self.blocks[0](x)
embed = self.cond_encoder(cond)
if self.cond_predict_scale:
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
scale = embed[:, 0, ...]
bias = embed[:, 1, ...]
out = scale * out + bias
else:
out = out + embed
out = self.blocks[1](out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(nn.Module):
def __init__(
self,
input_dim,
local_cond_dim=None,
global_cond_dim=None,
diffusion_step_embed_dim=256,
down_dims=None,
kernel_size=3,
n_groups=8,
cond_predict_scale=False,
):
super().__init__()
if down_dims is None:
down_dims = [256, 512, 1024]
all_dims = [input_dim] + list(down_dims)
start_dim = down_dims[0]
dsed = diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
cond_dim = dsed
if global_cond_dim is not None:
cond_dim += global_cond_dim
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
local_cond_encoder = None
if local_cond_dim is not None:
_, dim_out = in_out[0]
dim_in = local_cond_dim
local_cond_encoder = nn.ModuleList(
[
# down encoder
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
# up encoder
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
]
)
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList(
[
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
]
)
down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
down_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
Downsample1d(dim_out) if not is_last else nn.Identity(),
]
)
)
up_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
up_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
dim_out * 2,
dim_in,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
ConditionalResidualBlock1D(
dim_in,
dim_in,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
Upsample1d(dim_in) if not is_last else nn.Identity(),
]
)
)
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
nn.Conv1d(start_dim, input_dim, 1),
)
self.diffusion_step_encoder = diffusion_step_encoder
self.local_cond_encoder = local_cond_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
local_cond=None,
global_cond=None,
**kwargs,
):
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
local_cond: (B,T,local_cond_dim)
global_cond: (B,global_cond_dim)
output: (B,T,input_dim)
"""
sample = einops.rearrange(sample, "b h t -> b t h")
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
if global_cond is not None:
global_feature = torch.cat([global_feature, global_cond], axis=-1)
# encode local features
h_local = []
if local_cond is not None:
local_cond = einops.rearrange(local_cond, "b h t -> b t h")
resnet, resnet2 = self.local_cond_encoder
x = resnet(local_cond, global_feature)
h_local.append(x)
x = resnet2(local_cond, global_feature)
h_local.append(x)
x = sample
h = []
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
if idx == 0 and len(h_local) > 0:
x = x + h_local[0]
x = resnet2(x, global_feature)
h.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature)
# The correct condition should be:
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
# However this change will break compatibility with published checkpoints.
# Therefore it is left as a comment.
if idx == len(self.up_modules) and len(h_local) > 0:
x = x + h_local[1]
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, "b t h -> b h t")
return x

View File

@@ -1,47 +0,0 @@
import torch.nn as nn
# from einops.layers.torch import Rearrange
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
# def test():
# cb = Conv1dBlock(256, 128, kernel_size=3)
# x = torch.zeros((1,256,16))
# o = cb(x)

View File

@@ -1,294 +0,0 @@
import torch
import torch.nn as nn
import torchvision.transforms.functional as ttf
import lerobot.common.policies.diffusion.model.tensor_utils as tu
class CropRandomizer(nn.Module):
"""
Randomly sample crops at input, and then average across crop features at output.
"""
def __init__(
self,
input_shape,
crop_height,
crop_width,
num_crops=1,
pos_enc=False,
):
"""
Args:
input_shape (tuple, list): shape of input (not including batch dimension)
crop_height (int): crop height
crop_width (int): crop width
num_crops (int): number of random crops to take
pos_enc (bool): if True, add 2 channels to the output to encode the spatial
location of the cropped pixels in the source image
"""
super().__init__()
assert len(input_shape) == 3 # (C, H, W)
assert crop_height < input_shape[1]
assert crop_width < input_shape[2]
self.input_shape = input_shape
self.crop_height = crop_height
self.crop_width = crop_width
self.num_crops = num_crops
self.pos_enc = pos_enc
def output_shape_in(self, input_shape=None):
"""
Function to compute output shape from inputs to this module. Corresponds to
the @forward_in operation, where raw inputs (usually observation modalities)
are passed in.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
# the number of crops are reshaped into the batch dimension, increasing the batch
# size from B to B * N
out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
return [out_c, self.crop_height, self.crop_width]
def output_shape_out(self, input_shape=None):
"""
Function to compute output shape from inputs to this module. Corresponds to
the @forward_out operation, where processed inputs (usually encoded observation
modalities) are passed in.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# since the forward_out operation splits [B * N, ...] -> [B, N, ...]
# and then pools to result in [B, ...], only the batch dimension changes,
# and so the other dimensions retain their shape.
return list(input_shape)
def forward_in(self, inputs):
"""
Samples N random crops for each input in the batch, and then reshapes
inputs to [B * N, ...].
"""
assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions
if self.training:
# generate random crops
out, _ = sample_random_image_crops(
images=inputs,
crop_height=self.crop_height,
crop_width=self.crop_width,
num_crops=self.num_crops,
pos_enc=self.pos_enc,
)
# [B, N, ...] -> [B * N, ...]
return tu.join_dimensions(out, 0, 1)
else:
# take center crop during eval
out = ttf.center_crop(img=inputs, output_size=(self.crop_height, self.crop_width))
if self.num_crops > 1:
B, C, H, W = out.shape # noqa: N806
out = out.unsqueeze(1).expand(B, self.num_crops, C, H, W).reshape(-1, C, H, W)
# [B * N, ...]
return out
def forward_out(self, inputs):
"""
Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
to result in shape [B, ...] to make sure the network output is consistent with
what would have happened if there were no randomization.
"""
if self.num_crops <= 1:
return inputs
else:
batch_size = inputs.shape[0] // self.num_crops
out = tu.reshape_dimensions(
inputs, begin_axis=0, end_axis=0, target_dims=(batch_size, self.num_crops)
)
return out.mean(dim=1)
def forward(self, inputs):
return self.forward_in(inputs)
def __repr__(self):
"""Pretty print network."""
header = "{}".format(str(self.__class__.__name__))
msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
self.input_shape, self.crop_height, self.crop_width, self.num_crops
)
return msg
def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
"""
Crops images at the locations specified by @crop_indices. Crops will be
taken across all channels.
Args:
images (torch.Tensor): batch of images of shape [..., C, H, W]
crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
N is the number of crops to take per image and each entry corresponds
to the pixel height and width of where to take the crop. Note that
the indices can also be of shape [..., 2] if only 1 crop should
be taken per image. Leading dimensions must be consistent with
@images argument. Each index specifies the top left of the crop.
Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
H and W are the height and width of @images and CH and CW are
@crop_height and @crop_width.
crop_height (int): height of crop to take
crop_width (int): width of crop to take
Returns:
crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
"""
# make sure length of input shapes is consistent
assert crop_indices.shape[-1] == 2
ndim_im_shape = len(images.shape)
ndim_indices_shape = len(crop_indices.shape)
assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2)
# maybe pad so that @crop_indices is shape [..., N, 2]
is_padded = False
if ndim_im_shape == ndim_indices_shape + 2:
crop_indices = crop_indices.unsqueeze(-2)
is_padded = True
# make sure leading dimensions between images and indices are consistent
assert images.shape[:-3] == crop_indices.shape[:-2]
device = images.device
image_c, image_h, image_w = images.shape[-3:]
num_crops = crop_indices.shape[-2]
# make sure @crop_indices are in valid range
assert (crop_indices[..., 0] >= 0).all().item()
assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
assert (crop_indices[..., 1] >= 0).all().item()
assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
# convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
# 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
crop_ind_grid_h = torch.arange(crop_height).to(device)
crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1)
# 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
crop_ind_grid_w = torch.arange(crop_width).to(device)
crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0)
# combine into shape [CH, CW, 2]
crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
# Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
# After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
# shape array that tells us which pixels from the corresponding source image to grab.
grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2]
all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape)
# For using @torch.gather, convert to flat indices from 2D indices, and also
# repeat across the channel dimension. To get flat index of each pixel to grab for
# each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] # shape [..., N, CH, CW]
all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) # shape [..., N, C, CH, CW]
all_crop_inds = tu.flatten(all_crop_inds, begin_axis=-2) # shape [..., N, C, CH * CW]
# Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
# [..., N, C, CH * CW] -> [..., N, C, CH, CW]
reshape_axis = len(crops.shape) - 1
crops = tu.reshape_dimensions(
crops, begin_axis=reshape_axis, end_axis=reshape_axis, target_dims=(crop_height, crop_width)
)
if is_padded:
# undo padding -> [..., C, CH, CW]
crops = crops.squeeze(-4)
return crops
def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False):
"""
For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
@images.
Args:
images (torch.Tensor): batch of images of shape [..., C, H, W]
crop_height (int): height of crop to take
crop_width (int): width of crop to take
num_crops (n): number of crops to sample
pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
encoding of the original source pixel locations. This means that the
output crops will contain information about where in the source image
it was sampled from.
Returns:
crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
"""
device = images.device
# maybe add 2 channels of spatial encoding to the source image
source_im = images
if pos_enc:
# spatial encoding [y, x] in [0, 1]
h, w = source_im.shape[-2:]
pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
pos_y = pos_y.float().to(device) / float(h)
pos_x = pos_x.float().to(device) / float(w)
position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
# unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
leading_shape = source_im.shape[:-3]
position_enc = position_enc[(None,) * len(leading_shape)]
position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
# concat across channel dimension with input
source_im = torch.cat((source_im, position_enc), dim=-3)
# make sure sample boundaries ensure crops are fully within the images
image_c, image_h, image_w = source_im.shape[-3:]
max_sample_h = image_h - crop_height
max_sample_w = image_w - crop_width
# Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
# Each gets @num_crops samples - typically this will just be the batch dimension (B), so
# we will sample [B, N] indices, but this supports having more than one leading dimension,
# or possibly no leading dimension.
#
# Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) # shape [..., N, 2]
crops = crop_image_from_indices(
images=source_im,
crop_indices=crop_inds,
crop_height=crop_height,
crop_width=crop_width,
)
return crops, crop_inds

View File

@@ -1,41 +0,0 @@
import torch
import torch.nn as nn
class DictOfTensorMixin(nn.Module):
def __init__(self, params_dict=None):
super().__init__()
if params_dict is None:
params_dict = nn.ParameterDict()
self.params_dict = params_dict
@property
def device(self):
return next(iter(self.parameters())).device
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
def dfs_add(dest, keys, value: torch.Tensor):
if len(keys) == 1:
dest[keys[0]] = value
return
if keys[0] not in dest:
dest[keys[0]] = nn.ParameterDict()
dfs_add(dest[keys[0]], keys[1:], value)
def load_dict(state_dict, prefix):
out_dict = nn.ParameterDict()
for key, value in state_dict.items():
value: torch.Tensor
if key.startswith(prefix):
param_keys = key[len(prefix) :].split(".")[1:]
# if len(param_keys) == 0:
# import pdb; pdb.set_trace()
dfs_add(out_dict, param_keys, value.clone())
return out_dict
self.params_dict = load_dict(state_dict, prefix + "params_dict")
self.params_dict.requires_grad_(False)
return

View File

@@ -1,84 +0,0 @@
import torch
from torch.nn.modules.batchnorm import _BatchNorm
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(
self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999
):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self.averaged_model = model
self.averaged_model.eval()
self.averaged_model.requires_grad_(False)
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.decay = 0.0
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
value = 1 - (1 + step / self.inv_gamma) ** -self.power
if step <= 0:
return 0.0
return max(self.min_value, min(value, self.max_value))
@torch.no_grad()
def step(self, new_model):
self.decay = self.get_decay(self.optimization_step)
# old_all_dataptrs = set()
# for param in new_model.parameters():
# data_ptr = param.data_ptr()
# if data_ptr != 0:
# old_all_dataptrs.add(data_ptr)
# all_dataptrs = set()
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False):
for param, ema_param in zip(
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False
):
# iterative over immediate parameters only.
if isinstance(param, dict):
raise RuntimeError("Dict parameter not supported")
# data_ptr = param.data_ptr()
# if data_ptr != 0:
# all_dataptrs.add(data_ptr)
if isinstance(module, _BatchNorm):
# skip batchnorms
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
elif not param.requires_grad:
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
else:
ema_param.mul_(self.decay)
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
# verify that iterating over module and then parameters is identical to parameters recursively.
# assert old_all_dataptrs == all_dataptrs
self.optimization_step += 1

View File

@@ -1,46 +0,0 @@
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, Optimizer, Optional, SchedulerType, Union
def get_scheduler(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
**kwargs,
):
"""
Added kwargs vs diffuser's original implementation
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer, **kwargs)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs
)

View File

@@ -1,65 +0,0 @@
import torch
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
class LowdimMaskGenerator(ModuleAttrMixin):
def __init__(
self,
action_dim,
obs_dim,
# obs mask setup
max_n_obs_steps=2,
fix_obs_steps=True,
# action mask
action_visible=False,
):
super().__init__()
self.action_dim = action_dim
self.obs_dim = obs_dim
self.max_n_obs_steps = max_n_obs_steps
self.fix_obs_steps = fix_obs_steps
self.action_visible = action_visible
@torch.no_grad()
def forward(self, shape, seed=None):
device = self.device
B, T, D = shape # noqa: N806
assert (self.action_dim + self.obs_dim) == D
# create all tensors on this device
rng = torch.Generator(device=device)
if seed is not None:
rng = rng.manual_seed(seed)
# generate dim mask
dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device)
is_action_dim = dim_mask.clone()
is_action_dim[..., : self.action_dim] = True
is_obs_dim = ~is_action_dim
# generate obs mask
if self.fix_obs_steps:
obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device)
else:
obs_steps = torch.randint(
low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device
)
steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T)
obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
obs_mask = obs_mask & is_obs_dim
# generate action mask
if self.action_visible:
action_steps = torch.maximum(
obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device)
)
action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
action_mask = action_mask & is_action_dim
mask = obs_mask
if self.action_visible:
mask = mask | action_mask
return mask

View File

@@ -1,15 +0,0 @@
import torch.nn as nn
class ModuleAttrMixin(nn.Module):
def __init__(self):
super().__init__()
self._dummy_variable = nn.Parameter()
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype

View File

@@ -1,189 +0,0 @@
import copy
from typing import Dict, Tuple, Union
import torch
import torch.nn as nn
import torchvision
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class MultiImageObsEncoder(ModuleAttrMixin):
def __init__(
self,
shape_meta: dict,
rgb_model: Union[nn.Module, Dict[str, nn.Module]],
resize_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
random_crop: bool = True,
# replace BatchNorm with GroupNorm
use_group_norm: bool = False,
# use single rgb model for all rgb inputs
share_rgb_model: bool = False,
# renormalize rgb input with imagenet normalization
# assuming input in [0,1]
imagenet_norm: bool = False,
):
"""
Assumes rgb input: B,C,H,W
Assumes low_dim input: B,D
"""
super().__init__()
rgb_keys = []
low_dim_keys = []
key_model_map = nn.ModuleDict()
key_transform_map = nn.ModuleDict()
key_shape_map = {}
# handle sharing vision backbone
if share_rgb_model:
assert isinstance(rgb_model, nn.Module)
key_model_map["rgb"] = rgb_model
obs_shape_meta = shape_meta["obs"]
for key, attr in obs_shape_meta.items():
shape = tuple(attr["shape"])
type = attr.get("type", "low_dim")
key_shape_map[key] = shape
if type == "rgb":
rgb_keys.append(key)
# configure model for this key
this_model = None
if not share_rgb_model:
if isinstance(rgb_model, dict):
# have provided model for each key
this_model = rgb_model[key]
else:
assert isinstance(rgb_model, nn.Module)
# have a copy of the rgb model
this_model = copy.deepcopy(rgb_model)
if this_model is not None:
if use_group_norm:
this_model = replace_submodules(
root_module=this_model,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16, num_channels=x.num_features
),
)
key_model_map[key] = this_model
# configure resize
input_shape = shape
this_resizer = nn.Identity()
if resize_shape is not None:
if isinstance(resize_shape, dict):
h, w = resize_shape[key]
else:
h, w = resize_shape
this_resizer = torchvision.transforms.Resize(size=(h, w))
input_shape = (shape[0], h, w)
# configure randomizer
this_randomizer = nn.Identity()
if crop_shape is not None:
if isinstance(crop_shape, dict):
h, w = crop_shape[key]
else:
h, w = crop_shape
if random_crop:
this_randomizer = CropRandomizer(
input_shape=input_shape, crop_height=h, crop_width=w, num_crops=1, pos_enc=False
)
else:
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
# configure normalizer
this_normalizer = nn.Identity()
if imagenet_norm:
# TODO(rcadene): move normalizer to dataset and env
this_normalizer = torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
key_transform_map[key] = this_transform
elif type == "low_dim":
low_dim_keys.append(key)
else:
raise RuntimeError(f"Unsupported obs type: {type}")
rgb_keys = sorted(rgb_keys)
low_dim_keys = sorted(low_dim_keys)
self.shape_meta = shape_meta
self.key_model_map = key_model_map
self.key_transform_map = key_transform_map
self.share_rgb_model = share_rgb_model
self.rgb_keys = rgb_keys
self.low_dim_keys = low_dim_keys
self.key_shape_map = key_shape_map
def forward(self, obs_dict):
batch_size = None
features = []
# process rgb input
if self.share_rgb_model:
# pass all rgb obs to rgb model
imgs = []
for key in self.rgb_keys:
img = obs_dict[key]
if batch_size is None:
batch_size = img.shape[0]
else:
assert batch_size == img.shape[0]
assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img)
imgs.append(img)
# (N*B,C,H,W)
imgs = torch.cat(imgs, dim=0)
# (N*B,D)
feature = self.key_model_map["rgb"](imgs)
# (N,B,D)
feature = feature.reshape(-1, batch_size, *feature.shape[1:])
# (B,N,D)
feature = torch.moveaxis(feature, 0, 1)
# (B,N*D)
feature = feature.reshape(batch_size, -1)
features.append(feature)
else:
# run each rgb obs to independent models
for key in self.rgb_keys:
img = obs_dict[key]
if batch_size is None:
batch_size = img.shape[0]
else:
assert batch_size == img.shape[0]
assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img)
feature = self.key_model_map[key](img)
features.append(feature)
# process lowdim input
for key in self.low_dim_keys:
data = obs_dict[key]
if batch_size is None:
batch_size = data.shape[0]
else:
assert batch_size == data.shape[0]
assert data.shape[1:] == self.key_shape_map[key]
features.append(data)
# concatenate all features
result = torch.cat(features, dim=-1)
return result
@torch.no_grad()
def output_shape(self):
example_obs_dict = {}
obs_shape_meta = self.shape_meta["obs"]
batch_size = 1
for key, attr in obs_shape_meta.items():
shape = tuple(attr["shape"])
this_obs = torch.zeros((batch_size,) + shape, dtype=self.dtype, device=self.device)
example_obs_dict[key] = this_obs
example_output = self.forward(example_obs_dict)
output_shape = example_output.shape[1:]
return output_shape

View File

@@ -1,358 +0,0 @@
from typing import Dict, Union
import numpy as np
import torch
import torch.nn as nn
import zarr
from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
class LinearNormalizer(DictOfTensorMixin):
avaliable_modes = ["limits", "gaussian"]
@torch.no_grad()
def fit(
self,
data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array],
last_n_dims=1,
dtype=torch.float32,
mode="limits",
output_max=1.0,
output_min=-1.0,
range_eps=1e-4,
fit_offset=True,
):
if isinstance(data, dict):
for key, value in data.items():
self.params_dict[key] = _fit(
value,
last_n_dims=last_n_dims,
dtype=dtype,
mode=mode,
output_max=output_max,
output_min=output_min,
range_eps=range_eps,
fit_offset=fit_offset,
)
else:
self.params_dict["_default"] = _fit(
data,
last_n_dims=last_n_dims,
dtype=dtype,
mode=mode,
output_max=output_max,
output_min=output_min,
range_eps=range_eps,
fit_offset=fit_offset,
)
def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
return self.normalize(x)
def __getitem__(self, key: str):
return SingleFieldLinearNormalizer(self.params_dict[key])
def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"):
self.params_dict[key] = value.params_dict
def _normalize_impl(self, x, forward=True):
if isinstance(x, dict):
result = {}
for key, value in x.items():
params = self.params_dict[key]
result[key] = _normalize(value, params, forward=forward)
return result
else:
if "_default" not in self.params_dict:
raise RuntimeError("Not initialized")
params = self.params_dict["_default"]
return _normalize(x, params, forward=forward)
def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
return self._normalize_impl(x, forward=True)
def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
return self._normalize_impl(x, forward=False)
def get_input_stats(self) -> Dict:
if len(self.params_dict) == 0:
raise RuntimeError("Not initialized")
if len(self.params_dict) == 1 and "_default" in self.params_dict:
return self.params_dict["_default"]["input_stats"]
result = {}
for key, value in self.params_dict.items():
if key != "_default":
result[key] = value["input_stats"]
return result
def get_output_stats(self, key="_default"):
input_stats = self.get_input_stats()
if "min" in input_stats:
# no dict
return dict_apply(input_stats, self.normalize)
result = {}
for key, group in input_stats.items():
this_dict = {}
for name, value in group.items():
this_dict[name] = self.normalize({key: value})[key]
result[key] = this_dict
return result
class SingleFieldLinearNormalizer(DictOfTensorMixin):
avaliable_modes = ["limits", "gaussian"]
@torch.no_grad()
def fit(
self,
data: Union[torch.Tensor, np.ndarray, zarr.Array],
last_n_dims=1,
dtype=torch.float32,
mode="limits",
output_max=1.0,
output_min=-1.0,
range_eps=1e-4,
fit_offset=True,
):
self.params_dict = _fit(
data,
last_n_dims=last_n_dims,
dtype=dtype,
mode=mode,
output_max=output_max,
output_min=output_min,
range_eps=range_eps,
fit_offset=fit_offset,
)
@classmethod
def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs):
obj = cls()
obj.fit(data, **kwargs)
return obj
@classmethod
def create_manual(
cls,
scale: Union[torch.Tensor, np.ndarray],
offset: Union[torch.Tensor, np.ndarray],
input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]],
):
def to_tensor(x):
if not isinstance(x, torch.Tensor):
x = torch.from_numpy(x)
x = x.flatten()
return x
# check
for x in [offset] + list(input_stats_dict.values()):
assert x.shape == scale.shape
assert x.dtype == scale.dtype
params_dict = nn.ParameterDict(
{
"scale": to_tensor(scale),
"offset": to_tensor(offset),
"input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)),
}
)
return cls(params_dict)
@classmethod
def create_identity(cls, dtype=torch.float32):
scale = torch.tensor([1], dtype=dtype)
offset = torch.tensor([0], dtype=dtype)
input_stats_dict = {
"min": torch.tensor([-1], dtype=dtype),
"max": torch.tensor([1], dtype=dtype),
"mean": torch.tensor([0], dtype=dtype),
"std": torch.tensor([1], dtype=dtype),
}
return cls.create_manual(scale, offset, input_stats_dict)
def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
return _normalize(x, self.params_dict, forward=True)
def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
return _normalize(x, self.params_dict, forward=False)
def get_input_stats(self):
return self.params_dict["input_stats"]
def get_output_stats(self):
return dict_apply(self.params_dict["input_stats"], self.normalize)
def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
return self.normalize(x)
def _fit(
data: Union[torch.Tensor, np.ndarray, zarr.Array],
last_n_dims=1,
dtype=torch.float32,
mode="limits",
output_max=1.0,
output_min=-1.0,
range_eps=1e-4,
fit_offset=True,
):
assert mode in ["limits", "gaussian"]
assert last_n_dims >= 0
assert output_max > output_min
# convert data to torch and type
if isinstance(data, zarr.Array):
data = data[:]
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if dtype is not None:
data = data.type(dtype)
# convert shape
dim = 1
if last_n_dims > 0:
dim = np.prod(data.shape[-last_n_dims:])
data = data.reshape(-1, dim)
# compute input stats min max mean std
input_min, _ = data.min(axis=0)
input_max, _ = data.max(axis=0)
input_mean = data.mean(axis=0)
input_std = data.std(axis=0)
# compute scale and offset
if mode == "limits":
if fit_offset:
# unit scale
input_range = input_max - input_min
ignore_dim = input_range < range_eps
input_range[ignore_dim] = output_max - output_min
scale = (output_max - output_min) / input_range
offset = output_min - scale * input_min
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
# ignore dims scaled to mean of output max and min
else:
# use this when data is pre-zero-centered.
assert output_max > 0
assert output_min < 0
# unit abs
output_abs = min(abs(output_min), abs(output_max))
input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max))
ignore_dim = input_abs < range_eps
input_abs[ignore_dim] = output_abs
# don't scale constant channels
scale = output_abs / input_abs
offset = torch.zeros_like(input_mean)
elif mode == "gaussian":
ignore_dim = input_std < range_eps
scale = input_std.clone()
scale[ignore_dim] = 1
scale = 1 / scale
offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean)
# save
this_params = nn.ParameterDict(
{
"scale": scale,
"offset": offset,
"input_stats": nn.ParameterDict(
{"min": input_min, "max": input_max, "mean": input_mean, "std": input_std}
),
}
)
for p in this_params.parameters():
p.requires_grad_(False)
return this_params
def _normalize(x, params, forward=True):
assert "scale" in params
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
scale = params["scale"]
offset = params["offset"]
x = x.to(device=scale.device, dtype=scale.dtype)
src_shape = x.shape
x = x.reshape(-1, scale.shape[0])
x = x * scale + offset if forward else (x - offset) / scale
x = x.reshape(src_shape)
return x
def test():
data = torch.zeros((100, 10, 9, 2)).uniform_()
data[..., 0, 0] = 0
normalizer = SingleFieldLinearNormalizer()
normalizer.fit(data, mode="limits", last_n_dims=2)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.max(), 1.0)
assert np.allclose(datan.min(), -1.0)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
_ = normalizer.get_input_stats()
_ = normalizer.get_output_stats()
normalizer = SingleFieldLinearNormalizer()
normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.max(), 1.0, atol=1e-3)
assert np.allclose(datan.min(), 0.0, atol=1e-3)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
data = torch.zeros((100, 10, 9, 2)).uniform_()
normalizer = SingleFieldLinearNormalizer()
normalizer.fit(data, mode="gaussian", last_n_dims=0)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.mean(), 0.0, atol=1e-3)
assert np.allclose(datan.std(), 1.0, atol=1e-3)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
# dict
data = torch.zeros((100, 10, 9, 2)).uniform_()
data[..., 0, 0] = 0
normalizer = LinearNormalizer()
normalizer.fit(data, mode="limits", last_n_dims=2)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.max(), 1.0)
assert np.allclose(datan.min(), -1.0)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
_ = normalizer.get_input_stats()
_ = normalizer.get_output_stats()
data = {
"obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512,
"action": torch.zeros((1000, 128, 2)).uniform_() * 512,
}
normalizer = LinearNormalizer()
normalizer.fit(data)
datan = normalizer.normalize(data)
dataun = normalizer.unnormalize(datan)
for key in data:
assert torch.allclose(data[key], dataun[key], atol=1e-4)
_ = normalizer.get_input_stats()
_ = normalizer.get_output_stats()
state_dict = normalizer.state_dict()
n = LinearNormalizer()
n.load_state_dict(state_dict)
datan = n.normalize(data)
dataun = n.unnormalize(datan)
for key in data:
assert torch.allclose(data[key], dataun[key], atol=1e-4)

View File

@@ -1,19 +0,0 @@
import math
import torch
import torch.nn as nn
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb

View File

@@ -1,971 +0,0 @@
"""
A collection of utilities for working with nested tensor structures consisting
of numpy arrays and torch tensors.
"""
import collections
import numpy as np
import torch
def recursive_dict_list_tuple_apply(x, type_func_dict):
"""
Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
{data_type: function_to_apply}.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
type_func_dict (dict): a mapping from data types to the functions to be
applied for each data type.
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
assert list not in type_func_dict
assert tuple not in type_func_dict
assert dict not in type_func_dict
if isinstance(x, (dict, collections.OrderedDict)):
new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else {}
for k, v in x.items():
new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
return new_x
elif isinstance(x, (list, tuple)):
ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
if isinstance(x, tuple):
ret = tuple(ret)
return ret
else:
for t, f in type_func_dict.items():
if isinstance(x, t):
return f(x)
else:
raise NotImplementedError("Cannot handle data type %s" % str(type(x)))
def map_tensor(x, func):
"""
Apply function @func to torch.Tensor objects in a nested dictionary or
list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
func (function): function to apply to each tensor
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: func,
type(None): lambda x: x,
},
)
def map_ndarray(x, func):
"""
Apply function @func to np.ndarray objects in a nested dictionary or
list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
func (function): function to apply to each array
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
np.ndarray: func,
type(None): lambda x: x,
},
)
def map_tensor_ndarray(x, tensor_func, ndarray_func):
"""
Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
np.ndarray objects in a nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
tensor_func (function): function to apply to each tensor
ndarray_Func (function): function to apply to each array
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: tensor_func,
np.ndarray: ndarray_func,
type(None): lambda x: x,
},
)
def clone(x):
"""
Clones all torch tensors and numpy arrays in nested dictionary or list
or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.clone(),
np.ndarray: lambda x: x.copy(),
type(None): lambda x: x,
},
)
def detach(x):
"""
Detaches all torch tensors in nested dictionary or list
or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.detach(),
},
)
def to_batch(x):
"""
Introduces a leading batch dimension of 1 for all torch tensors and numpy
arrays in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x[None, ...],
np.ndarray: lambda x: x[None, ...],
type(None): lambda x: x,
},
)
def to_sequence(x):
"""
Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
arrays in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x[:, None, ...],
np.ndarray: lambda x: x[:, None, ...],
type(None): lambda x: x,
},
)
def index_at_time(x, ind):
"""
Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
ind (int): index
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x[:, ind, ...],
np.ndarray: lambda x: x[:, ind, ...],
type(None): lambda x: x,
},
)
def unsqueeze(x, dim):
"""
Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
dim (int): dimension
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.unsqueeze(dim=dim),
np.ndarray: lambda x: np.expand_dims(x, axis=dim),
type(None): lambda x: x,
},
)
def contiguous(x):
"""
Makes all torch tensors and numpy arrays contiguous in nested dictionary or
list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.contiguous(),
np.ndarray: lambda x: np.ascontiguousarray(x),
type(None): lambda x: x,
},
)
def to_device(x, device):
"""
Sends all torch tensors in nested dictionary or list or tuple to device
@device, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
device (torch.Device): device to send tensors to
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x, d=device: x.to(d),
type(None): lambda x: x,
},
)
def to_tensor(x):
"""
Converts all numpy arrays in nested dictionary or list or tuple to
torch tensors (and leaves existing torch Tensors as-is), and returns
a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x,
np.ndarray: lambda x: torch.from_numpy(x),
type(None): lambda x: x,
},
)
def to_numpy(x):
"""
Converts all torch tensors in nested dictionary or list or tuple to
numpy (and leaves existing numpy arrays as-is), and returns
a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
def f(tensor):
if tensor.is_cuda:
return tensor.detach().cpu().numpy()
else:
return tensor.detach().numpy()
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: f,
np.ndarray: lambda x: x,
type(None): lambda x: x,
},
)
def to_list(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to a list, and returns a new nested structure. Useful for
json encoding.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
def f(tensor):
if tensor.is_cuda:
return tensor.detach().cpu().numpy().tolist()
else:
return tensor.detach().numpy().tolist()
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: f,
np.ndarray: lambda x: x.tolist(),
type(None): lambda x: x,
},
)
def to_float(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to float type entries, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.float(),
np.ndarray: lambda x: x.astype(np.float32),
type(None): lambda x: x,
},
)
def to_uint8(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to uint8 type entries, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.byte(),
np.ndarray: lambda x: x.astype(np.uint8),
type(None): lambda x: x,
},
)
def to_torch(x, device):
"""
Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
torch tensors on device @device and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
device (torch.Device): device to send tensors to
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return to_device(to_float(to_tensor(x)), device)
def to_one_hot_single(tensor, num_class):
"""
Convert tensor to one-hot representation, assuming a certain number of total class labels.
Args:
tensor (torch.Tensor): tensor containing integer labels
num_class (int): number of classes
Returns:
x (torch.Tensor): tensor containing one-hot representation of labels
"""
x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device)
x.scatter_(-1, tensor.unsqueeze(-1), 1)
return x
def to_one_hot(tensor, num_class):
"""
Convert all tensors in nested dictionary or list or tuple to one-hot representation,
assuming a certain number of total class labels.
Args:
tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
num_class (int): number of classes
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc))
def flatten_single(x, begin_axis=1):
"""
Flatten a tensor in all dimensions from @begin_axis onwards.
Args:
x (torch.Tensor): tensor to flatten
begin_axis (int): which axis to flatten from
Returns:
y (torch.Tensor): flattened tensor
"""
fixed_size = x.size()[:begin_axis]
_s = list(fixed_size) + [-1]
return x.reshape(*_s)
def flatten(x, begin_axis=1):
"""
Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): which axis to flatten from
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
},
)
def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
"""
Reshape selected dimensions in a tensor to a target dimension.
Args:
x (torch.Tensor): tensor to reshape
begin_axis (int): begin dimension
end_axis (int): end dimension
target_dims (tuple or list): target shape for the range of dimensions
(@begin_axis, @end_axis)
Returns:
y (torch.Tensor): reshaped tensor
"""
assert begin_axis <= end_axis
assert begin_axis >= 0
assert end_axis < len(x.shape)
assert isinstance(target_dims, (tuple, list))
s = x.shape
final_s = []
for i in range(len(s)):
if i == begin_axis:
final_s.extend(target_dims)
elif i < begin_axis or i > end_axis:
final_s.append(s[i])
return x.reshape(*final_s)
def reshape_dimensions(x, begin_axis, end_axis, target_dims):
"""
Reshape selected dimensions for all tensors in nested dictionary or list or tuple
to a target dimension.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): begin dimension
end_axis (int): end dimension
target_dims (tuple or list): target shape for the range of dimensions
(@begin_axis, @end_axis)
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=t
),
np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=t
),
type(None): lambda x: x,
},
)
def join_dimensions(x, begin_axis, end_axis):
"""
Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
all tensors in nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): begin dimension
end_axis (int): end dimension
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=[-1]
),
np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=[-1]
),
type(None): lambda x: x,
},
)
def expand_at_single(x, size, dim):
"""
Expand a tensor at a single dimension @dim by @size
Args:
x (torch.Tensor): input tensor
size (int): size to expand
dim (int): dimension to expand
Returns:
y (torch.Tensor): expanded tensor
"""
assert dim < x.ndimension()
assert x.shape[dim] == 1
expand_dims = [-1] * x.ndimension()
expand_dims[dim] = size
return x.expand(*expand_dims)
def expand_at(x, size, dim):
"""
Expand all tensors in nested dictionary or list or tuple at a single
dimension @dim by @size.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size to expand
dim (int): dimension to expand
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
def unsqueeze_expand_at(x, size, dim):
"""
Unsqueeze and expand a tensor at a dimension @dim by @size.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size to expand
dim (int): dimension to unsqueeze and expand
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
x = unsqueeze(x, dim)
return expand_at(x, size, dim)
def repeat_by_expand_at(x, repeats, dim):
"""
Repeat a dimension by combining expand and reshape operations.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
repeats (int): number of times to repeat the target dimension
dim (int): dimension to repeat on
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
x = unsqueeze_expand_at(x, repeats, dim + 1)
return join_dimensions(x, dim, dim + 1)
def named_reduce_single(x, reduction, dim):
"""
Reduce tensor at a dimension by named reduction functions.
Args:
x (torch.Tensor): tensor to be reduced
reduction (str): one of ["sum", "max", "mean", "flatten"]
dim (int): dimension to be reduced (or begin axis for flatten)
Returns:
y (torch.Tensor): reduced tensor
"""
assert x.ndimension() > dim
assert reduction in ["sum", "max", "mean", "flatten"]
if reduction == "flatten":
x = flatten(x, begin_axis=dim)
elif reduction == "max":
x = torch.max(x, dim=dim)[0] # [B, D]
elif reduction == "sum":
x = torch.sum(x, dim=dim)
else:
x = torch.mean(x, dim=dim)
return x
def named_reduce(x, reduction, dim):
"""
Reduces all tensors in nested dictionary or list or tuple at a dimension
using a named reduction function.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
reduction (str): one of ["sum", "max", "mean", "flatten"]
dim (int): dimension to be reduced (or begin axis for flatten)
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
"""
This function indexes out a target dimension of a tensor in a structured way,
by allowing a different value to be selected for each member of a flat index
tensor (@indices) corresponding to a source dimension. This can be interpreted
as moving along the source dimension, using the corresponding index value
in @indices to select values for all other dimensions outside of the
source and target dimensions. A common use case is to gather values
in target dimension 1 for each batch member (target dimension 0).
Args:
x (torch.Tensor): tensor to gather values for
target_dim (int): dimension to gather values along
source_dim (int): dimension to hold constant and use for gathering values
from the other dimensions
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
@source_dim
Returns:
y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
"""
assert len(indices.shape) == 1
assert x.shape[source_dim] == indices.shape[0]
# unsqueeze in all dimensions except the source dimension
new_shape = [1] * x.ndimension()
new_shape[source_dim] = -1
indices = indices.reshape(*new_shape)
# repeat in all dimensions - but preserve shape of source dimension,
# and make sure target_dimension has singleton dimension
expand_shape = list(x.shape)
expand_shape[source_dim] = -1
expand_shape[target_dim] = 1
indices = indices.expand(*expand_shape)
out = x.gather(dim=target_dim, index=indices)
return out.squeeze(target_dim)
def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
"""
Apply @gather_along_dim_with_dim_single to all tensors in a nested
dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
target_dim (int): dimension to gather values along
source_dim (int): dimension to hold constant and use for gathering values
from the other dimensions
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
@source_dim
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(
x, lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i)
)
def gather_sequence_single(seq, indices):
"""
Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
the batch given an index for each sequence.
Args:
seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
indices (torch.Tensor): tensor indices of shape [B]
Return:
y (torch.Tensor): indexed tensor of shape [B, ....]
"""
return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices)
def gather_sequence(seq, indices):
"""
Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
for tensors with leading dimensions [B, T, ...].
Args:
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
indices (torch.Tensor): tensor indices of shape [B]
Returns:
y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
"""
return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices)
def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None):
"""
Pad input tensor or array @seq in the time dimension (dimension 1).
Args:
seq (np.ndarray or torch.Tensor): sequence to be padded
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
batched (bool): if sequence has the batch dimension
pad_same (bool): if pad by duplicating
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
Returns:
padded sequence (np.ndarray or torch.Tensor)
"""
assert isinstance(seq, (np.ndarray, torch.Tensor))
assert pad_same or pad_values is not None
if pad_values is not None:
assert isinstance(pad_values, float)
repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave
concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like
seq_dim = 1 if batched else 0
begin_pad = []
end_pad = []
if padding[0] > 0:
pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
begin_pad.append(repeat_func(pad, padding[0], seq_dim))
if padding[1] > 0:
pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
end_pad.append(repeat_func(pad, padding[1], seq_dim))
return concat_func(begin_pad + [seq] + end_pad, seq_dim)
def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
"""
Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
Args:
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
batched (bool): if sequence has the batch dimension
pad_same (bool): if pad by duplicating
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
Returns:
padded sequence (dict or list or tuple)
"""
return recursive_dict_list_tuple_apply(
seq,
{
torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
x, p, b, ps, pv
),
np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
x, p, b, ps, pv
),
type(None): lambda x: x,
},
)
def assert_size_at_dim_single(x, size, dim, msg):
"""
Ensure that array or tensor @x has size @size in dim @dim.
Args:
x (np.ndarray or torch.Tensor): input array or tensor
size (int): size that tensors should have at @dim
dim (int): dimension to check
msg (str): text to display if assertion fails
"""
assert x.shape[dim] == size, msg
def assert_size_at_dim(x, size, dim, msg):
"""
Ensure that arrays and tensors in nested dictionary or list or tuple have
size @size in dim @dim.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size that tensors should have at @dim
dim (int): dimension to check
"""
map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
def get_shape(x):
"""
Get all shapes of arrays and tensors in nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple that contains each array or
tensor's shape
"""
return recursive_dict_list_tuple_apply(
x,
{
torch.Tensor: lambda x: x.shape,
np.ndarray: lambda x: x.shape,
type(None): lambda x: x,
},
)
def list_of_flat_dict_to_dict_of_list(list_of_dict):
"""
Helper function to go from a list of flat dictionaries to a dictionary of lists.
By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
floats, etc.
Args:
list_of_dict (list): list of flat dictionaries
Returns:
dict_of_list (dict): dictionary of lists
"""
assert isinstance(list_of_dict, list)
dic = collections.OrderedDict()
for i in range(len(list_of_dict)):
for k in list_of_dict[i]:
if k not in dic:
dic[k] = []
dic[k].append(list_of_dict[i][k])
return dic
def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""):
"""
Flatten a nested dict or list to a list.
For example, given a dict
{
a: 1
b: {
c: 2
}
c: 3
}
the function would return [(a, 1), (b_c, 2), (c, 3)]
Args:
d (dict, list): a nested dict or list to be flattened
parent_key (str): recursion helper
sep (str): separator for nesting keys
item_key (str): recursion helper
Returns:
list: a list of (key, value) tuples
"""
items = []
if isinstance(d, (tuple, list)):
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
for i, v in enumerate(d):
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
return items
elif isinstance(d, dict):
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
for k, v in d.items():
assert isinstance(k, str)
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
return items
else:
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
return [(new_key, d)]
def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs):
"""
Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
outputs to [B, T, ...].
Args:
inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
op: a layer op that accepts inputs
activation: activation to apply at the output
inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
inputs_as_args (bool) whether to feed input as a args list to the op
kwargs (dict): other kwargs to supply to the op
Returns:
outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
"""
batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
inputs = join_dimensions(inputs, 0, 1)
if inputs_as_kwargs:
outputs = op(**inputs, **kwargs)
elif inputs_as_args:
outputs = op(*inputs, **kwargs)
else:
outputs = op(inputs, **kwargs)
if activation is not None:
outputs = map_tensor(outputs, activation)
outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len))
return outputs

View File

@@ -0,0 +1,723 @@
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
TODO(alexander-soare):
- Remove reliance on Robomimic for SpatialSoftmax.
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
- Move EMA out of policy.
- Consolidate _DiffusionUnetImagePolicy into DiffusionPolicy.
- One more pass on comments and documentation.
"""
import copy
import logging
import math
import time
from collections import deque
from itertools import chain
from typing import Callable
import einops
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.optimization import get_scheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
populate_queues,
)
class DiffusionPolicy(nn.Module):
"""
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
"""
name = "diffusion"
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
"""
super().__init__()
# TODO(alexander-soare): LR scheduler will be removed.
assert lr_scheduler_num_training_steps > 0
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
self.diffusion = _DiffusionUnetImagePolicy(cfg)
# TODO(alexander-soare): This should probably be managed outside of the policy class.
self.ema_diffusion = None
self.ema = None
if self.cfg.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = _EMA(cfg, model=self.ema_diffusion)
# TODO(alexander-soare): Move optimizer out of policy.
self.optimizer = torch.optim.Adam(
self.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
)
# TODO(alexander-soare): Move LR scheduler out of policy.
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
self.global_step = 0
# configure lr scheduler
self.lr_scheduler = get_scheduler(
cfg.lr_scheduler,
optimizer=self.optimizer,
num_warmup_steps=cfg.lr_warmup_steps,
num_training_steps=lr_scheduler_num_training_steps,
# pytorch assumes stepping LRScheduler every epoch
# however huggingface diffusers steps it every batch
last_epoch=self.global_step - 1,
)
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=self.cfg.n_obs_steps),
"observation.state": deque(maxlen=self.cfg.n_obs_steps),
"action": deque(maxlen=self.cfg.n_action_steps),
}
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
"""Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the
underlying diffusion model. Here's how it works:
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
copied `n_obs_steps` times to fill the cache).
- The diffusion model generates `horizon` steps worth of actions.
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
Schematically this looks like:
----------------------------------------------------------------------------------------------
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h|
|observation is used | YES | YES | YES | NO | NO | NO | NO | NO | NO |
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
----------------------------------------------------------------------------------------------
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model
weights.
"""
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
if not self.training and self.ema_diffusion is not None:
actions = self.ema_diffusion.generate_actions(batch)
else:
actions = self.diffusion.generate_actions(batch)
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
return action
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
def update(self, batch: dict[str, Tensor], **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.diffusion.train()
loss = self.forward(batch)["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.ema is not None:
self.ema.step(self.diffusion)
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.lr_scheduler.get_last_lr()[0],
"update_s": time.time() - start_time,
}
return info
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
logging.warning(
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
)
assert len(unexpected_keys) == 0
class _DiffusionUnetImagePolicy(nn.Module):
def __init__(self, cfg: DiffusionConfig):
super().__init__()
self.cfg = cfg
self.rgb_encoder = _RgbEncoder(cfg)
self.unet = _ConditionalUnet1D(
cfg, global_cond_dim=(cfg.action_dim + self.rgb_encoder.feature_dim) * cfg.n_obs_steps
)
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=cfg.num_train_timesteps,
beta_start=cfg.beta_start,
beta_end=cfg.beta_end,
beta_schedule=cfg.beta_schedule,
variance_type="fixed_small",
clip_sample=cfg.clip_sample,
clip_sample_range=cfg.clip_sample_range,
prediction_type=cfg.prediction_type,
)
if cfg.num_inference_steps is None:
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
else:
self.num_inference_steps = cfg.num_inference_steps
# ========= inference ============
def conditional_sample(
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
) -> Tensor:
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
# Sample prior.
sample = torch.randn(
size=(batch_size, self.cfg.horizon, self.cfg.action_dim),
dtype=dtype,
device=device,
generator=generator,
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
# Predict model output.
model_output = self.unet(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
global_cond=global_cond,
)
# Compute previous image: x_t -> x_t-1
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
return sample
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
}
"""
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.cfg.n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
# run sampling
sample = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation).
actions = sample[..., : self.cfg.action_dim]
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.cfg.n_action_steps
actions = actions[:, start:end]
return actions
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
"""
This function expects `batch` to have (at least):
{
"observation.state": (B, n_obs_steps, state_dim)
"observation.image": (B, n_obs_steps, C, H, W)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)
}
"""
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
horizon = batch["action"].shape[1]
assert horizon == self.cfg.horizon
assert n_obs_steps == self.cfg.n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
trajectory = batch["action"]
# Forward diffusion.
# Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random noising timestep for each item in the batch.
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(trajectory.shape[0],),
device=trajectory.device,
).long()
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
# Compute the loss.
# The target is either the original trajectory, or the noise.
if self.cfg.prediction_type == "epsilon":
target = eps
elif self.cfg.prediction_type == "sample":
target = batch["action"]
else:
raise ValueError(f"Unsupported prediction type {self.cfg.prediction_type}")
loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)
return loss.mean()
class _RgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first.
"""
def __init__(self, cfg: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
self.normalizer = nn.Identity()
else:
self.normalizer = torchvision.transforms.Normalize(
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
)
if cfg.crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(cfg.crop_shape)
if cfg.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(cfg.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
# Set up backbone.
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
pretrained=cfg.use_pretrained_backbone
)
# Note: This assumes that the layer4 feature map is children()[-3]
# TODO(alexander-soare): Use a safer alternative.
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if cfg.use_group_norm:
if cfg.use_pretrained_backbone:
raise ValueError(
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
)
self.backbone = _replace_submodules(
root_module=self.backbone,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
)
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, 3, *cfg.image_size))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints)
self.feature_dim = cfg.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B, C, H, W) image tensor with pixel values in [0, 1].
Returns:
(B, D) image feature.
"""
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
x = self.normalizer(x)
if self.do_crop:
if self.training: # noqa: SIM108
x = self.maybe_random_crop(x)
else:
# Always use center crop for eval.
x = self.center_crop(x)
# Extract backbone feature.
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
# Final linear layer with non-linearity.
x = self.relu(self.out(x))
return x
def _replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
Args:
root_module: The module for which the submodules need to be replaced
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
func: Takes a module as an argument and returns a new module to replace it with.
Returns:
The root module with its submodules replaced.
"""
if predicate(root_module):
return func(root_module)
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parents, k in replace_list:
parent_module = root_module
if len(parents) > 0:
parent_module = root_module.get_submodule(".".join(parents))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
return root_module
class _SinusoidalPosEmb(nn.Module):
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class _Conv1dBlock(nn.Module):
"""Conv1d --> GroupNorm --> Mish"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class _ConditionalUnet1D(nn.Module):
"""A 1D convolutional UNet with FiLM modulation for conditioning.
Note: this removes local conditioning as compared to the original diffusion policy code.
"""
def __init__(self, cfg: DiffusionConfig, global_cond_dim: int):
super().__init__()
self.cfg = cfg
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
_SinusoidalPosEmb(cfg.diffusion_step_embed_dim),
nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4),
nn.Mish(),
nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim),
)
# The FiLM conditioning dimension.
cond_dim = cfg.diffusion_step_embed_dim + global_cond_dim
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(cfg.action_dim, cfg.down_dims[0])] + list(
zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True)
)
# Unet encoder.
common_res_block_kwargs = {
"cond_dim": cond_dim,
"kernel_size": cfg.kernel_size,
"n_groups": cfg.n_groups,
"use_film_scale_modulation": cfg.use_film_scale_modulation,
}
self.down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
self.down_modules.append(
nn.ModuleList(
[
_ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs),
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
]
)
)
# Processing in the middle of the auto-encoder.
self.mid_modules = nn.ModuleList(
[
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
]
)
# Unet decoder.
self.up_modules = nn.ModuleList([])
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
self.up_modules.append(
nn.ModuleList(
[
# dim_in * 2, because it takes the encoder's skip connection as well
_ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs),
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
]
)
)
self.final_conv = nn.Sequential(
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
nn.Conv1d(cfg.down_dims[0], cfg.action_dim, 1),
)
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
"""
Args:
x: (B, T, input_dim) tensor for input to the Unet.
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
global_cond: (B, global_cond_dim)
output: (B, T, input_dim)
Returns:
(B, T, input_dim) diffusion model prediction.
"""
# For 1D convolutions we'll need feature dimension first.
x = einops.rearrange(x, "b t d -> b d t")
timesteps_embed = self.diffusion_step_encoder(timestep)
# If there is a global conditioning feature, concatenate it to the timestep embedding.
if global_cond is not None:
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
else:
global_feature = timesteps_embed
# Run encoder, keeping track of skip features to pass to the decoder.
encoder_skip_features: list[Tensor] = []
for resnet, resnet2, downsample in self.down_modules:
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
encoder_skip_features.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
# Run decoder, using the skip features from the encoder.
for resnet, resnet2, upsample in self.up_modules:
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
x = resnet(x, global_feature)
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, "b d t -> b t d")
return x
class _ConditionalResidualBlock1D(nn.Module):
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
def __init__(
self,
in_channels: int,
out_channels: int,
cond_dim: int,
kernel_size: int = 3,
n_groups: int = 8,
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
# FiLM just modulates bias).
use_film_scale_modulation: bool = False,
):
super().__init__()
self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
"""
Args:
x: (B, in_channels, T)
cond: (B, cond_dim)
Returns:
(B, out_channels, T)
"""
out = self.conv1(x)
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
if self.use_film_scale_modulation:
# Treat the embedding as a list of scales and biases.
scale = cond_embed[:, : self.out_channels]
bias = cond_embed[:, self.out_channels :]
out = scale * out + bias
else:
# Treat the embedding as biases.
out = out + cond_embed
out = self.conv2(out)
out = out + self.residual_conv(x)
return out
class _EMA:
"""
Exponential Moving Average of models weights
"""
def __init__(self, cfg: DiffusionConfig, model: nn.Module):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_alpha (float): The minimum EMA decay rate. Default: 0.
"""
self.averaged_model = model
self.averaged_model.eval()
self.averaged_model.requires_grad_(False)
self.update_after_step = cfg.ema_update_after_step
self.inv_gamma = cfg.ema_inv_gamma
self.power = cfg.ema_power
self.min_alpha = cfg.ema_min_alpha
self.max_alpha = cfg.ema_max_alpha
self.alpha = 0.0
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
value = 1 - (1 + step / self.inv_gamma) ** -self.power
if step <= 0:
return 0.0
return max(self.min_alpha, min(value, self.max_alpha))
@torch.no_grad()
def step(self, new_model):
self.alpha = self.get_decay(self.optimization_step)
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
# Iterate over immediate parameters only.
for param, ema_param in zip(
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
):
if isinstance(param, dict):
raise RuntimeError("Dict parameter not supported")
if isinstance(module, _BatchNorm) or not param.requires_grad:
# Copy BatchNorm parameters, and non-trainable parameters directly.
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
else:
ema_param.mul_(self.alpha)
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
self.optimization_step += 1

View File

@@ -1,199 +0,0 @@
import copy
import time
import hydra
import torch
import torch.nn as nn
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
class DiffusionPolicy(nn.Module):
def __init__(
self,
cfg,
cfg_device,
cfg_noise_scheduler,
cfg_rgb_model,
cfg_obs_encoder,
cfg_optimizer,
cfg_ema,
shape_meta: dict,
horizon,
n_action_steps,
n_obs_steps,
num_inference_steps=None,
obs_as_global_cond=True,
diffusion_step_embed_dim=256,
down_dims=(256, 512, 1024),
kernel_size=5,
n_groups=8,
cond_predict_scale=True,
# parameters passed to step
**kwargs,
):
super().__init__()
self.cfg = cfg
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg_obs_encoder,
)
self.diffusion = DiffusionUnetImagePolicy(
shape_meta=shape_meta,
noise_scheduler=noise_scheduler,
obs_encoder=obs_encoder,
horizon=horizon,
n_action_steps=n_action_steps,
n_obs_steps=n_obs_steps,
num_inference_steps=num_inference_steps,
obs_as_global_cond=obs_as_global_cond,
diffusion_step_embed_dim=diffusion_step_embed_dim,
down_dims=down_dims,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
# parameters passed to step
**kwargs,
)
self.device = torch.device(cfg_device)
if torch.cuda.is_available() and cfg_device == "cuda":
self.diffusion.cuda()
self.ema = None
if self.cfg.use_ema:
self.ema = hydra.utils.instantiate(
cfg_ema,
model=copy.deepcopy(self.diffusion),
)
self.optimizer = hydra.utils.instantiate(
cfg_optimizer,
params=self.diffusion.parameters(),
)
# TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
self.global_step = 0
# configure lr scheduler
self.lr_scheduler = get_scheduler(
cfg.lr_scheduler,
optimizer=self.optimizer,
num_warmup_steps=cfg.lr_warmup_steps,
num_training_steps=cfg.offline_steps,
# pytorch assumes stepping LRScheduler every epoch
# however huggingface diffusers steps it every batch
last_epoch=self.global_step - 1,
)
@torch.no_grad()
def forward(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count
# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
obs_dict = {
"image": observation["image"],
"agent_pos": observation["state"],
}
out = self.diffusion.predict_action(obs_dict)
action = out["action"].squeeze(0)
return action
def update(self, replay_buffer, step):
start_time = time.time()
self.diffusion.train()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
def process_batch(batch, horizon, num_slices):
# trajectory t = 64, horizon h = 16
# (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
# |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
# |o|o| observations: 2
# | |a|a|a|a|a|a|a|a| actions executed: 8
# |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
# note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
image = batch["observation", "image"]
state = batch["observation", "state"]
action = batch["action"]
assert image.shape[1] == horizon
assert state.shape[1] == horizon
assert action.shape[1] == horizon
if not (horizon == 16 and self.cfg.n_obs_steps == 2):
raise NotImplementedError()
# keep first 2 observations of the slice corresponding to t=[-1,0]
image = image[:, : self.cfg.n_obs_steps]
state = state[:, : self.cfg.n_obs_steps]
out = {
"obs": {
"image": image.to(self.device, non_blocking=True),
"agent_pos": state.to(self.device, non_blocking=True),
},
"action": action.to(self.device, non_blocking=True),
}
return out
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time
loss = self.diffusion.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.ema is not None:
self.ema.step(self.diffusion)
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.lr_scheduler.get_last_lr()[0],
"data_s": data_s,
"update_s": time.time() - start_time,
}
# TODO(rcadene): remove hardcoding
# in diffusion_policy, len(dataloader) is 168 for a batch_size of 64
if step % 168 == 0:
self.global_step += 1
return info
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)

View File

@@ -1,76 +0,0 @@
from typing import Callable, Dict
import torch
import torch.nn as nn
import torchvision
def get_resnet(name, weights=None, **kwargs):
"""
name: resnet18, resnet34, resnet50
weights: "IMAGENET1K_V1", "r3m"
"""
# load r3m weights
if (weights == "r3m") or (weights == "R3M"):
return get_r3m(name=name, **kwargs)
func = getattr(torchvision.models, name)
resnet = func(weights=weights, **kwargs)
resnet.fc = torch.nn.Identity()
return resnet
def get_r3m(name, **kwargs):
"""
name: resnet18, resnet34, resnet50
"""
import r3m
r3m.device = "cpu"
model = r3m.load_r3m(name)
r3m_model = model.module
resnet_model = r3m_model.convnet
resnet_model = resnet_model.to("cpu")
return resnet_model
def dict_apply(
x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor]
) -> Dict[str, torch.Tensor]:
result = {}
for key, value in x.items():
if isinstance(value, dict):
result[key] = dict_apply(value, func)
else:
result[key] = func(value)
return result
def replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
if predicate(root_module):
return func(root_module)
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parent, k in bn_list:
parent_module = root_module
if len(parent) > 0:
parent_module = root_module.get_submodule(".".join(parent))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
assert len(bn_list) == 0
return root_module

View File

@@ -1,40 +1,61 @@
def make_policy(cfg):
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPC
import inspect
policy = TDMPC(cfg.policy, cfg.device)
elif cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from omegaconf import DictConfig, OmegaConf
policy = DiffusionPolicy(
cfg=cfg.policy,
cfg_device=cfg.device,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
from lerobot.common.utils.utils import get_safe_torch_device
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
assert set(hydra_cfg.policy).issuperset(
expected_kwargs
), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
policy_cfg = policy_cfg_class(
**{
k: v
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
if k in expected_kwargs
}
)
return policy_cfg
def make_policy(hydra_cfg: DictConfig):
if hydra_cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
policy = TDMPCPolicy(
hydra_cfg.policy,
n_obs_steps=hydra_cfg.n_obs_steps,
n_action_steps=hydra_cfg.n_action_steps,
device=hydra_cfg.device,
)
elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
elif hydra_cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
policy = ActionChunkingTransformerPolicy(
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
)
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps)
policy.to(get_safe_torch_device(hydra_cfg.device))
elif hydra_cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
policy = ActionChunkingTransformerPolicy(policy_cfg)
policy.to(get_safe_torch_device(hydra_cfg.device))
else:
raise ValueError(cfg.policy.name)
raise ValueError(hydra_cfg.policy.name)
if cfg.policy.pretrained_model_path:
if hydra_cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
if hydra_cfg.policy.name == "tdmpc" and "fowm" in hydra_cfg.policy.pretrained_model_path:
if "offline" in hydra_cfg.policy.pretrained_model_path:
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
elif "final" in hydra_cfg.policy.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.policy.pretrained_model_path)
policy.load(hydra_cfg.policy.pretrained_model_path)
return policy

View File

@@ -0,0 +1,45 @@
"""A protocol that all policies should follow.
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
subclass a base class.
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
how to implement new policies.
"""
from typing import Protocol, runtime_checkable
from torch import Tensor
@runtime_checkable
class Policy(Protocol):
"""The required interface for implementing a policy."""
name: str
def reset(self):
"""To be called whenever the environment is reset.
Does things like clearing caches.
"""
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and maybe other information.
"""
def select_action(self, batch: dict[str, Tensor]):
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
def update(self, batch):
"""Does compute_loss then an optimization step.
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
disappear.
"""

View File

@@ -1,6 +1,7 @@
# ruff: noqa: N806
import time
from collections import deque
from copy import deepcopy
import einops
@@ -9,6 +10,8 @@ import torch
import torch.nn as nn
import lerobot.common.policies.tdmpc.helper as h
from lerobot.common.policies.utils import populate_queues
from lerobot.common.utils.utils import get_safe_torch_device
FIRST_FRAME = 0
@@ -85,24 +88,28 @@ class TOLD(nn.Module):
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
class TDMPC(nn.Module):
class TDMPCPolicy(nn.Module):
"""Implementation of TD-MPC learning + inference."""
def __init__(self, cfg, device):
name = "tdmpc"
def __init__(self, cfg, n_obs_steps, n_action_steps, device):
super().__init__()
self.action_dim = cfg.action_dim
self.cfg = cfg
self.device = torch.device(device)
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
self.device = get_safe_torch_device(device)
self.std = h.linear_schedule(cfg.std_schedule, 0)
self.model = TOLD(cfg).cuda() if torch.cuda.is_available() and device == "cuda" else TOLD(cfg)
self.model = TOLD(cfg)
self.model.to(self.device)
self.model_target = deepcopy(self.model)
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.model.eval()
self.model_target.eval()
self.batch_size = cfg.batch_size
self.register_buffer("step", torch.zeros(1))
@@ -123,21 +130,54 @@ class TDMPC(nn.Module):
self.model.load_state_dict(d["model"])
self.model_target.load_state_dict(d["model_target"])
@torch.no_grad()
def forward(self, observation, step_count):
t0 = step_count.item() == 0
# TODO(rcadene): remove unsqueeze hack...
if observation["image"].ndim == 3:
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
obs = {
# TODO(rcadene): remove contiguous hack...
"rgb": observation["image"].contiguous(),
"state": observation["state"].contiguous(),
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=self.n_obs_steps),
"observation.state": deque(maxlen=self.n_obs_steps),
"action": deque(maxlen=self.n_action_steps),
}
action = self.act(obs, t0=t0, step=self.step.item())
@torch.no_grad()
def select_action(self, batch, step):
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
self._queues = populate_queues(self._queues, batch)
t0 = step == 0
self.eval()
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
if self.n_obs_steps == 1:
# hack to remove the time dimension
for key in batch:
assert batch[key].shape[1] == 1
batch[key] = batch[key][:, 0]
actions = []
batch_size = batch["observation.image"].shape[0]
for i in range(batch_size):
obs = {
"rgb": batch["observation.image"][[i]],
"state": batch["observation.state"][[i]],
}
# Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step)
actions.append(action)
action = torch.stack(actions)
# tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time
if i in range(self.n_action_steps):
self._queues["action"].append(action)
action = self._queues["action"].popleft()
return action
@torch.no_grad()
@@ -286,117 +326,58 @@ class TDMPC(nn.Module):
def _td_target(self, next_z, reward, mask):
"""Compute the TD-target from a reward and the observation at the following time step."""
next_v = self.model.V(next_z)
td_target = reward + self.cfg.discount * mask * next_v
td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
return td_target
def update(self, replay_buffer, step, demo_buffer=None):
def forward(self, batch, step):
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
raise NotImplementedError()
def update(self, batch, step):
"""Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
batch_size = batch["index"].shape[0]
if demo_buffer is None:
demo_batch_size = 0
else:
# Update oversampling ratio
demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
demo_num_slices = int(demo_pc_batch * self.batch_size)
demo_batch_size = self.cfg.horizon * demo_num_slices
batch_size -= demo_batch_size
num_slices -= demo_num_slices
replay_buffer._sampler.num_slices = num_slices
demo_buffer._sampler.num_slices = demo_num_slices
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
# batch size b = 256, time/horizon t = 5
# b t ... -> t b ...
for key in batch:
if batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
assert demo_batch_size % self.cfg.horizon == 0
assert demo_batch_size % demo_num_slices == 0
action = batch["action"]
reward = batch["next.reward"]
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
obses = {
"rgb": batch["observation.image"],
"state": batch["observation.state"],
}
# Sample from interaction dataset
def process_batch(batch, horizon, num_slices):
# trajectory t = 256, horizon h = 5
# (t h) ... -> h t ...
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
obs = {
"rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
"state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
}
action = batch["action"].to(self.device, non_blocking=True)
next_obses = {
"rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
"state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
}
reward = batch["next", "reward"].to(self.device, non_blocking=True)
idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
# TODO(rcadene): rearrange directly in offline dataset
if reward.ndim == 2:
reward = einops.rearrange(reward, "h t -> h t 1")
assert reward.ndim == 3
assert reward.shape == (horizon, num_slices, 1)
# We dont use `batch["next", "done"]` since it only indicates the end of an
# episode, but not the end of the trajectory of an episode.
# Neither does `batch["next", "terminated"]`
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
return obs, action, next_obses, reward, mask, done, idxs, weights
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
batch, self.cfg.horizon, num_slices
)
# Sample from demonstration dataset
if demo_batch_size > 0:
demo_batch = demo_buffer.sample(demo_batch_size)
(
demo_obs,
demo_action,
demo_next_obses,
demo_reward,
demo_mask,
demo_done,
demo_idxs,
demo_weights,
) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
if isinstance(obs, dict):
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
else:
obs = torch.cat([obs, demo_obs])
next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
action = torch.cat([action, demo_action], dim=1)
reward = torch.cat([reward, demo_reward], dim=1)
mask = torch.cat([mask, demo_mask], dim=1)
done = torch.cat([done, demo_done], dim=1)
idxs = torch.cat([idxs, demo_idxs])
weights = torch.cat([weights, demo_weights])
shapes = {}
for k in obses:
shapes[k] = obses[k].shape
obses[k] = einops.rearrange(obses[k], "t b ... -> (t b) ... ")
# Apply augmentations
aug_tf = h.aug(self.cfg)
obs = aug_tf(obs)
obses = aug_tf(obses)
for k in next_obses:
next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...")
next_obses = aug_tf(next_obses)
for k in next_obses:
next_obses[k] = einops.rearrange(
next_obses[k],
"(h t) ... -> h t ...",
h=self.cfg.horizon,
t=self.cfg.batch_size,
)
for k in obses:
t, b = shapes[k][:2]
obses[k] = einops.rearrange(obses[k], "(t b) ... -> t b ... ", b=b, t=t)
horizon = self.cfg.horizon
obs, next_obses = {}, {}
for k in obses:
obs[k] = obses[k][0]
next_obses[k] = obses[k][1:].clone()
horizon = next_obses["rgb"].shape[0]
loss_mask = torch.ones_like(mask, device=self.device)
for t in range(1, horizon):
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
@@ -414,7 +395,7 @@ class TDMPC(nn.Module):
td_targets = self._td_target(next_z, reward, mask)
# Latent rollout
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
reward_preds = torch.empty_like(reward, device=self.device)
assert reward.shape[0] == horizon
z = self.model.encode(obs)
@@ -423,22 +404,21 @@ class TDMPC(nn.Module):
for t in range(horizon):
z, reward_pred = self.model.next(z, action[t])
zs[t + 1] = z
reward_preds[t] = reward_pred
reward_preds[t] = reward_pred.squeeze(1)
with torch.no_grad():
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
# Predictions
qs = self.model.Q(zs[:-1], action, return_type="all")
qs = qs.squeeze(3)
value_info["Q"] = qs.mean().item()
v = self.model.V(zs[:-1])
value_info["V"] = v.mean().item()
# Losses
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
dim=0
)
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
q_value_loss, priority_loss = 0, 0
for q in range(self.cfg.num_q):
@@ -446,7 +426,9 @@ class TDMPC(nn.Module):
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
expectile = h.linear_schedule(self.cfg.expectile, step)
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
dim=0
)
total_loss = (
self.cfg.consistency_coef * consistency_loss
@@ -455,7 +437,7 @@ class TDMPC(nn.Module):
+ self.cfg.value_coef * v_value_loss
)
weighted_loss = (total_loss.squeeze(1) * weights).mean()
weighted_loss = (total_loss * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
has_nan = torch.isnan(weighted_loss).item()
if has_nan:
@@ -468,19 +450,20 @@ class TDMPC(nn.Module):
)
self.optim.step()
if self.cfg.per:
# Update priorities
priorities = priority_loss.clamp(max=1e4).detach()
has_nan = torch.isnan(priorities).any().item()
if has_nan:
print(f"priorities has nan: {priorities=}")
else:
replay_buffer.update_priority(
idxs[:num_slices],
priorities[:num_slices],
)
if demo_batch_size > 0:
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
# TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
# if self.cfg.per:
# # Update priorities
# priorities = priority_loss.clamp(max=1e4).detach()
# has_nan = torch.isnan(priorities).any().item()
# if has_nan:
# print(f"priorities has nan: {priorities=}")
# else:
# replay_buffer.update_priority(
# idxs[:num_slices],
# priorities[:num_slices],
# )
# if demo_batch_size > 0:
# demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
# Update policy + target network
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
@@ -503,7 +486,7 @@ class TDMPC(nn.Module):
"data_s": data_s,
"update_s": time.time() - start_time,
}
info["demo_batch_size"] = demo_batch_size
# info["demo_batch_size"] = demo_batch_size
info["expectile"] = expectile
info.update(value_info)
info.update(pi_update_info)

View File

@@ -0,0 +1,30 @@
import torch
from torch import nn
def populate_queues(queues, batch):
for key in batch:
if len(queues[key]) != queues[key].maxlen:
# initialize by copying the first observation several times until the queue is full
while len(queues[key]) != queues[key].maxlen:
queues[key].append(batch[key])
else:
# add latest observation to the queue
queues[key].append(batch[key])
return queues
def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Note: assumes that all parameters have the same device
"""
return next(iter(module.parameters())).device
def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
"""Get a module's parameter dtype by checking one of its parameters.
Note: assumes that all parameters have the same dtype.
"""
return next(iter(module.parameters())).dtype

View File

@@ -0,0 +1,65 @@
from torchvision.transforms.v2 import Compose, Transform
def apply_inverse_transform(item, transform):
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
for tf in transforms[::-1]:
if tf.invertible:
item = tf.inverse_transform(item)
else:
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
return item
class NormalizeTransform(Transform):
invertible = True
def __init__(
self,
stats: dict,
in_keys: list[str] = None,
out_keys: list[str] | None = None,
in_keys_inv: list[str] | None = None,
out_keys_inv: list[str] | None = None,
mode="mean_std",
):
super().__init__()
self.in_keys = in_keys
self.out_keys = in_keys if out_keys is None else out_keys
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
self.stats = stats
assert mode in ["mean_std", "min_max"]
self.mode = mode
def forward(self, item):
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
# normalize to [0,1]
item[outkey] = (item[inkey] - min) / (max - min)
# normalize to [-1, 1]
item[outkey] = item[outkey] * 2 - 1
return item
def inverse_transform(self, item):
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = item[inkey] * std + mean
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
item[outkey] = (item[inkey] + 1) / 2
item[outkey] = item[outkey] * (max - min) + min
return item

View File

@@ -1,45 +0,0 @@
import logging
import random
from datetime import datetime
import numpy as np
import torch
def set_seed(seed):
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def init_logging():
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
return message
logging.basicConfig(level=logging.INFO)
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
formatter = logging.Formatter()
formatter.format = custom_format
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
def format_big_number(num):
suffixes = ["", "K", "M", "B", "T", "Q"]
divisor = 1000.0
for suffix in suffixes:
if abs(num) < divisor:
return f"{num:.0f}{suffix}"
num /= divisor
return num

View File

@@ -0,0 +1,44 @@
import importlib
import logging
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
Check if the package spec exists and grab its version to avoid importing a local directory.
**Note:** this doesn't work for all packages.
"""
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
# Primary method to get the package version
package_version = importlib.metadata.version(pkg_name)
except importlib.metadata.PackageNotFoundError:
# Fallback method: Only for "torch" and versions containing "dev"
if pkg_name == "torch":
try:
package = importlib.import_module(pkg_name)
temp_version = getattr(package, "__version__", "N/A")
# Check if the version contains "dev"
if "dev" in temp_version:
package_version = temp_version
package_exists = True
else:
package_exists = False
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
logging.debug(f"Detected {pkg_name} version: {package_version}")
if return_version:
return package_exists, package_version
else:
return package_exists
_torch_available, _torch_version = is_package_available("torch", return_version=True)
_gym_xarm_available = is_package_available("gym_xarm")
_gym_aloha_available = is_package_available("gym_aloha")
_gym_pusht_available = is_package_available("gym_pusht")

View File

@@ -0,0 +1,111 @@
import logging
import os.path as osp
import random
from datetime import datetime
from pathlib import Path
import hydra
import numpy as np
import torch
from omegaconf import DictConfig
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
match cfg_device:
case "cuda":
assert torch.cuda.is_available()
device = torch.device("cuda")
case "mps":
assert torch.backends.mps.is_available()
device = torch.device("mps")
case "cpu":
device = torch.device("cpu")
if log:
logging.warning("Using CPU, this will be slow.")
case _:
device = torch.device(cfg_device)
if log:
logging.warning(f"Using custom {cfg_device} device.")
return device
def set_global_seed(seed):
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def init_logging():
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
return message
logging.basicConfig(level=logging.INFO)
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
formatter = logging.Formatter()
formatter.format = custom_format
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
def format_big_number(num):
suffixes = ["", "K", "M", "B", "T", "Q"]
divisor = 1000.0
for suffix in suffixes:
if abs(num) < divisor:
return f"{num:.0f}{suffix}"
num /= divisor
return num
def _relative_path_between(path1: Path, path2: Path) -> Path:
"""Returns path1 relative to path2."""
path1 = path1.absolute()
path2 = path2.absolute()
try:
return path1.relative_to(path2)
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
)
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
"""Initialize a Hydra config given only the path to the relevant config file.
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
"""
# TODO(alexander-soare): Resolve configs without Hydra initialization.
hydra.core.global_hydra.GlobalHydra.instance().clear()
# Hydra needs a path relative to this file.
hydra.initialize(
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent))
)
cfg = hydra.compose(Path(config_path).stem, overrides)
return cfg
def print_cuda_memory_usage():
"""Use this function to locate and debug memory leak."""
import gc
gc.collect()
# Also clear the cache if you want to fully release the memory
torch.cuda.empty_cache()
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))

View File

@@ -5,11 +5,14 @@ defaults:
hydra:
run:
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
job:
name: default
seed: 1337
# batch size for TorchRL SerialEnv. Each underlying env will get the seed = seed + env_index
# NOTE: only diffusion policy supports rollout_batch_size > 1
rollout_batch_size: 1
device: cuda # cpu
prefetch: 4
eval_freq: ???
@@ -23,12 +26,17 @@ fps: ???
offline_prioritized_sampler: true
dataset_id: ???
n_action_steps: ???
n_obs_steps: ???
env: ???
policy: ???
wandb:
enable: true
# Set to true to disable saving an artifact despite save_model == True
disable_artifact: false
project: lerobot
notes: ""

View File

@@ -4,19 +4,20 @@ eval_episodes: 50
eval_freq: 7500
save_freq: 75000
log_freq: 250
# TODO: same as simxarm, need to adjust
# TODO: same as xarm, need to adjust
offline_steps: 25000
online_steps: 25000
fps: 50
dataset_id: aloha_sim_insertion_human
env:
name: aloha
task: sim_insertion_human
task: AlohaInsertion-v0
from_pixels: True
pixels_only: False
image_size: [3, 480, 640]
action_repeat: 1
episode_length: 400
fps: ${fps}

View File

@@ -4,19 +4,20 @@ eval_episodes: 50
eval_freq: 7500
save_freq: 75000
log_freq: 250
# TODO: same as simxarm, need to adjust
# TODO: same as xarm, need to adjust
offline_steps: 25000
online_steps: 25000
fps: 10
dataset_id: pusht
env:
name: pusht
task: pusht
task: PushT-v0
from_pixels: True
pixels_only: False
image_size: 96
action_repeat: 1
episode_length: 300
fps: ${fps}

Some files were not shown because too many files have changed in this diff Show More