Compare commits

...

84 Commits

Author SHA1 Message Date
Thomas Wolf
a1e47202c0 update 2024-04-03 07:46:17 +02:00
Thomas Wolf
24821fee24 update 2024-04-02 22:49:16 +02:00
Thomas Wolf
4751642ace adding docstring and from_pretrained/save_pretrained 2024-04-02 22:45:21 +02:00
Alexander Soare
11cbf1bea1 Merge pull request #53 from alexander-soare/finish_examples
Add examples 2 and 3
2024-04-01 11:52:41 +01:00
Alexander Soare
f1148b8c2d Merge remote-tracking branch 'upstream/main' into finish_examples 2024-04-01 11:31:31 +01:00
Simon Alibert
2a98cc71ed Merge pull request #56 from huggingface/user/aliberts/2024_03_27_improve_ci
Add code coverage, more end-to-end tests
2024-03-28 10:57:44 +01:00
Simon Alibert
a7c9b78e56 Deactivate eval ACT on Aloha (policy is None) 2024-03-28 10:55:11 +01:00
Simon Alibert
404b8f8a75 Fix end-to-end ACT train on Aloha 2024-03-28 10:35:11 +01:00
Simon Alibert
17c2bbbeb8 remove todo 2024-03-28 10:35:11 +01:00
Simon Alibert
006e5feabf WIP add code coverage 2024-03-28 10:35:11 +01:00
Simon Alibert
b99ee8180a Add more end-to-end tests 2024-03-28 10:35:11 +01:00
Simon Alibert
6bddcb647e Add test_aloha env test 2024-03-28 10:35:11 +01:00
Simon Alibert
58df2066a9 Add pytest-cov 2024-03-28 10:35:11 +01:00
Simon Alibert
c89aa4f8ed Merge pull request #57 from huggingface/user/aliberts/2024_03_27_improve_readme
Improve readme
2024-03-28 10:26:48 +01:00
Simon Alibert
62aad7104b Pull merge 2024-03-28 10:03:25 +01:00
Simon Alibert
9d9148dad8 Fixes for #57 2024-03-28 10:01:33 +01:00
Simon Alibert
1b6cb2b1be Add space
Co-authored-by: Remi <re.cadene@gmail.com>
2024-03-27 20:51:52 +01:00
Simon Alibert
6f1a0aefab typo fix
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
2024-03-27 20:50:23 +01:00
Alexander Soare
b7c9c33072 revision 2024-03-27 18:33:48 +00:00
Alexander Soare
120f0aef5c Merge remote-tracking branch 'upstream/main' into finish_examples 2024-03-27 17:52:36 +00:00
Simon Alibert
032200e32c Typo fix 2024-03-27 17:05:04 +01:00
Simon Alibert
de1e9187c8 Formatting 2024-03-27 16:56:21 +01:00
Simon Alibert
4f8f1926f9 Update pip install without requirements.txt 2024-03-27 16:49:27 +01:00
Simon Alibert
6710121a29 Revert "Add requirements.txt"
This reverts commit 18e7f4c3e6.
2024-03-27 16:47:49 +01:00
Simon Alibert
5f4b8ab899 Add more exhaustive install instructions 2024-03-27 16:35:32 +01:00
Simon Alibert
18e7f4c3e6 Add requirements.txt 2024-03-27 16:33:54 +01:00
Simon Alibert
643d64e2a8 Add cmake 2024-03-27 16:33:26 +01:00
Alexander Soare
c037722e23 Merge pull request #58 from alexander-soare/update_diffusion_model
Update diffusion model
2024-03-27 13:34:01 +00:00
Alexander Soare
6cd671040f fix revision 2024-03-27 13:22:14 +00:00
Alexander Soare
b6353964ba fix bug: use provided revision instead of hardcoded one 2024-03-27 13:08:47 +00:00
Alexander Soare
64c8851c40 Merge branch 'tidy_diffusion_config' into update_diffusion_model 2024-03-27 13:06:08 +00:00
Alexander Soare
dc745e3037 Remove unused part of diffusion policy config 2024-03-27 13:05:13 +00:00
Simon Alibert
6f0c2445ca Improve readme format 2024-03-27 13:26:54 +01:00
Simon Alibert
d1d2229407 WIP add badges 2024-03-27 13:26:45 +01:00
Alexander Soare
68d02c80cf Remove b/c workaround 2024-03-27 12:03:19 +00:00
Alexander Soare
011f2d27fe fix tests 2024-03-26 16:40:54 +00:00
Alexander Soare
be4441c7ff update README 2024-03-26 16:28:16 +00:00
Alexander Soare
1ed0110900 finish examples 2 and 3 2024-03-26 16:13:40 +00:00
Remi
cb6d1e0871 Merge pull request #49 from huggingface/user/rcadene/2024_03_25_readme
Improve README
2024-03-26 11:50:24 +01:00
Cadene
9ced0cf1fb unskip 2024-03-26 10:45:31 +00:00
Cadene
98534d1a63 skip 2024-03-26 10:42:53 +00:00
Cadene
edacc1d2a0 Add root in example 2024-03-26 10:40:06 +00:00
Cadene
5a46b8a2a9 fix tests 2024-03-26 10:24:46 +00:00
Cadene
4a8c5e238e issue with cat_and_write_video 2024-03-26 10:12:16 +00:00
Alexander Soare
1a1308d62f fix environment seeding
add fixes for reproducibility

only try to start env if it is closed

revision

fix normalization and data type

Improve README

Improve README

Tests are passing, Eval pretrained model works, Add gif

Update gif

Update gif

Update gif

Update gif

Update README

Update README

update minor

Update README.md

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

Update README.md

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

Address suggestions

Update thumbnail + stats

Update thumbnail + stats

Update README.md

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>

Add more comments

Add test_examples.py
2024-03-26 10:10:43 +00:00
Simon Alibert
203bcd7ca5 Merge pull request #47 from huggingface/user/aliberts/2024_03_22_fix_simxarm
Port simxarm, upgrade gym to gymnasium
2024-03-26 10:17:52 +01:00
Simon Alibert
98b9631aa6 Add n_obs_steps in default.yaml config 2024-03-26 10:08:00 +01:00
Simon Alibert
c5635b7d94 Minor fixes for #47 2024-03-25 18:50:47 +01:00
Simon Alibert
f00252552a fix ci 2024-03-25 17:38:08 +01:00
Simon Alibert
90f6af9736 Update ci 2024-03-25 17:33:47 +01:00
Simon Alibert
bcfdba109f Update pre-commit & run on all files 2024-03-25 17:29:35 +01:00
Simon Alibert
0fae5b206b Allow mujoco minor revision updates 2024-03-25 17:25:55 +01:00
Simon Alibert
7cdd6d2450 Renamed set_seed -> set_global_seed 2024-03-25 17:19:28 +01:00
Simon Alibert
058ac991eb Add simxarm back into tests 2024-03-25 16:35:46 +01:00
Simon Alibert
d3adaf1379 Add stat.pth for xarm_lift_medium 2024-03-25 15:55:45 +01:00
Simon Alibert
dc89166bee Upgrade gym to gymnasium 2024-03-25 15:12:21 +01:00
Simon Alibert
5ef813ff1e Remove deprecated code 2024-03-25 13:22:49 +01:00
Simon Alibert
a2ac83276b Update ci env 2024-03-25 13:02:50 +01:00
Simon Alibert
c0833f1c2d Remove simxarm download and preproc hack 2024-03-25 12:41:17 +01:00
Simon Alibert
de5c30405e fix wrong version 2024-03-25 12:35:06 +01:00
Simon Alibert
462e7469e8 Add xarm_lift_medium revision 1.0 to hub 2024-03-25 12:28:07 +01:00
Simon Alibert
298d391b26 Add simxarm license 2024-03-25 12:28:07 +01:00
Cadene
be6364f109 fix, it's training now! 2024-03-25 12:28:07 +01:00
Simon Alibert
127de1258d WIP 2024-03-25 12:28:07 +01:00
Cadene
b905111895 fix render issue 2024-03-25 12:28:07 +01:00
Simon Alibert
0c41675986 fix __init__ import Base 2024-03-25 12:28:07 +01:00
Simon Alibert
1c24bbda3f WIP Upgrading simxam from mujoco-py to mujoco python bindings 2024-03-25 12:28:07 +01:00
Alexander Soare
e41c420a96 Merge pull request #48 from alexander-soare/fix_visualization
Fix normalization of last frame and data type in visualization
2024-03-25 09:53:06 +00:00
Alexander Soare
4a48b77540 fix normalization and data type 2024-03-25 09:44:03 +00:00
Remi
f3cfc8b3b4 Merge pull request #46 from huggingface/user/rcadene/2024_03_23_update_stats_v1.2
Fix bug with stats.pth + Move from cadene to lerobot + Update datasets to v1.2
2024-03-24 17:53:32 +01:00
Cadene
d2ef43436c move from cadene to lerobot 2024-03-23 13:34:35 +00:00
Cadene
40f3783fca v1.2 2024-03-23 11:41:56 +00:00
Alexander Soare
e21ed6f510 Merge pull request #45 from alexander-soare/fix_environment_seeding
Reproduce original diffusion policy pusht image eval
2024-03-22 16:27:48 +00:00
Alexander Soare
bd40ffc53c revision 2024-03-22 15:43:45 +00:00
Alexander Soare
d43fa600a0 only try to start env if it is closed 2024-03-22 15:32:55 +00:00
Alexander Soare
e698d38a35 Merge remote-tracking branch 'upstream/main' into fix_environment_seeding 2024-03-22 15:11:15 +00:00
Alexander Soare
15ff3b3af8 add fixes for reproducibility 2024-03-22 15:06:57 +00:00
Remi
a80d9c0257 Merge pull request #44 from alexander-soare/run_model_from_hub
Run model from HuggingFace hub
2024-03-22 16:03:27 +01:00
Alexander Soare
b9047fbdd2 fix environment seeding 2024-03-22 13:25:23 +00:00
Alexander Soare
115927d0f6 make sure to pass stats.pth arg 2024-03-22 12:58:59 +00:00
Alexander Soare
529f42643d revision 2024-03-22 12:33:25 +00:00
Alexander Soare
1b279a1fc0 fix test 2024-03-22 10:58:27 +00:00
Alexander Soare
3f0f95f4c0 update readme 2024-03-22 10:34:22 +00:00
Alexander Soare
8720c568d0 Add ability to eval hub model 2024-03-22 10:26:55 +00:00
121 changed files with 2873 additions and 733 deletions

1
.gitattributes vendored
View File

@@ -1 +1,2 @@
*.memmap filter=lfs diff=lfs merge=lfs -text
*.stl filter=lfs diff=lfs merge=lfs -text

401
.github/poetry/cpu/poetry.lock generated vendored
View File

@@ -327,6 +327,35 @@ files = [
{file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"},
]
[[package]]
name = "cmake"
version = "3.29.0.1"
description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software"
optional = false
python-versions = ">=3.7"
files = [
{file = "cmake-3.29.0.1-py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:ec8b39fdeb75c48fd5a2894658a1ca75f94fb49b421c1f753d86d3e5d5e9f196"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ead7dc5176a6c6347b3fc19532c25ec328f9279b6213902ac930242334e7b621"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f7c7dabf3dd40cb830d2eded43d51c5a3737625bbc5ab6916041c04e352b74f8"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de8f55198f4a820daf2c57645a4bb8cd1064dc92d950ad95be14c5ffabc15bd4"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bb8aaa3419eafd466931e4dcc161d3e5e6a82730ab508c75946ff4fc883b3f6"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e956e4b6c2d8d6bbee399bf3c77e5b901d916fe8f35d6b2f58444d5892c4602b"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a8f7c8b07e6ab0dd444c5b74e658d5013ca0da456041029f734d751090bb7ec"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22493049b6383ea2baa7237a326c2914ab4a7b3e1642f4233245e3a34aae39f6"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:d09573411901fc8a3ede7433713c000ac7b81cda0d771874e9182770acf29eb4"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_i686.whl", hash = "sha256:c85e35ec572e54152154637f24d6bde316fbcf94dfea644bd8f22b1855a09abe"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:a36f3b19a5caa6c63aa70bfa0b262bae8d296e68c6e6d8e918edc5c51d952bf8"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:85a0e28eaaee311d50fcee60f730e5a44a65b3cefc556a1163bbabd7328acd60"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:58879ef15dd8344e1583a36cead794fb0fee13f78c590a56749283ac2e27d30a"},
{file = "cmake-3.29.0.1-py3-none-win32.whl", hash = "sha256:068a3e7461dd9e487f5f3f720ae8072c2eeb37239ebe4642c4ae29058d83347f"},
{file = "cmake-3.29.0.1-py3-none-win_amd64.whl", hash = "sha256:c097892b3653e2d2d41d055c80a9af029fdd24f9eff2ecb66576e4da8b85b1c7"},
{file = "cmake-3.29.0.1-py3-none-win_arm64.whl", hash = "sha256:1808071047cb49ed0fe2359e4c310b49880c2805cad4ea9f03c959f51b881ac7"},
{file = "cmake-3.29.0.1.tar.gz", hash = "sha256:ec49a7a4480959c229d9d2aa77f7885859c17a45fc66981aaf4551ceffb4d030"},
]
[package.extras]
test = ["coverage (>=4.2)", "pytest (>=3.0.3)", "pytest-cov (>=2.4.0)"]
[[package]]
name = "colorama"
version = "0.4.6"
@@ -339,72 +368,72 @@ files = [
]
[[package]]
name = "cython"
version = "3.0.9"
description = "The Cython compiler for writing C extensions in the Python language."
name = "coverage"
version = "7.4.4"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
python-versions = ">=3.8"
files = [
{file = "Cython-3.0.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:296bd30d4445ac61b66c9d766567f6e81a6e262835d261e903c60c891a6729d3"},
{file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f496b52845cb45568a69d6359a2c335135233003e708ea02155c10ce3548aa89"},
{file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:858c3766b9aa3ab8a413392c72bbab1c144a9766b7c7bfdef64e2e414363fa0c"},
{file = "Cython-3.0.9-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c0eb1e6ef036028a52525fd9a012a556f6dd4788a0e8755fe864ba0e70cde2ff"},
{file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c8191941073ea5896321de3c8c958fd66e5f304b0cd1f22c59edd0b86c4dd90d"},
{file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e32b016030bc72a8a22a1f21f470a2f57573761a4f00fbfe8347263f4fbdb9f1"},
{file = "Cython-3.0.9-cp310-cp310-win32.whl", hash = "sha256:d6f3ff1cd6123973fe03e0fb8ee936622f976c0c41138969975824d08886572b"},
{file = "Cython-3.0.9-cp310-cp310-win_amd64.whl", hash = "sha256:56f3b643dbe14449248bbeb9a63fe3878a24256664bc8c8ef6efd45d102596d8"},
{file = "Cython-3.0.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:35e6665a20d6b8a152d72b7fd87dbb2af6bb6b18a235b71add68122d594dbd41"},
{file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f92f4960c40ad027bd8c364c50db11104eadc59ffeb9e5b7f605ca2f05946e20"},
{file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38df37d0e732fbd9a2fef898788492e82b770c33d1e4ed12444bbc8a3b3f89c0"},
{file = "Cython-3.0.9-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad7fd88ebaeaf2e76fd729a8919fae80dab3d6ac0005e28494261d52ff347a8f"},
{file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1365d5f76bf4d19df3d19ce932584c9bb76e9fb096185168918ef9b36e06bfa4"},
{file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c232e7f279388ac9625c3e5a5a9f0078a9334959c5d6458052c65bbbba895e1e"},
{file = "Cython-3.0.9-cp311-cp311-win32.whl", hash = "sha256:357e2fad46a25030b0c0496487e01a9dc0fdd0c09df0897f554d8ba3c1bc4872"},
{file = "Cython-3.0.9-cp311-cp311-win_amd64.whl", hash = "sha256:1315aee506506e8d69cf6631d8769e6b10131fdcc0eb66df2698f2a3ddaeeff2"},
{file = "Cython-3.0.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:157973807c2796addbed5fbc4d9c882ab34bbc60dc297ca729504901479d5df7"},
{file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00b105b5d050645dd59e6767bc0f18b48a4aa11c85f42ec7dd8181606f4059e3"},
{file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac5536d09bef240cae0416d5a703d298b74c7bbc397da803ac9d344e732d4369"},
{file = "Cython-3.0.9-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09c44501d476d16aaa4cbc29c87f8c0f54fc20e69b650d59cbfa4863426fc70c"},
{file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cc9c3b9f20d8e298618e5ccd32083ca386e785b08f9893fbec4c50b6b85be772"},
{file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a30d96938c633e3ec37000ac3796525da71254ef109e66bdfd78f29891af6454"},
{file = "Cython-3.0.9-cp312-cp312-win32.whl", hash = "sha256:757ca93bdd80702546df4d610d2494ef2e74249cac4d5ba9464589fb464bd8a3"},
{file = "Cython-3.0.9-cp312-cp312-win_amd64.whl", hash = "sha256:1dc320a9905ab95414013f6de805efbff9e17bb5fb3b90bbac533f017bec8136"},
{file = "Cython-3.0.9-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4ae349960ebe0da0d33724eaa7f1eb866688fe5434cc67ce4dbc06d6a719fbfc"},
{file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63d2537bf688247f76ded6dee28ebd26274f019309aef1eb4f2f9c5c482fde2d"},
{file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36f5a2dfc724bea1f710b649f02d802d80fc18320c8e6396684ba4a48412445a"},
{file = "Cython-3.0.9-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:deaf4197d4b0bcd5714a497158ea96a2bd6d0f9636095437448f7e06453cc83d"},
{file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:000af6deb7412eb7ac0c635ff5e637fb8725dd0a7b88cc58dfc2b3de14e701c4"},
{file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:15c7f5c2d35bed9aa5f2a51eaac0df23ae72f2dbacf62fc672dd6bfaa75d2d6f"},
{file = "Cython-3.0.9-cp36-cp36m-win32.whl", hash = "sha256:f49aa4970cd3bec66ac22e701def16dca2a49c59cceba519898dd7526e0be2c0"},
{file = "Cython-3.0.9-cp36-cp36m-win_amd64.whl", hash = "sha256:4558814fa025b193058d42eeee498a53d6b04b2980d01339fc2444b23fd98e58"},
{file = "Cython-3.0.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:539cd1d74fd61f6cfc310fa6bbbad5adc144627f2b7486a07075d4e002fd6aad"},
{file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3232926cd406ee02eabb732206f6e882c3aed9d58f0fea764013d9240405bcf"},
{file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33b6ac376538a7fc8c567b85d3c71504308a9318702ec0485dd66c059f3165cb"},
{file = "Cython-3.0.9-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2cc92504b5d22ac66031ffb827bd3a967fc75a5f0f76ab48bce62df19be6fdfd"},
{file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:22b8fae756c5c0d8968691bed520876de452f216c28ec896a00739a12dba3bd9"},
{file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:9cda0d92a09f3520f29bd91009f1194ba9600777c02c30c6d2d4ac65fb63e40d"},
{file = "Cython-3.0.9-cp37-cp37m-win32.whl", hash = "sha256:ec612418490941ed16c50c8d3784c7bdc4c4b2a10c361259871790b02ec8c1db"},
{file = "Cython-3.0.9-cp37-cp37m-win_amd64.whl", hash = "sha256:976c8d2bedc91ff6493fc973d38b2dc01020324039e2af0e049704a8e1b22936"},
{file = "Cython-3.0.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5055988b007c92256b6e9896441c3055556038c3497fcbf8c921a6c1fce90719"},
{file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9360606d964c2d0492a866464efcf9d0a92715644eede3f6a2aa696de54a137"},
{file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02c6e809f060bed073dc7cba1648077fe3b68208863d517c8b39f3920eecf9dd"},
{file = "Cython-3.0.9-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:95ed792c966f969cea7489c32ff90150b415c1f3567db8d5a9d489c7c1602dac"},
{file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8edd59d22950b400b03ca78d27dc694d2836a92ef0cac4f64cb4b2ff902f7e25"},
{file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4cf0ed273bf60e97922fcbbdd380c39693922a597760160b4b4355e6078ca188"},
{file = "Cython-3.0.9-cp38-cp38-win32.whl", hash = "sha256:5eb9bd4ae12ebb2bc79a193d95aacf090fbd8d7013e11ed5412711650cb34934"},
{file = "Cython-3.0.9-cp38-cp38-win_amd64.whl", hash = "sha256:44457279da56e0f829bb1fc5a5dc0836e5d498dbcf9b2324f32f7cc9d2ec6569"},
{file = "Cython-3.0.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c4b419a1adc2af43f4660e2f6eaf1e4fac2dbac59490771eb8ac3d6063f22356"},
{file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f836192140f033b2319a0128936367c295c2b32e23df05b03b672a6015757ea"},
{file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fd198c1a7f8e9382904d622cc0efa3c184605881fd5262c64cbb7168c4c1ec5"},
{file = "Cython-3.0.9-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a274fe9ca5c53fafbcf5c8f262f8ad6896206a466f0eeb40aaf36a7951e957c0"},
{file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:158c38360bbc5063341b1e78d3737f1251050f89f58a3df0d10fb171c44262be"},
{file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bf30b045f7deda0014b042c1b41c1d272facc762ab657529e3b05505888e878"},
{file = "Cython-3.0.9-cp39-cp39-win32.whl", hash = "sha256:9a001fd95c140c94d934078544ff60a3c46aca2dc86e75a76e4121d3cd1f4b33"},
{file = "Cython-3.0.9-cp39-cp39-win_amd64.whl", hash = "sha256:530c01c4aebba709c0ec9c7ecefe07177d0b9fd7ffee29450a118d92192ccbdf"},
{file = "Cython-3.0.9-py2.py3-none-any.whl", hash = "sha256:bf96417714353c5454c2e3238fca9338599330cf51625cdc1ca698684465646f"},
{file = "Cython-3.0.9.tar.gz", hash = "sha256:a2d354f059d1f055d34cfaa62c5b68bc78ac2ceab6407148d47fb508cf3ba4f3"},
{file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"},
{file = "coverage-7.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ccd341521be3d1b3daeb41960ae94a5e87abe2f46f17224ba5d6f2b8398016cf"},
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fa497a8ab37784fbb20ab699c246053ac294d13fc7eb40ec007a5043ec91f8"},
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1a93009cb80730c9bca5d6d4665494b725b6e8e157c1cb7f2db5b4b122ea562"},
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690db6517f09336559dc0b5f55342df62370a48f5469fabf502db2c6d1cffcd2"},
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:09c3255458533cb76ef55da8cc49ffab9e33f083739c8bd4f58e79fecfe288f7"},
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ce1415194b4a6bd0cdcc3a1dfbf58b63f910dcb7330fe15bdff542c56949f87"},
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b91cbc4b195444e7e258ba27ac33769c41b94967919f10037e6355e998af255c"},
{file = "coverage-7.4.4-cp310-cp310-win32.whl", hash = "sha256:598825b51b81c808cb6f078dcb972f96af96b078faa47af7dfcdf282835baa8d"},
{file = "coverage-7.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:09ef9199ed6653989ebbcaacc9b62b514bb63ea2f90256e71fea3ed74bd8ff6f"},
{file = "coverage-7.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f9f50e7ef2a71e2fae92774c99170eb8304e3fdf9c8c3c7ae9bab3e7229c5cf"},
{file = "coverage-7.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:623512f8ba53c422fcfb2ce68362c97945095b864cda94a92edbaf5994201083"},
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0513b9508b93da4e1716744ef6ebc507aff016ba115ffe8ecff744d1322a7b63"},
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40209e141059b9370a2657c9b15607815359ab3ef9918f0196b6fccce8d3230f"},
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a2b2b78c78293782fd3767d53e6474582f62443d0504b1554370bde86cc8227"},
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:73bfb9c09951125d06ee473bed216e2c3742f530fc5acc1383883125de76d9cd"},
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f384c3cc76aeedce208643697fb3e8437604b512255de6d18dae3f27655a384"},
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54eb8d1bf7cacfbf2a3186019bcf01d11c666bd495ed18717162f7eb1e9dd00b"},
{file = "coverage-7.4.4-cp311-cp311-win32.whl", hash = "sha256:cac99918c7bba15302a2d81f0312c08054a3359eaa1929c7e4b26ebe41e9b286"},
{file = "coverage-7.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:b14706df8b2de49869ae03a5ccbc211f4041750cd4a66f698df89d44f4bd30ec"},
{file = "coverage-7.4.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:201bef2eea65e0e9c56343115ba3814e896afe6d36ffd37bab783261db430f76"},
{file = "coverage-7.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41c9c5f3de16b903b610d09650e5e27adbfa7f500302718c9ffd1c12cf9d6818"},
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d898fe162d26929b5960e4e138651f7427048e72c853607f2b200909794ed978"},
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ea79bb50e805cd6ac058dfa3b5c8f6c040cb87fe83de10845857f5535d1db70"},
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce4b94265ca988c3f8e479e741693d143026632672e3ff924f25fab50518dd51"},
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00838a35b882694afda09f85e469c96367daa3f3f2b097d846a7216993d37f4c"},
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fdfafb32984684eb03c2d83e1e51f64f0906b11e64482df3c5db936ce3839d48"},
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:69eb372f7e2ece89f14751fbcbe470295d73ed41ecd37ca36ed2eb47512a6ab9"},
{file = "coverage-7.4.4-cp312-cp312-win32.whl", hash = "sha256:137eb07173141545e07403cca94ab625cc1cc6bc4c1e97b6e3846270e7e1fea0"},
{file = "coverage-7.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:d71eec7d83298f1af3326ce0ff1d0ea83c7cb98f72b577097f9083b20bdaf05e"},
{file = "coverage-7.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ae728ff3b5401cc320d792866987e7e7e880e6ebd24433b70a33b643bb0384"},
{file = "coverage-7.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc4f1358cb0c78edef3ed237ef2c86056206bb8d9140e73b6b89fbcfcbdd40e1"},
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8130a2aa2acb8788e0b56938786c33c7c98562697bf9f4c7d6e8e5e3a0501e4a"},
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf271892d13e43bc2b51e6908ec9a6a5094a4df1d8af0bfc360088ee6c684409"},
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4cdc86d54b5da0df6d3d3a2f0b710949286094c3a6700c21e9015932b81447e"},
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ae71e7ddb7a413dd60052e90528f2f65270aad4b509563af6d03d53e979feafd"},
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:38dd60d7bf242c4ed5b38e094baf6401faa114fc09e9e6632374388a404f98e7"},
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa5b1c1bfc28384f1f53b69a023d789f72b2e0ab1b3787aae16992a7ca21056c"},
{file = "coverage-7.4.4-cp38-cp38-win32.whl", hash = "sha256:dfa8fe35a0bb90382837b238fff375de15f0dcdb9ae68ff85f7a63649c98527e"},
{file = "coverage-7.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:b2991665420a803495e0b90a79233c1433d6ed77ef282e8e152a324bbbc5e0c8"},
{file = "coverage-7.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b799445b9f7ee8bf299cfaed6f5b226c0037b74886a4e11515e569b36fe310d"},
{file = "coverage-7.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4d33f418f46362995f1e9d4f3a35a1b6322cb959c31d88ae56b0298e1c22357"},
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aadacf9a2f407a4688d700e4ebab33a7e2e408f2ca04dbf4aef17585389eff3e"},
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c95949560050d04d46b919301826525597f07b33beba6187d04fa64d47ac82e"},
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff7687ca3d7028d8a5f0ebae95a6e4827c5616b31a4ee1192bdfde697db110d4"},
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5fc1de20b2d4a061b3df27ab9b7c7111e9a710f10dc2b84d33a4ab25065994ec"},
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c74880fc64d4958159fbd537a091d2a585448a8f8508bf248d72112723974cbd"},
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:742a76a12aa45b44d236815d282b03cfb1de3b4323f3e4ec933acfae08e54ade"},
{file = "coverage-7.4.4-cp39-cp39-win32.whl", hash = "sha256:d89d7b2974cae412400e88f35d86af72208e1ede1a541954af5d944a8ba46c57"},
{file = "coverage-7.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ca28a302acb19b6af89e90f33ee3e1906961f94b54ea37de6737b7ca9d8827c"},
{file = "coverage-7.4.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b2c5edc4ac10a7ef6605a966c58929ec6c1bd0917fb8c15cb3363f65aa40e677"},
{file = "coverage-7.4.4.tar.gz", hash = "sha256:c901df83d097649e257e803be22592aedfd5182f07b3cc87d640bbb9afd50f49"},
]
[package.dependencies]
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
[package.extras]
toml = ["tomli"]
[[package]]
name = "debugpy"
version = "1.8.1"
@@ -500,13 +529,13 @@ files = [
[[package]]
name = "dm-control"
version = "1.0.16"
version = "1.0.14"
description = "Continuous control environments and MuJoCo Python bindings."
optional = false
python-versions = ">=3.8"
files = [
{file = "dm_control-1.0.16-py3-none-any.whl", hash = "sha256:341f582e26d88556ac0dd1963e11c25f93ef6649b07c75d04570a7de3edebd1b"},
{file = "dm_control-1.0.16.tar.gz", hash = "sha256:bbdd6dc54b4dbc33eef6c43294edb0a927c3a5f35621e698480f88f55f48a1fd"},
{file = "dm_control-1.0.14-py3-none-any.whl", hash = "sha256:883c63244a7ebf598700a97564ed19fffd3479ca79efd090aed881609cdb9fc6"},
{file = "dm_control-1.0.14.tar.gz", hash = "sha256:def1ece747b6f175c581150826b50f1a6134086dab34f8f3fd2d088ea035cf3d"},
]
[package.dependencies]
@@ -516,7 +545,7 @@ dm-tree = "!=0.1.2"
glfw = "*"
labmaze = "*"
lxml = "*"
mujoco = ">=3.1.1"
mujoco = ">=2.3.7"
numpy = ">=1.9.0"
protobuf = ">=3.19.4"
pyopengl = ">=3.1.4"
@@ -635,43 +664,6 @@ files = [
{file = "einops-0.7.0.tar.gz", hash = "sha256:b2b04ad6081a3b227080c9bf5e3ace7160357ff03043cd66cc5b2319eb7031d1"},
]
[[package]]
name = "etils"
version = "1.7.0"
description = "Collection of common python utils"
optional = false
python-versions = ">=3.10"
files = [
{file = "etils-1.7.0-py3-none-any.whl", hash = "sha256:61af8f7c242171de15e22e5da02d527cb9e677d11f8bcafe18fcc3548eee3e60"},
{file = "etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350"},
]
[package.dependencies]
fsspec = {version = "*", optional = true, markers = "extra == \"epath\""}
importlib_resources = {version = "*", optional = true, markers = "extra == \"epath\""}
typing_extensions = {version = "*", optional = true, markers = "extra == \"epy\""}
zipp = {version = "*", optional = true, markers = "extra == \"epath\""}
[package.extras]
all = ["etils[array-types]", "etils[eapp]", "etils[ecolab]", "etils[edc]", "etils[enp]", "etils[epath-gcs]", "etils[epath-s3]", "etils[epath]", "etils[epy]", "etils[etqdm]", "etils[etree-dm]", "etils[etree-jax]", "etils[etree-tf]", "etils[etree]"]
array-types = ["etils[enp]"]
dev = ["chex", "dataclass_array", "optree", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "torch"]
docs = ["etils[all,dev]", "sphinx-apitree[ext]"]
eapp = ["absl-py", "etils[epy]", "simple_parsing"]
ecolab = ["etils[enp]", "etils[epy]", "etils[etree]", "jupyter", "mediapy", "numpy", "packaging", "protobuf"]
edc = ["etils[epy]"]
enp = ["etils[epy]", "numpy"]
epath = ["etils[epy]", "fsspec", "importlib_resources", "typing_extensions", "zipp"]
epath-gcs = ["etils[epath]", "gcsfs"]
epath-s3 = ["etils[epath]", "s3fs"]
epy = ["typing_extensions"]
etqdm = ["absl-py", "etils[epy]", "tqdm"]
etree = ["etils[array-types]", "etils[enp]", "etils[epy]", "etils[etqdm]"]
etree-dm = ["dm-tree", "etils[etree]"]
etree-jax = ["etils[etree]", "jax[cpu]"]
etree-tf = ["etils[etree]", "tensorflow"]
lazy-imports = ["etils[ecolab]"]
[[package]]
name = "exceptiongroup"
version = "1.2.0"
@@ -686,6 +678,17 @@ files = [
[package.extras]
test = ["pytest (>=6)"]
[[package]]
name = "farama-notifications"
version = "0.0.4"
description = "Notifications for all Farama Foundation maintained libraries."
optional = false
python-versions = "*"
files = [
{file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"},
{file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"},
]
[[package]]
name = "fasteners"
version = "0.19"
@@ -887,43 +890,58 @@ files = [
protobuf = ["grpcio-tools (>=1.62.1)"]
[[package]]
name = "gym"
version = "0.26.2"
description = "Gym: A universal API for reinforcement learning environments"
name = "gymnasium"
version = "0.29.1"
description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)."
optional = false
python-versions = ">=3.6"
python-versions = ">=3.8"
files = [
{file = "gym-0.26.2.tar.gz", hash = "sha256:e0d882f4b54f0c65f203104c24ab8a38b039f1289986803c7d02cdbe214fbcc4"},
{file = "gymnasium-0.29.1-py3-none-any.whl", hash = "sha256:61c3384b5575985bb7f85e43213bcb40f36fcdff388cae6bc229304c71f2843e"},
{file = "gymnasium-0.29.1.tar.gz", hash = "sha256:1a532752efcb7590478b1cc7aa04f608eb7a2fdad5570cd217b66b6a35274bb1"},
]
[package.dependencies]
cloudpickle = ">=1.2.0"
gym_notices = ">=0.0.4"
numpy = ">=1.18.0"
farama-notifications = ">=0.0.1"
numpy = ">=1.21.0"
typing-extensions = ">=4.3.0"
[package.extras]
accept-rom-license = ["autorom[accept-rom-license] (>=0.4.2,<0.5.0)"]
all = ["ale-py (>=0.8.0,<0.9.0)", "box2d-py (==2.3.5)", "imageio (>=2.14.1)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (==2.2)", "mujoco_py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (==2.1.0)", "pytest (==7.0.1)", "swig (==4.*)"]
atari = ["ale-py (>=0.8.0,<0.9.0)"]
box2d = ["box2d-py (==2.3.5)", "pygame (==2.1.0)", "swig (==4.*)"]
classic-control = ["pygame (==2.1.0)"]
mujoco = ["imageio (>=2.14.1)", "mujoco (==2.2)"]
mujoco-py = ["mujoco_py (>=2.1,<2.2)"]
other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-python (>=3.0)"]
testing = ["box2d-py (==2.3.5)", "imageio (>=2.14.1)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (==2.2)", "mujoco_py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (==2.1.0)", "pytest (==7.0.1)", "swig (==4.*)"]
toy-text = ["pygame (==2.1.0)"]
all = ["box2d-py (==2.3.5)", "cython (<3)", "imageio (>=2.14.1)", "jax (>=0.4.0)", "jaxlib (>=0.4.0)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (>=2.3.3)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (>=2.1.3)", "shimmy[atari] (>=0.1.0,<1.0)", "swig (==4.*)", "torch (>=1.0.0)"]
atari = ["shimmy[atari] (>=0.1.0,<1.0)"]
box2d = ["box2d-py (==2.3.5)", "pygame (>=2.1.3)", "swig (==4.*)"]
classic-control = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
jax = ["jax (>=0.4.0)", "jaxlib (>=0.4.0)"]
mujoco = ["imageio (>=2.14.1)", "mujoco (>=2.3.3)"]
mujoco-py = ["cython (<3)", "cython (<3)", "mujoco-py (>=2.1,<2.2)", "mujoco-py (>=2.1,<2.2)"]
other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-python (>=3.0)", "torch (>=1.0.0)"]
testing = ["pytest (==7.1.3)", "scipy (>=1.7.3)"]
toy-text = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
[[package]]
name = "gym-notices"
version = "0.0.8"
description = "Notices for gym"
name = "gymnasium-robotics"
version = "1.2.4"
description = "Robotics environments for the Gymnasium repo."
optional = false
python-versions = "*"
python-versions = ">=3.8"
files = [
{file = "gym-notices-0.0.8.tar.gz", hash = "sha256:ad25e200487cafa369728625fe064e88ada1346618526102659b4640f2b4b911"},
{file = "gym_notices-0.0.8-py3-none-any.whl", hash = "sha256:e5f82e00823a166747b4c2a07de63b6560b1acb880638547e0cabf825a01e463"},
{file = "gymnasium-robotics-1.2.4.tar.gz", hash = "sha256:d304192b066f8b800599dfbe3d9d90bba9b761ee884472bdc4d05968a8bc61cb"},
{file = "gymnasium_robotics-1.2.4-py3-none-any.whl", hash = "sha256:c2cb23e087ca0280ae6802837eb7b3a6d14e5bd24c00803ab09f015fcff3eef5"},
]
[package.dependencies]
gymnasium = ">=0.26"
imageio = "*"
Jinja2 = ">=3.0.3"
mujoco = ">=2.3.3,<3.0"
numpy = ">=1.21.0"
PettingZoo = ">=1.23.0"
[package.extras]
mujoco-py = ["cython (<3)", "mujoco-py (>=2.1,<2.2)"]
testing = ["Jinja2 (>=3.0.3)", "PettingZoo (>=1.23.0)", "cython (<3)", "mujoco-py (>=2.1,<2.2)", "pytest (==7.0.1)"]
[[package]]
name = "h5py"
version = "3.10.0"
@@ -1105,21 +1123,6 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link
perf = ["ipython"]
testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
[[package]]
name = "importlib-resources"
version = "6.3.2"
description = "Read resources from Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "importlib_resources-6.3.2-py3-none-any.whl", hash = "sha256:f41f4098b16cd140a97d256137cfd943d958219007990b2afb00439fc623f580"},
{file = "importlib_resources-6.3.2.tar.gz", hash = "sha256:963eb79649252b0160c1afcfe5a1d3fe3ad66edd0a8b114beacffb70c0674223"},
]
[package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["jaraco.collections", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"]
[[package]]
name = "iniconfig"
version = "2.0.0"
@@ -1457,65 +1460,44 @@ tests = ["pytest (>=4.6)"]
[[package]]
name = "mujoco"
version = "3.1.3"
version = "2.3.7"
description = "MuJoCo Physics Simulator"
optional = false
python-versions = ">=3.8"
files = [
{file = "mujoco-3.1.3-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:1a07e33443ca88c77128336e550502c58721e37b3830af29f0118311c17d826e"},
{file = "mujoco-3.1.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:da442c45fa08cf7f307a6f2484ff382b90714b9f52aaceffd5fcb8536dbdc11c"},
{file = "mujoco-3.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec08dbfddef6e4c6d7b03685b929ed134e8eb9d0dbc788752ff54216b7b3544e"},
{file = "mujoco-3.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1578279ff581ed1c70893cc16ecf48a048a14568e9e64b446a2d32c22b1154c"},
{file = "mujoco-3.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:9a359e7787e1d0bbdb9fafeb31df61261a4cdc42d0a5d77c91fbe57c63e4c6fd"},
{file = "mujoco-3.1.3-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:b070805d65ee6b708ddf1a16a16fc2073ce2d1eea8ea26352b8aee4071de274c"},
{file = "mujoco-3.1.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d789a95150cf1bef21e3a3431c26263730b0437ec3b4794b2eed0f900185746e"},
{file = "mujoco-3.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6661aa27c81be338ce0973ba6e83f655ff3cc023ea9d62398f130b46478f708a"},
{file = "mujoco-3.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f79dc6134c90a7274d2663c07bea6d45629ea52ce40bf6722c5d506df909b4b9"},
{file = "mujoco-3.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:b1df674d9486e1bd2e93fb69009d8db4adcf4b3b7edc92da5c98d1c6a2ea7a28"},
{file = "mujoco-3.1.3-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:51841750310a1c4b5e7c7f19d28fe5e3deea0e2c7cc60ebab33c2f07360b1700"},
{file = "mujoco-3.1.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fe4318af5b14ea39bc5b8892c69797a1a9deb02199178814be16abb5611308fb"},
{file = "mujoco-3.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:079f293a56c2b3aa6b4101c3822ee5587b5cc9bf35028afdd1f2128db102ad20"},
{file = "mujoco-3.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4c120414bf89a11538e3f5eb1de6bcd6c4aeade9775ecad3e4eea27d88e1492"},
{file = "mujoco-3.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:7fc9d69383dc0f7c4b775b2be829a065fb78dca743a25f9d864d52174c916b2b"},
{file = "mujoco-3.1.3-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:5a29004079a40d23836228647bae9ea41f77fd7e407e8ad642dc72054e5a099e"},
{file = "mujoco-3.1.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c1d1e083d8825faf9e2609d4e749cc5629ed7735374ed68eb3dde63dd0e4fe73"},
{file = "mujoco-3.1.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cc267cf0de8de3c8b317f7c12b2d7a484a7f462263f8ce4c8ae18e9d6817897"},
{file = "mujoco-3.1.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d67a2aae8d5d58f7ac41d0bcffcd745955415887843bc34da8e3d794b46afbae"},
{file = "mujoco-3.1.3-cp38-cp38-win_amd64.whl", hash = "sha256:8ddb9a07d5ad59c67f2d7e79568cba27ad68cf2284a68370f2054dce2e6e4128"},
{file = "mujoco-3.1.3-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:acdc761e8fa7d4bfb9f262b8886dbb3dd41a957c3ef7ec126aae3342f68b1293"},
{file = "mujoco-3.1.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:49bad0da02ebf67ab37a6f6fe435dfc6339f0b46b51b452ee79aaffa5b73659b"},
{file = "mujoco-3.1.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:daa6a3ca50a3769ebfd59274651d2edc76b177cd950560022120fb77cd51f607"},
{file = "mujoco-3.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80f3e99d1f2bb02bf7903ba4fef9c31e05fba439b7292ed751390ec78e1eb890"},
{file = "mujoco-3.1.3-cp39-cp39-win_amd64.whl", hash = "sha256:2d2fe38b1a7f64e708e8b9a96cf7677027b33fb6e059184163976c6c03fef4cc"},
{file = "mujoco-3.1.3.tar.gz", hash = "sha256:f700d074031060b46111ddb60432d00425f821eeeaf0ccc76ed95d47861bd4de"},
{file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"},
{file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a934315f858a4e0c4b90a682fde519471cfdd7baa64435179da8cd20d4ae3f99"},
{file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:36513024330f88b5f9a43558efef5692b33599bffd5141029b690a27918ffcbe"},
{file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d4eede8ba8210fbd3d3cd1dbf69e24dd1541aa74c5af5b8adbbbf65504b6dba"},
{file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab85fafc9d5a091c712947573b7e694512d283876bf7f33ae3f8daad3a20c0db"},
{file = "mujoco-2.3.7-cp310-cp310-win_amd64.whl", hash = "sha256:f8b7e13fef8c813d91b78f975ed0815157692777907ffa4b4be53a4edb75019b"},
{file = "mujoco-2.3.7-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:779520216f72a8e370e3f0cdd71b45c3b7384c63331a3189194c930a3e7cff5c"},
{file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9d4018053879016282d27ab7a91e292c72d44efb5a88553feacfe5b843dde103"},
{file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:3149b16b8122ee62642474bfd2871064e8edc40235471cf5d84be3569afc0312"},
{file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c08660a8d52ef3efde76095f0991e807703a950c1e882d2bcd984b9a846626f7"},
{file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:426af8965f8636d94a0f75740c3024a62b3e585020ee817ef5208ec844a1ad94"},
{file = "mujoco-2.3.7-cp311-cp311-win_amd64.whl", hash = "sha256:215415a8e98a4b50625beae859079d5e0810b2039e50420f0ba81763c34abb59"},
{file = "mujoco-2.3.7-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:8b78d14f4c60cea3c58e046bd4de453fb5b9b33aca6a25fc91d39a53f3a5342a"},
{file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5c6f5a51d6f537a4bf294cf73816f3a6384573f8f10a5452b044df2771412a96"},
{file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:ea8911e6047f92d7d775701f37e4c093971b6def3160f01d0b6926e29a7e962e"},
{file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7473a3de4dd1a8762d569ffb139196b4c5e7eca27d256df97b6cd4c66d2a09b2"},
{file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7e2d8f93d2495ec74efec84e5118ecc6e1d85157a844789c73c9ac9a4e28e"},
{file = "mujoco-2.3.7-cp38-cp38-win_amd64.whl", hash = "sha256:720bc228a2023b3b0ed6af78f5b0f8ea36867be321d473321555c57dbf6e4e5b"},
{file = "mujoco-2.3.7-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:855e79686366442aa410246043b44f7d842d3900d68fe7e37feb42147db9d707"},
{file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:98947f4a742d34d36f3c3f83e9167025bb0414bbaa4bd859b0673bdab9959963"},
{file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:d42818f2ee5d1632dbce31d136ed5ff868db54b04e4e9aca0c5a3ac329f8a90f"},
{file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9237e1ba14bced9449c31199e6d5be49547f3a4c99bc83b196af7ca45fd73b83"},
{file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b728ea638245b150e2650c5433e6952e0ed3798c63e47e264574270caea2a3"},
{file = "mujoco-2.3.7-cp39-cp39-win_amd64.whl", hash = "sha256:9c721a5042b99d948d5f0296a534bcce3f142c777c4d7642f503a539513f3912"},
{file = "mujoco-2.3.7.tar.gz", hash = "sha256:422041f1ce37c6d151fbced1048df626837e94fe3cd9f813585907046336a7d0"},
]
[package.dependencies]
absl-py = "*"
etils = {version = "*", extras = ["epath"]}
glfw = "*"
numpy = "*"
pyopengl = "*"
[[package]]
name = "mujoco-py"
version = "2.1.2.14"
description = ""
optional = false
python-versions = ">=3.6"
files = [
{file = "mujoco-py-2.1.2.14.tar.gz", hash = "sha256:eb5b14485acf80a3cf8c15f4b080c6a28a9f79e68869aa696d16cbd51ea7706f"},
{file = "mujoco_py-2.1.2.14-py3-none-any.whl", hash = "sha256:37c0b41bc0153a8a0eb3663103a67c60f65467753f74e4ff6e68b879f3e3a71f"},
]
[package.dependencies]
cffi = ">=1.10"
Cython = ">=0.27.2"
fasteners = ">=0.15,<1.0"
glfw = ">=1.4.0"
imageio = ">=2.1.2"
numpy = ">=1.11"
[[package]]
name = "networkx"
version = "3.2.1"
@@ -1790,6 +1772,31 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.9.2)"]
[[package]]
name = "pettingzoo"
version = "1.24.3"
description = "Gymnasium for multi-agent reinforcement learning."
optional = false
python-versions = ">=3.8"
files = [
{file = "pettingzoo-1.24.3-py3-none-any.whl", hash = "sha256:23ed90517d2e8a7098bdaf5e31234b3a7f7b73ca578d70d1ca7b9d0cb0e37982"},
{file = "pettingzoo-1.24.3.tar.gz", hash = "sha256:91f9094f18e06fb74b98f4099cd22e8ae4396125e51719d50b30c9f1c7ab07e6"},
]
[package.dependencies]
gymnasium = ">=0.28.0"
numpy = ">=1.21.0"
[package.extras]
all = ["box2d-py (==2.3.5)", "chess (==1.9.4)", "multi-agent-ale-py (==0.1.11)", "pillow (>=8.0.1)", "pygame (==2.3.0)", "pymunk (==6.2.0)", "rlcard (==1.0.5)", "scipy (>=1.4.1)", "shimmy[openspiel] (>=1.2.0)"]
atari = ["multi-agent-ale-py (==0.1.11)", "pygame (==2.3.0)"]
butterfly = ["pygame (==2.3.0)", "pymunk (==6.2.0)"]
classic = ["chess (==1.9.4)", "pygame (==2.3.0)", "rlcard (==1.0.5)", "shimmy[openspiel] (>=1.2.0)"]
mpe = ["pygame (==2.3.0)"]
other = ["pillow (>=8.0.1)"]
sisl = ["box2d-py (==2.3.5)", "pygame (==2.3.0)", "pymunk (==6.2.0)", "scipy (>=1.4.1)"]
testing = ["AutoROM", "pre-commit", "pynput", "pytest", "pytest-cov", "pytest-markdown-docs", "pytest-xdist"]
[[package]]
name = "pillow"
version = "10.2.0"
@@ -2192,6 +2199,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-cov"
version = "5.0.0"
description = "Pytest plugin for measuring coverage."
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
]
[package.dependencies]
coverage = {version = ">=5.2.1", extras = ["toml"]}
pytest = ">=4.6"
[package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
@@ -3305,4 +3330,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "cbd9aedcb3a24417b85124fb94db706dd6ca0a90dfb610b0aebdcd3aa2a0333c"
content-hash = "8800bb8b24312d17b765cd2ce2799f49436171dd5fbf1bec3b07f853cfa9befd"

View File

@@ -21,7 +21,6 @@ packages = [{include = "lerobot"}]
[tool.poetry.dependencies]
python = "^3.10"
cython = "^3.0.8"
termcolor = "^2.4.0"
omegaconf = "^2.3.0"
dm-env = "^1.6"
@@ -42,23 +41,25 @@ mpmath = "^1.3.0"
torch = {version = "^2.2.1", source = "torch-cpu"}
tensordict = {git = "https://github.com/pytorch/tensordict"}
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
mujoco = "^3.1.2"
mujoco-py = "^2.1.2.14"
gym = "^0.26.2"
mujoco = "^2.3.7"
opencv-python = "^4.9.0.80"
diffusers = "^0.26.3"
torchvision = {version = "^0.17.1", source = "torch-cpu"}
h5py = "^3.10.0"
dm = "^1.3"
dm-control = "^1.0.16"
dm-control = "1.0.14"
robomimic = "0.2.0"
huggingface-hub = "^0.21.4"
gymnasium-robotics = "^1.2.4"
gymnasium = "^0.29.1"
cmake = "^3.29.0.1"
[tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2"
debugpy = "^1.8.1"
pytest = "^8.1.0"
pytest-cov = "^5.0.0"
[[tool.poetry.source]]

View File

@@ -1,4 +1,4 @@
name: Test
name: Tests
on:
pull_request:
@@ -10,20 +10,15 @@ on:
- main
jobs:
test:
tests:
if: |
${{ github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'CI') }} ||
${{ github.event_name == 'push' }}
runs-on: ubuntu-latest
env:
POETRY_VERSION: 1.8.1
POETRY_VERSION: 1.8.2
DATA_DIR: tests/data
TMPDIR: ~/tmp
TEMP: ~/tmp
TMP: ~/tmp
PYOPENGL_PLATFORM: egl
MUJOCO_GL: egl
LEROBOT_TESTS_DEVICE: cpu
steps:
#----------------------------------------------
# check-out repo and set-up python
@@ -86,6 +81,10 @@ jobs:
- name: Install dependencies
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
env:
TMPDIR: ~/tmp
TEMP: ~/tmp
TMP: ~/tmp
run: |
mkdir ~/tmp
poetry install --no-interaction --no-root
@@ -110,35 +109,126 @@ jobs:
run: poetry install --no-interaction
#----------------------------------------------
# run tests
# run tests & coverage
#----------------------------------------------
- name: Run tests
env:
LEROBOT_TESTS_DEVICE: cpu
run: |
source .venv/bin/activate
pytest tests
pytest --cov=./lerobot --cov-report=xml tests
- name: Test train pusht end-to-end
# TODO(aliberts): Link with HF Codecov account
# - name: Upload coverage reports to Codecov with GitHub Action
# uses: codecov/codecov-action@v4
# with:
# files: ./coverage.xml
# verbose: true
#----------------------------------------------
# run end-to-end tests
#----------------------------------------------
- name: Test train ACT on ALOHA end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/train.py \
hydra.job.name=pusht \
policy=act \
env=aloha \
wandb.enable=False \
offline_steps=2 \
online_steps=0 \
device=cpu \
save_model=true \
save_freq=2 \
horizon=20 \
policy.batch_size=2 \
hydra.run.dir=tests/outputs/act/
- name: Test eval ACT on ALOHA end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config tests/outputs/act/.hydra/config.yaml \
eval_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/act/models/2.pt
# TODO(aliberts): This takes ~2mn to run, needs to be improved
# - name: Test eval ACT on ALOHA end-to-end (policy is None)
# run: |
# source .venv/bin/activate
# python lerobot/scripts/eval.py \
# --config lerobot/configs/default.yaml \
# policy=act \
# env=aloha \
# eval_episodes=1 \
# device=cpu
- name: Test train Diffusion on PushT end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/train.py \
policy=diffusion \
env=pusht \
wandb.enable=False \
offline_steps=2 \
online_steps=0 \
device=cpu \
save_model=true \
save_freq=1 \
hydra.run.dir=tests/outputs/
save_freq=2 \
hydra.run.dir=tests/outputs/diffusion/
- name: Test eval pusht end-to-end
- name: Test eval Diffusion on PushT end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
hydra.job.name=pusht \
env=pusht \
wandb.enable=False \
--config tests/outputs/diffusion/.hydra/config.yaml \
eval_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/models/1.pt
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
- name: Test eval Diffusion on PushT end-to-end (policy is None)
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
policy=diffusion \
env=pusht \
eval_episodes=1 \
device=cpu
- name: Test train TDMPC on Simxarm end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/train.py \
policy=tdmpc \
env=simxarm \
wandb.enable=False \
offline_steps=1 \
online_steps=1 \
device=cpu \
save_model=true \
save_freq=2 \
hydra.run.dir=tests/outputs/tdmpc/
- name: Test eval TDMPC on Simxarm end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config tests/outputs/tdmpc/.hydra/config.yaml \
eval_episodes=1 \
env.episode_length=8 \
device=cpu \
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
- name: Test eval TDPMC on Simxarm end-to-end (policy is None)
run: |
source .venv/bin/activate
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
policy=tdmpc \
env=simxarm \
eval_episodes=1 \
device=cpu

View File

@@ -14,11 +14,11 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.1
rev: v3.15.2
hooks:
- id: pyupgrade
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
rev: v0.3.4
hooks:
- id: ruff
args: [--fix]

25
LICENSE
View File

@@ -253,6 +253,31 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
## Some of lerobot's code is derived from simxarm, which is subject to the following copyright notice:
MIT License
Copyright (c) 2023 Nicklas Hansen & Yanjie Ze
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
## Some of lerobot's code is derived from ALOHA, which is subject to the following copyright notice:
MIT License

483
README.md
View File

@@ -1,89 +1,360 @@
# Le Robot
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="media/lerobot-logo-thumbnail.png">
<source media="(prefers-color-scheme: light)" srcset="media/lerobot-logo-thumbnail.png">
<img alt="LeRobot, Hugging Face Robotics Library" src="media/lerobot-logo-thumbnail.png" style="max-width: 100%;">
</picture>
<br/>
<br/>
</p>
#### State-of-the-art machine learning for real-world robotics
<div align="center">
Le Robot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
[![Tests](https://github.com/huggingface/lerobot/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/test.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/huggingface/lerobot/branch/main/graph/badge.svg?token=TODO)](https://codecov.io/gh/huggingface/lerobot)
[![Python versions](https://img.shields.io/pypi/pyversions/lerobot)](https://www.python.org/downloads/)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/huggingface/lerobot/blob/main/LICENSE)
[![Status](https://img.shields.io/pypi/status/lerobot)](https://pypi.org/project/lerobot/)
[![Version](https://img.shields.io/pypi/v/lerobot)](https://pypi.org/project/lerobot/)
[![Examples](https://img.shields.io/badge/Examples-green.svg)](https://github.com/huggingface/lerobot/tree/main/examples)
[![Discord](https://dcbadge.vercel.app/api/server/C5P34WJ68S?style=flat)](https://discord.gg/s3KuuzsPFb)
Le Robot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
</div>
Le Robot already provides a set of pretrained models, datasets with human collected demonstrations, and simulated environments so that everyone can get started. In the coming weeks, the plan is to add more and more supports for real-world robotics on the most affordable and capable robots out there.
<h3 align="center">
<p>State-of-the-art Machine Learning for real-world robotics</p>
</h3>
Le Robot is built upon [TorchRL](https://github.com/pytorch/rl) which provides abstractions and utilities for Reinforcement Learning.
---
## Acknowledgment
- Our ACT policy and ALOHA environment are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha/)
- Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
- Our TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulated environments so that everyone can get started. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
🤗 LeRobot hosts pretrained models and datasets on this HuggingFace community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
#### Examples of pretrained models and environments
<table>
<tr>
<td><img src="http://remicadene.com/assets/gif/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
<td><img src="http://remicadene.com/assets/gif/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
<td><img src="http://remicadene.com/assets/gif/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
</tr>
<tr>
<td align="center">ACT policy on ALOHA env</td>
<td align="center">TDMPC policy on SimXArm env</td>
<td align="center">Diffusion policy on PushT env</td>
</tr>
</table>
### Acknowledgment
- ACT policy and ALOHA environment are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha/)
- Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
- TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
- Abstractions and utilities for Reinforcement Learning come from [TorchRL](https://github.com/pytorch/rl)
## Installation
Create a virtual environment with Python 3.10, e.g. using `conda`:
Download our source code:
```bash
git clone https://github.com/huggingface/lerobot.git
cd lerobot
```
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
```bash
conda create -y -n lerobot python=3.10
conda activate lerobot
```
[Install `poetry`](https://python-poetry.org/docs/#installation) (if you don't have it already)
```
curl -sSL https://install.python-poetry.org | python -
```
Install dependencies
```
poetry install
```
If you encounter a disk space error, try to change your tmp dir to a location where you have enough disk space, e.g.
```
mkdir ~/tmp
export TMPDIR='~/tmp'
Then, install 🤗 LeRobot:
```bash
python -m pip install .
```
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
```
```bash
wandb login
```
## Usage
### Train
## Walkthrough
```
python lerobot/scripts/train.py \
hydra.job.name=pusht \
env=pusht
```
### Visualize offline buffer
.
├── lerobot
| ├── configs # contains hydra yaml files with all options that you can override in the command line
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, simxarm.yaml
| | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml
| ├── common # contains classes and utilities
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, simxarm
| | ├── envs # various sim environments: aloha, pusht, simxarm
| | └── policies # various policies: act, diffusion, tdmpc
| └── scripts # contains functions to execute via command line
| ├── visualize_dataset.py # load a dataset and render its demonstrations
| ├── eval.py # load policy and evaluate it on an environment
| └── train.py # train a policy via imitation learning and/or reinforcement learning
├── outputs # contains results of scripts execution: logs, videos, model checkpoints
├── .github
| └── workflows
| └── test.yml # defines install settings for continuous integration and specifies end-to-end tests
└── tests # contains pytest utilities for continuous integration
```
### Visualize datasets
You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities:
```python
""" Copy pasted from `examples/1_visualize_dataset.py` """
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
from torchrl.data.replay_buffers import SamplerWithoutReplacement
from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
# we use this sampler to sample 1 frame after the other
sampler = SamplerWithoutReplacement(shuffle=False)
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler)
video_paths = render_dataset(
dataset,
out_dir="outputs/visualize_dataset/example",
max_num_samples=300,
fps=50,
)
print(video_paths)
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
```
Or you can achieve the same result by executing our script from the command line:
```bash
python lerobot/scripts/visualize_dataset.py \
hydra.run.dir=tmp/$(date +"%Y_%m_%d") \
env=pusht
env=aloha \
task=sim_sim_transfer_cube_human \
hydra.run.dir=outputs/visualize_dataset/example
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
```
### Visualize online buffer / Eval
### Evaluate a pretrained policy
```
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
Or you can achieve the same result by executing our script from the command line:
```bash
python lerobot/scripts/eval.py \
hydra.run.dir=tmp/$(date +"%Y_%m_%d") \
env=pusht
--hub-id lerobot/diffusion_policy_pusht_image \
eval_episodes=10 \
hydra.run.dir=outputs/eval/example_hub
```
After training your own policy, you can also re-evaluate the checkpoints with:
```bash
python lerobot/scripts/eval.py \
--config PATH/TO/FOLDER/config.yaml \
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
eval_episodes=10 \
hydra.run.dir=outputs/eval/example_dir
```
## TODO
See `python lerobot/scripts/eval.py --help` for more instructions.
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/users/Cadene/projects/1)
### Train your own policy
Ask [Remi Cadene](re.cadene@gmail.com) for access if needed.
You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): check out [example 3](./examples/3_train_policy.py). After you run this, you may want to revisit [example 2](./examples/2_evaluate_pretrained_policy.py) to evaluate your training output!
In general, you can use our training script to easily train any policy on any environment:
```bash
python lerobot/scripts/train.py \
env=aloha \
task=sim_insertion \
dataset_id=aloha_sim_insertion_scripted \
policy=act \
hydra.run.dir=outputs/train/aloha_act
```
## Contribute
Feel free to open issues and PRs, and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co).
### TODO
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
### Follow our style
```bash
# install if needed
pre-commit install
# apply style and linter checks before git commit
pre-commit
```
### Add dependencies
Instead of using `pip` directly, we use `poetry` for development purposes to easily track our dependencies.
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
Install the project with:
```bash
poetry install
```
Then, the equivalent of `pip install some-package`, would just be:
```bash
poetry add some-package
```
**NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated to the CI. To do this, you should create a separate environment and add the new package there as well. For example:
```bash
# Add the new package to your main poetry env
poetry add some-package
# Add the same package to the CPU-only env dedicated to CI
conda create -y -n lerobot-ci python=3.10
conda activate lerobot-ci
cd .github/poetry/cpu
poetry add some-package
```
### Run tests locally
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
On Mac:
```bash
brew install git-lfs
git lfs install
```
On Ubuntu:
```bash
sudo apt-get install git-lfs
git lfs install
```
Pull artifacts if they're not in [tests/data](tests/data)
```bash
git lfs pull
```
When adding a new dataset, mock it with
```bash
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
```
Run tests
```bash
DATA_DIR="tests/data" pytest -sx tests
```
### Add a new dataset
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then you can upload it to the hub with:
```bash
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
--repo-type dataset \
--revision v1.0
```
You will need to set the corresponding version as a default argument in your dataset class:
```python
version: str | None = "v1.0",
```
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
For instance, for [lerobot/pusht](https://huggingface.co/datasets/lerobot/pusht), we used:
```bash
HF_USER=lerobot
DATASET=pusht
```
If you want to improve an existing dataset, you can download it locally with:
```bash
mkdir -p data/$DATASET
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download ${HF_USER}/$DATASET \
--repo-type dataset \
--local-dir data/$DATASET \
--local-dir-use-symlinks=False \
--revision v1.0
```
Iterate on your code and dataset with:
```bash
DATA_DIR=data python train.py
```
Upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant):
```bash
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
--repo-type dataset \
--revision v1.1 \
--delete "*"
```
Then you will need to set the corresponding version as a default argument in your dataset class:
```python
version: str | None = "v1.1",
```
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
## Profile
Finally, you might want to mock the dataset if you need to update the unit tests as well:
```bash
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
```
**Example**
### Add a pretrained policy
Once you have trained a policy you may upload it to the HuggingFace hub.
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
Secondly, assuming you have trained a policy, you need:
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
```
to_upload
├── config.yaml
├── model.pt
└── stats.pth
```
With the folder prepared, run the following with a desired revision ID.
```bash
huggingface-cli upload $HUB_ID to_upload --revision $REVISION_ID
```
If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers):
```bash
huggingface-cli upload $HUB_ID to_upload
```
See `eval.py` for an example of how a user may use your policy.
### Improve your code with profiling
An example of a code snippet to profile the evaluation of a policy:
```python
from torch.profiler import profile, record_function, ProfilerActivity
@@ -102,124 +373,12 @@ with profile(
with record_function("eval_policy"):
for i in range(num_episodes):
prof.step()
# insert code to profile, potentially whole body of eval_policy function
```
```bash
python lerobot/scripts/eval.py \
pretrained_model_path=/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt \
--config outputs/pusht/.hydra/config.yaml \
pretrained_model_path=outputs/pusht/model.pt \
eval_episodes=7
```
## Contribute
**Style**
```
# install if needed
pre-commit install
# apply style and linter checks before git commit
pre-commit run -a
```
**Adding dependencies (temporary)**
Right now, for the CI to work, whenever a new dependency is added it needs to be also added to the cpu env, eg:
```
# Run in this directory, adds the package to the main env with cuda
poetry add some-package
# Adds the same package to the cpu env
cd .github/poetry/cpu && poetry add some-package
```
**Tests**
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
On Mac:
```
brew install git-lfs
git lfs install
```
On Ubuntu:
```
sudo apt-get install git-lfs
git lfs install
```
Pull artifacts if they're not in [tests/data](tests/data)
```
git lfs pull
```
When adding a new dataset, mock it with
```
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
```
Run tests
```
DATA_DIR="tests/data" pytest -sx tests
```
**Datasets**
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
```
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Then you can upload it to the hub with:
```
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
--repo-type dataset \
--revision v1.0
```
You will need to set the corresponding version as a default argument in your dataset class:
```python
version: str | None = "v1.0",
```
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used:
```
HF_USER=cadene
DATASET=pusht
```
If you want to improve an existing dataset, you can download it locally with:
```
mkdir -p data/$DATASET
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download ${HF_USER}/$DATASET \
--repo-type dataset \
--local-dir data/$DATASET \
--local-dir-use-symlinks=False \
--revision v1.0
```
Iterate on your code and dataset with:
```
DATA_DIR=data python train.py
```
Upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant):
```
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
--repo-type dataset \
--revision v1.1 \
--delete "*"
```
Then you will need to set the corresponding version as a default argument in your dataset class:
```python
version: str | None = "v1.1",
```
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
Finally, you might want to mock the dataset if you need to update the unit tests as well:
```
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
```

View File

@@ -0,0 +1,24 @@
import os
from torchrl.data.replay_buffers import SamplerWithoutReplacement
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
# we use this sampler to sample 1 frame after the other
sampler = SamplerWithoutReplacement(shuffle=False)
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR"))
video_paths = render_dataset(
dataset,
out_dir="outputs/visualize_dataset/example",
max_num_samples=300,
fps=50,
)
print(video_paths)
# ['outputs/visualize_dataset/example/episode_0.mp4']

View File

@@ -0,0 +1,39 @@
"""
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
"""
from pathlib import Path
from huggingface_hub import snapshot_download
from lerobot.common.utils import init_hydra_config
from lerobot.scripts.eval import eval
# Get a pretrained policy from the hub.
hub_id = "lerobot/diffusion_policy_pusht_image"
folder = Path(snapshot_download(hub_id))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# folder = Path("outputs/train/example_pusht_diffusion")
config_path = folder / "config.yaml"
weights_path = folder / "model.pt"
stats_path = folder / "stats.pth" # normalization stats
# Override some config parameters to do with evaluation.
overrides = [
f"policy.pretrained_model_path={weights_path}",
"eval_episodes=10",
"rollout_batch_size=10",
"device=cuda",
]
# Create a Hydra config.
cfg = init_hydra_config(config_path, overrides)
# Evaluate the policy and save the outputs including metrics and videos.
eval(
cfg,
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
)

View File

@@ -0,0 +1,55 @@
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
Once you have trained a model with this script, you can try to evaluate it on
examples/2_evaluate_pretrained_policy.py
"""
import os
from pathlib import Path
import torch
from omegaconf import OmegaConf
from tqdm import trange
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.utils import init_hydra_config
output_directory = Path("outputs/train/example_pusht_diffusion")
os.makedirs(output_directory, exist_ok=True)
overrides = [
"env=pusht",
"policy=diffusion",
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
"offline_steps=5000",
"log_freq=250",
"device=cuda",
]
cfg = init_hydra_config("lerobot/configs/default.yaml", overrides)
policy = DiffusionPolicy(
cfg=cfg.policy,
cfg_device=cfg.device,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
policy.train()
offline_buffer = make_offline_buffer(cfg)
for offline_step in trange(cfg.offline_steps):
train_info = policy.update(offline_buffer, offline_step)
if offline_step % cfg.log_freq == 0:
print(train_info)
# Save the policy, configuration, and normalization stats for later use.
policy.save_pretrained(output_directory / "model.pt")
OmegaConf.save(cfg, output_directory / "config.yaml")
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")

View File

@@ -1 +1,59 @@
"""
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
Example:
```python
import lerobot
print(lerobot.available_envs)
print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets_per_env)
print(lerobot.available_datasets)
print(lerobot.available_policies)
```
Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
"""
from lerobot.__version__ import __version__ # noqa: F401
available_envs = [
"aloha",
"pusht",
"simxarm",
]
available_tasks_per_env = {
"aloha": [
"sim_insertion",
"sim_transfer_cube",
],
"pusht": ["pusht"],
"simxarm": ["lift"],
}
available_datasets_per_env = {
"aloha": [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
],
"pusht": ["pusht"],
"simxarm": ["xarm_lift_medium"],
}
available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]]
available_policies = [
"act",
"diffusion",
"tdmpc",
]

View File

@@ -1,4 +1,4 @@
""" To enable `lerobot.__version__` """
"""To enable `lerobot.__version__`"""
from importlib.metadata import PackageNotFoundError, version

View File

@@ -9,32 +9,68 @@ import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from torchrl.envs.transforms.transforms import Compose
HF_USER = "lerobot"
class AbstractDataset(TensorDictReplayBuffer):
"""
AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning.
This class is designed to be subclassed by concrete implementations that specify particular types of datasets.
These implementations can vary based on the source of the data, the environment the data pertains to,
or the specific kind of data manipulation applied.
Note:
- `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational
functionality for storing and retrieving `TensorDict`-like data.
- `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported.
It is expected that these variants correspond to a HuggingFace dataset on the hub.
For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants:
- [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
- [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
- [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
- [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
"""
available_datasets: list[str] | None = None
class AbstractExperienceReplay(TensorDictReplayBuffer):
def __init__(
self,
dataset_id: str,
version: str | None = None,
batch_size: int = None,
batch_size: int | None = None,
*,
shuffle: bool = True,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
sampler: Sampler | None = None,
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
):
assert (
self.available_datasets is not None
), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute."
assert (
dataset_id in self.available_datasets
), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}."
self.dataset_id = dataset_id
self.version = version
self.shuffle = shuffle
self.root = root
self.root = root if root is None else Path(root)
if self.root is not None and self.version is not None:
logging.warning(
@@ -106,7 +142,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
if self.root is None:
self.data_dir = Path(
snapshot_download(
repo_id=f"cadene/{self.dataset_id}", repo_type="dataset", revision=self.version
repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version
)
)
else:

View File

@@ -9,11 +9,11 @@ import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
from lerobot.common.datasets.abstract import AbstractDataset
DATASET_IDS = [
"aloha_sim_insertion_human",
@@ -80,24 +80,24 @@ def download(data_dir, dataset_id):
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
class AlohaExperienceReplay(AbstractExperienceReplay):
class AlohaDataset(AbstractDataset):
available_datasets = DATASET_IDS
def __init__(
self,
dataset_id: str,
version: str | None = "v1.1",
batch_size: int = None,
version: str | None = "v1.2",
batch_size: int | None = None,
*,
shuffle: bool = True,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
sampler: Sampler | None = None,
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
):
assert dataset_id in DATASET_IDS
super().__init__(
dataset_id,
version,

View File

@@ -5,7 +5,7 @@ from pathlib import Path
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
from lerobot.common.envs.transforms import NormalizeTransform, Prod
from lerobot.common.transforms import NormalizeTransform, Prod
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
@@ -14,7 +14,13 @@ DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_offline_buffer(
cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None
cfg,
overwrite_sampler=None,
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
normalize=True,
overwrite_batch_size=None,
overwrite_prefetch=None,
stats_path=None,
):
if cfg.policy.balanced_sampling:
assert cfg.online_steps > 0
@@ -59,27 +65,24 @@ def make_offline_buffer(
sampler = overwrite_sampler
if cfg.env.name == "simxarm":
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmDataset
clsfunc = SimxarmExperienceReplay
dataset_id = f"xarm_{cfg.env.task}_medium"
clsfunc = SimxarmDataset
elif cfg.env.name == "pusht":
from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.pusht import PushtDataset
clsfunc = PushtExperienceReplay
dataset_id = "pusht"
clsfunc = PushtDataset
elif cfg.env.name == "aloha":
from lerobot.common.datasets.aloha import AlohaExperienceReplay
from lerobot.common.datasets.aloha import AlohaDataset
clsfunc = AlohaExperienceReplay
dataset_id = f"aloha_{cfg.env.task}"
clsfunc = AlohaDataset
else:
raise ValueError(cfg.env.name)
offline_buffer = clsfunc(
dataset_id=dataset_id,
dataset_id=cfg.dataset_id,
sampler=sampler,
batch_size=batch_size,
root=DATA_DIR,
@@ -95,13 +98,15 @@ def make_offline_buffer(
else:
img_keys = offline_buffer.image_keys
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
# we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following)
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
# we only normalize the state and action, since the images are usually normalized inside the model for
# now (except for tdmpc: see the following)
in_keys = [("observation", "state"), ("action")]
if cfg.policy.name == "tdmpc":
@@ -122,7 +127,7 @@ def make_offline_buffer(
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
offline_buffer.set_transform(transforms)
offline_buffer.set_transform(transforms)
if not overwrite_sampler:
index = torch.arange(0, offline_buffer.num_samples, 1)

View File

@@ -9,11 +9,11 @@ import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
from lerobot.common.datasets.abstract import AbstractDataset
from lerobot.common.datasets.utils import download_and_extract_zip
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
@@ -83,20 +83,22 @@ def add_tee(
return body
class PushtExperienceReplay(AbstractExperienceReplay):
class PushtDataset(AbstractDataset):
available_datasets = ["pusht"]
def __init__(
self,
dataset_id: str,
version: str | None = "v1.1",
batch_size: int = None,
version: str | None = "v1.2",
batch_size: int | None = None,
*,
shuffle: bool = True,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
sampler: Sampler | None = None,
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
):
super().__init__(

View File

@@ -8,12 +8,12 @@ import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import (
SliceSampler,
Sampler,
)
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
from lerobot.common.datasets.abstract import AbstractDataset
def download():
@@ -32,7 +32,7 @@ def download():
Path(download_path).unlink()
class SimxarmExperienceReplay(AbstractExperienceReplay):
class SimxarmDataset(AbstractDataset):
available_datasets = [
"xarm_lift_medium",
]
@@ -40,16 +40,16 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
def __init__(
self,
dataset_id: str,
version: str | None = None,
batch_size: int = None,
version: str | None = "v1.1",
batch_size: int | None = None,
*,
shuffle: bool = True,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
sampler: Sampler | None = None,
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
):
super().__init__(
@@ -67,11 +67,11 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
)
def _download_and_preproc_obsolete(self):
assert self.root is not None
# assert self.root is not None
# TODO(rcadene): finish download
download()
# download()
dataset_path = self.root / f"{self.dataset_id}_raw" / "buffer.pkl"
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
@@ -105,15 +105,19 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
"frame_id": torch.arange(0, num_frames, 1),
("next", "observation", "image"): next_image,
("next", "observation", "state"): next_state,
("next", "observation", "reward"): next_reward,
("next", "observation", "done"): next_done,
("next", "reward"): next_reward,
("next", "done"): next_done,
},
batch_size=num_frames,
)
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = episode[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
td_data = (
episode[0]
.expand(total_frames)
.memmap_like(self.root / f"{self.dataset_id}" / "replay_buffer")
)
td_data[idx0:idx1] = episode

View File

@@ -4,8 +4,24 @@ from typing import Optional
from tensordict import TensorDict
from torchrl.envs import EnvBase
from lerobot.common.utils import set_global_seed
class AbstractEnv(EnvBase):
"""
Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
"""
name: str | None = None # same name should be used to instantiate the environment in factory.py
available_tasks: list[str] | None = None # for instance: sim_insertion, sim_transfer_cube, pusht, lift
def __init__(
self,
task,
@@ -19,6 +35,14 @@ class AbstractEnv(EnvBase):
num_prev_action=0,
):
super().__init__(device=device, batch_size=[])
assert self.name is not None, "Subclasses of `AbstractEnv` should set the `name` class attribute."
assert (
self.available_tasks is not None
), "Subclasses of `AbstractEnv` should set the `available_tasks` class attribute."
assert (
task in self.available_tasks
), f"The provided task ({task}) is not on the list of available tasks {self.available_tasks}."
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
@@ -34,7 +58,13 @@ class AbstractEnv(EnvBase):
self._make_env()
self._make_spec()
self._current_seed = self.set_seed(seed)
# self._next_seed will be used for the next reset. It is recommended that when self.set_seed is called
# you store the return value in self._next_seed (it will be a new randomly generated seed).
self._next_seed = seed
# Don't store the result of this in self._next_seed, as we want to make sure that the first time
# self._reset is called, we use seed.
self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
@@ -59,4 +89,4 @@ class AbstractEnv(EnvBase):
raise NotImplementedError("Abstract method")
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError("Abstract method")
set_global_seed(seed)

View File

@@ -29,12 +29,14 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import (
TransferCubeEndEffectorTask,
)
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
from lerobot.common.utils import set_seed
from lerobot.common.utils import set_global_seed
_has_gym = importlib.util.find_spec("gym") is not None
_has_gym = importlib.util.find_spec("gymnasium") is not None
class AlohaEnv(AbstractEnv):
name = "aloha"
available_tasks = ["sim_insertion", "sim_transfer_cube"]
_reset_warning_issued = False
def __init__(
@@ -63,7 +65,7 @@ class AlohaEnv(AbstractEnv):
def _make_env(self):
if not _has_gym:
raise ImportError("Cannot import gym.")
raise ImportError("Cannot import gymnasium.")
if not self.from_pixels:
raise NotImplementedError()
@@ -126,9 +128,8 @@ class AlohaEnv(AbstractEnv):
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
AlohaEnv._reset_warning_issued = True
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
# Seed the environment and update the seed to be used for the next reset.
self._next_seed = self.set_seed(self._next_seed)
# TODO(rcadene): do not use global variable for this
if "sim_transfer_cube" in self.task:
@@ -137,8 +138,6 @@ class AlohaEnv(AbstractEnv):
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
raw_obs = self._env.reset()
# TODO(rcadene): add assert
# assert self._current_seed == self._env._seed
obs = self._format_raw_obs(raw_obs.observation)
@@ -207,7 +206,7 @@ class AlohaEnv(AbstractEnv):
if self.from_pixels:
if isinstance(self.image_size, int):
image_shape = (3, self.image_size, self.image_size)
elif OmegaConf.is_list(self.image_size):
elif OmegaConf.is_list(self.image_size) or isinstance(self.image_size, list):
assert len(self.image_size) == 3 # c h w
assert self.image_size[0] == 3 # c is RGB
image_shape = tuple(self.image_size)
@@ -293,7 +292,7 @@ class AlohaEnv(AbstractEnv):
)
def _set_seed(self, seed: Optional[int]):
set_seed(seed)
set_global_seed(seed)
# TODO(rcadene): seed the env
# self._env.seed(seed)
logging.warning("Aloha env is not seeded")

View File

@@ -17,7 +17,7 @@ def make_env(cfg, transform=None):
}
if cfg.env.name == "simxarm":
from lerobot.common.envs.simxarm import SimxarmEnv
from lerobot.common.envs.simxarm.env import SimxarmEnv
kwargs["task"] = cfg.env.task
clsfunc = SimxarmEnv

View File

@@ -16,12 +16,14 @@ from torchrl.data.tensor_specs import (
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.utils import set_seed
from lerobot.common.utils import set_global_seed
_has_gym = importlib.util.find_spec("gym") is not None
_has_gym = importlib.util.find_spec("gymnasium") is not None
class PushtEnv(AbstractEnv):
name = "pusht"
available_tasks = ["pusht"]
_reset_warning_issued = False
def __init__(
@@ -50,7 +52,7 @@ class PushtEnv(AbstractEnv):
def _make_env(self):
if not _has_gym:
raise ImportError("Cannot import gym.")
raise ImportError("Cannot import gymnasium.")
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
@@ -106,11 +108,9 @@ class PushtEnv(AbstractEnv):
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
PushtEnv._reset_warning_issued = True
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
# Seed the environment and update the seed to be used for the next reset.
self._next_seed = self.set_seed(self._next_seed)
raw_obs = self._env.reset()
assert self._current_seed == self._env._seed
obs = self._format_raw_obs(raw_obs)
@@ -239,5 +239,7 @@ class PushtEnv(AbstractEnv):
)
def _set_seed(self, seed: Optional[int]):
set_seed(seed)
# Set global seed.
set_global_seed(seed)
# Set PushTImageEnv seed as it relies on it's own internal _seed attribute.
self._env.seed(seed)

View File

@@ -1,14 +1,14 @@
import collections
import cv2
import gym
import gymnasium as gym
import numpy as np
import pygame
import pymunk
import pymunk.pygame_util
import shapely.geometry as sg
import skimage.transform as st
from gym import spaces
from gymnasium import spaces
from pymunk.vec2d import Vec2d
from lerobot.common.envs.pusht.pymunk_override import DrawOptions
@@ -33,7 +33,7 @@ class PushTEnv(gym.Env):
def __init__(
self,
legacy=False,
legacy=True, # compatibility with original
block_cog=None,
damping=None,
render_action=True,

View File

@@ -1,5 +1,5 @@
import numpy as np
from gym import spaces
from gymnasium import spaces
from lerobot.common.envs.pusht.pusht_env import PushTEnv
@@ -7,7 +7,8 @@ from lerobot.common.envs.pusht.pusht_env import PushTEnv
class PushTImageEnv(PushTEnv):
metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
def __init__(self, legacy=False, block_cog=None, damping=None, render_size=96):
# Note: legacy defaults to True for compatibility with original
def __init__(self, legacy=True, block_cog=None, damping=None, render_size=96):
super().__init__(
legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False
)

View File

@@ -1,4 +1,5 @@
import importlib
import logging
from collections import deque
from typing import Optional
@@ -15,15 +16,17 @@ from torchrl.data.tensor_specs import (
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.utils import set_seed
from lerobot.common.utils import set_global_seed
MAX_NUM_ACTIONS = 4
_has_gym = importlib.util.find_spec("gym") is not None
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
_has_gym = importlib.util.find_spec("gymnasium") is not None
class SimxarmEnv(AbstractEnv):
name = "simxarm"
available_tasks = ["lift"]
def __init__(
self,
task,
@@ -49,13 +52,12 @@ class SimxarmEnv(AbstractEnv):
)
def _make_env(self):
if not _has_simxarm:
raise ImportError("Cannot import simxarm.")
if not _has_gym:
raise ImportError("Cannot import gym.")
raise ImportError("Cannot import gymnasium.")
import gym
from simxarm import TASKS
import gymnasium
from lerobot.common.envs.simxarm.simxarm import TASKS
if self.task not in TASKS:
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
@@ -63,7 +65,7 @@ class SimxarmEnv(AbstractEnv):
self._env = TASKS[self.task]["env"]()
num_actions = len(TASKS[self.task]["action_space"])
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
if "w" not in TASKS[self.task]["action_space"]:
self._action_padding[-1] = 1.0
@@ -84,7 +86,7 @@ class SimxarmEnv(AbstractEnv):
else:
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
obs = TensorDict(obs, batch_size=[])
# obs = TensorDict(obs, batch_size=[])
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
@@ -229,5 +231,7 @@ class SimxarmEnv(AbstractEnv):
)
def _set_seed(self, seed: Optional[int]):
set_seed(seed)
self._env.seed(seed)
set_global_seed(seed)
self._seed = seed
# TODO(aliberts): change self._reset so that it takes in a seed value
logging.warning("simxarm env is not properly seeded")

View File

@@ -0,0 +1,166 @@
from collections import OrderedDict, deque
import gymnasium as gym
import numpy as np
from gymnasium.wrappers import TimeLimit
from lerobot.common.envs.simxarm.simxarm.tasks.base import Base as Base
from lerobot.common.envs.simxarm.simxarm.tasks.lift import Lift
from lerobot.common.envs.simxarm.simxarm.tasks.peg_in_box import PegInBox
from lerobot.common.envs.simxarm.simxarm.tasks.push import Push
from lerobot.common.envs.simxarm.simxarm.tasks.reach import Reach
TASKS = OrderedDict(
(
(
"reach",
{
"env": Reach,
"action_space": "xyz",
"episode_length": 50,
"description": "Reach a target location with the end effector",
},
),
(
"push",
{
"env": Push,
"action_space": "xyz",
"episode_length": 50,
"description": "Push a cube to a target location",
},
),
(
"peg_in_box",
{
"env": PegInBox,
"action_space": "xyz",
"episode_length": 50,
"description": "Insert a peg into a box",
},
),
(
"lift",
{
"env": Lift,
"action_space": "xyzw",
"episode_length": 50,
"description": "Lift a cube above a height threshold",
},
),
)
)
class SimXarmWrapper(gym.Wrapper):
"""
A wrapper for the SimXarm environments. This wrapper is used to
convert the action and observation spaces to the correct format.
"""
def __init__(self, env, task, obs_mode, image_size, action_repeat, frame_stack=1, channel_last=False):
super().__init__(env)
self._env = env
self.obs_mode = obs_mode
self.image_size = image_size
self.action_repeat = action_repeat
self.frame_stack = frame_stack
self._frames = deque([], maxlen=frame_stack)
self.channel_last = channel_last
self._max_episode_steps = task["episode_length"] // action_repeat
image_shape = (
(image_size, image_size, 3 * frame_stack)
if channel_last
else (3 * frame_stack, image_size, image_size)
)
if obs_mode == "state":
self.observation_space = env.observation_space["observation"]
elif obs_mode == "rgb":
self.observation_space = gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8)
elif obs_mode == "all":
self.observation_space = gym.spaces.Dict(
state=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32),
rgb=gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8),
)
else:
raise ValueError(f"Unknown obs_mode {obs_mode}. Must be one of [rgb, all, state]")
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(len(task["action_space"]),))
self.action_padding = np.zeros(4 - len(task["action_space"]), dtype=np.float32)
if "w" not in task["action_space"]:
self.action_padding[-1] = 1.0
def _render_obs(self):
obs = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
if not self.channel_last:
obs = obs.transpose(2, 0, 1)
return obs.copy()
def _update_frames(self, reset=False):
pixels = self._render_obs()
self._frames.append(pixels)
if reset:
for _ in range(1, self.frame_stack):
self._frames.append(pixels)
assert len(self._frames) == self.frame_stack
def transform_obs(self, obs, reset=False):
if self.obs_mode == "state":
return obs["observation"]
elif self.obs_mode == "rgb":
self._update_frames(reset=reset)
rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
return rgb_obs
elif self.obs_mode == "all":
self._update_frames(reset=reset)
rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
return OrderedDict((("rgb", rgb_obs), ("state", self.robot_state)))
else:
raise ValueError(f"Unknown obs_mode {self.obs_mode}. Must be one of [rgb, all, state]")
def reset(self):
return self.transform_obs(self._env.reset(), reset=True)
def step(self, action):
action = np.concatenate([action, self.action_padding])
reward = 0.0
for _ in range(self.action_repeat):
obs, r, done, info = self._env.step(action)
reward += r
return self.transform_obs(obs), reward, done, info
def render(self, mode="rgb_array", width=384, height=384, **kwargs):
return self._env.render(mode, width=width, height=height)
@property
def state(self):
return self._env.robot_state
def make(task, obs_mode="state", image_size=84, action_repeat=1, frame_stack=1, channel_last=False, seed=0):
"""
Create a new environment.
Args:
task (str): The task to create an environment for. Must be one of:
- 'reach'
- 'push'
- 'peg-in-box'
- 'lift'
obs_mode (str): The observation mode to use. Must be one of:
- 'state': Only state observations
- 'rgb': RGB images
- 'all': RGB images and state observations
image_size (int): The size of the image observations
action_repeat (int): The number of times to repeat the action
seed (int): The random seed to use
Returns:
gym.Env: The environment
"""
if task not in TASKS:
raise ValueError(f"Unknown task {task}. Must be one of {list(TASKS.keys())}")
env = TASKS[task]["env"]()
env = TimeLimit(env, TASKS[task]["episode_length"])
env = SimXarmWrapper(env, TASKS[task], obs_mode, image_size, action_repeat, frame_stack, channel_last)
env.seed(seed)
return env

View File

@@ -0,0 +1,53 @@
<?xml version="1.0" encoding="utf-8"?>
<mujoco>
<compiler angle="radian" coordinate="local" meshdir="mesh" texturedir="texture"></compiler>
<size nconmax="2000" njmax="500"/>
<option timestep="0.002">
<flag warmstart="enable"></flag>
</option>
<include file="shared.xml"></include>
<worldbody>
<body name="floor0" pos="0 0 0">
<geom name="floorgeom0" pos="1.2 -2.0 0" size="20.0 20.0 1" type="plane" condim="3" material="floor_mat"></geom>
</body>
<include file="xarm.xml"></include>
<body pos="0.75 0 0.6325" name="pedestal0">
<geom name="pedestalgeom0" size="0.1 0.1 0.01" pos="0.32 0.27 0" type="box" mass="2000" material="pedestal_mat"></geom>
<site pos="0.30 0.30 0" size="0.075 0.075 0.002" type="box" name="robotmountsite0" rgba="0.55 0.54 0.53 1" />
</body>
<body pos="1.5 0.075 0.3425" name="table0">
<geom name="tablegeom0" size="0.3 0.6 0.2" pos="0 0 0" type="box" material="table_mat" density="2000" friction="1 1 1"></geom>
</body>
<body name="object" pos="1.405 0.3 0.58625">
<joint name="object_joint0" type="free" limited="false"></joint>
<geom size="0.035 0.035 0.035" type="box" name="object0" material="block_mat" density="50000" condim="4" friction="1 1 1" solimp="1 1 1" solref="0.02 1"></geom>
<site name="object_site" pos="0 0 0" size="0.035 0.035 0.035" rgba="1 0 0 0" type="box"></site>
</body>
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="1.65 0 10" dir="-0.57 -0.57 -0.57" name="light0"></light>
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="0 -4 4" dir="0 1 -0.1" name="light1"></light>
<light directional="true" ambient="0.05 0.05 0.05" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="2.13 1.6 2.5" name="light2"></light>
<light pos="0 0 2" dir="0.2 0.2 -0.8" directional="true" diffuse="0.3 0.3 0.3" castshadow="false" name="light3"></light>
<camera fovy="50" name="camera0" pos="0.9559 1.0 1.1" euler="-1.1 -0.6 3.4" />
</worldbody>
<equality>
<connect body2="left_finger" body1="left_inner_knuckle" anchor="0.0 0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
<connect body2="right_finger" body1="right_inner_knuckle" anchor="0.0 -0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
<joint joint1="left_inner_knuckle_joint" joint2="right_inner_knuckle_joint"></joint>
</equality>
<actuator>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="left_inner_knuckle_joint" gear="200.0"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="right_inner_knuckle_joint" gear="200.0"/>
</actuator>
</mujoco>

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:21fb81ae7fba19e3c6b2d2ca60c8051712ba273357287eb5a397d92d61c7a736
size 1211434

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:be68ce180d11630a667a5f37f4dffcc3feebe4217d4bb3912c813b6d9ca3ec66
size 3284

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2c6448552bf6b1c4f17334d686a5320ce051bcdfe31431edf69303d8a570d1de
size 3284

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:748b9e197e6521914f18d1f6383a36f211136b3f33f2ad2a8c11b9f921c2cf86
size 6284

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a44756eb72f9c214cb37e61dc209cd7073fdff3e4271a7423476ef6fd090d2d4
size 242684

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e8e48692ad26837bb3d6a97582c89784d09948fc09bfe4e5a59017859ff04dac
size 366284

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:501665812b08d67e764390db781e839adc6896a9540301d60adf606f57648921
size 22284

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:34b541122df84d2ef5fcb91b715eb19659dc15ad8d44a191dde481f780265636
size 184184

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:61e641cd47c169ecef779683332e00e4914db729bf02dfb61bfbe69351827455
size 225584

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9e2798e7946dd70046c95455d5ba96392d0b54a6069caba91dc4ca66e1379b42
size 237084

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c757fee95f873191a0633c355c07a360032960771cabbd7593a6cdb0f1ffb089
size 243684

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:715ad5787c5dab57589937fd47289882707b5e1eb997e340d567785b02f4ec90
size 229084

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:85b320aa420497827223d16d492bba8de091173374e361396fc7a5dad7bdb0cb
size 399384

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:97115d848fbf802cb770cd9be639ae2af993103b9d9bbb0c50c943c738a36f18
size 231684

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f6fcbc18258090eb56c21cfb17baa5ae43abc98b1958cd366f3a73b9898fc7f0
size 2106184

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c5dee87c7f37baf554b8456ebfe0b3e8ed0b22b8938bd1add6505c2ad6d32c7d
size 242684

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b41dd2c2c550281bf78d7cc6fa117b14786700e5c453560a0cb5fd6dfa0ffb3e
size 366284

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:75ca1107d0a42a0f03802a9a49cab48419b31851ee8935f8f1ca06be1c1c91e8
size 22284

View File

@@ -0,0 +1,74 @@
<?xml version="1.0" encoding="utf-8"?>
<mujoco>
<compiler angle="radian" coordinate="local" meshdir="mesh" texturedir="texture"></compiler>
<size nconmax="2000" njmax="500"/>
<option timestep="0.001">
<flag warmstart="enable"></flag>
</option>
<include file="shared.xml"></include>
<worldbody>
<body name="floor0" pos="0 0 0">
<geom name="floorgeom0" pos="1.2 -2.0 0" size="1.0 10.0 1" type="plane" condim="3" material="floor_mat"></geom>
</body>
<include file="xarm.xml"></include>
<body pos="0.75 0 0.6325" name="pedestal0">
<geom name="pedestalgeom0" size="0.1 0.1 0.01" pos="0.32 0.27 0" type="box" mass="2000" material="pedestal_mat"></geom>
<site pos="0.30 0.30 0" size="0.075 0.075 0.002" type="box" name="robotmountsite0" rgba="0.55 0.54 0.53 1" />
</body>
<body pos="1.5 0.075 0.3425" name="table0">
<geom name="tablegeom0" size="0.3 0.6 0.2" pos="0 0 0" type="box" material="table_mat" density="2000" friction="1 0.005 0.0002"></geom>
</body>
<body name="box0" pos="1.605 0.25 0.55">
<joint name="box_joint0" type="free" limited="false"></joint>
<site name="box_site" pos="0 0.075 -0.01" size="0.02" rgba="0 0 0 0" type="sphere"></site>
<geom name="box_side0" pos="0 0 0" size="0.065 0.002 0.04" type= "box" rgba="0.8 0.1 0.1 1" mass ="1" condim="4" />
<geom name="box_side1" pos="0 0.149 0" size="0.065 0.002 0.04" type="box" rgba="0.9 0.2 0.2 1" mass ="2" condim="4" />
<geom name="box_side2" pos="0.064 0.074 0" size="0.002 0.075 0.04" type="box" rgba="0.8 0.1 0.1 1" mass ="2" condim="4" />
<geom name="box_side3" pos="-0.064 0.074 0" size="0.002 0.075 0.04" type="box" rgba="0.9 0.2 0.2 1" mass ="2" condim="4" />
<geom name="box_side4" pos="-0 0.074 -0.038" size="0.065 0.075 0.002" type="box" rgba="0.5 0 0 1" mass ="2" condim="4"/>
</body>
<body name="object0" pos="1.4 0.25 0.65">
<joint name="object_joint0" type="free" limited="false"></joint>
<geom name="object_target0" type="cylinder" pos="0 0 -0.05" size="0.03 0.035" rgba="0.6 0.8 0.5 1" mass ="0.1" condim="3" />
<site name="object_site" pos="0 0 -0.05" size="0.0325 0.0375" rgba="0 0 0 0" type="cylinder"></site>
<body name="B0" pos="0 0 0" euler="0 0 0 ">
<joint name="B0:joint" type="slide" limited="true" axis="0 0 1" damping="0.05" range="0.0001 0.0001001" solimpfriction="0.98 0.98 0.95" frictionloss="1"></joint>
<geom type="capsule" size="0.002 0.03" rgba="0 0 0 1" mass="0.001" condim="4"/>
<body name="B1" pos="0 0 0.04" euler="0 3.14 0 ">
<joint name="B1:joint1" type="hinge" axis="1 0 0" range="-0.1 0.1" frictionloss="1"></joint>
<joint name="B1:joint2" type="hinge" axis="0 1 0" range="-0.1 0.1" frictionloss="1"></joint>
<joint name="B1:joint3" type="hinge" axis="0 0 1" range="-0.1 0.1" frictionloss="1"></joint>
<geom type="capsule" size="0.002 0.004" rgba="1 0 0 0" mass="0.001" condim="4"/>
</body>
</body>
</body>
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="1.65 0 10" dir="-0.57 -0.57 -0.57" name="light0"></light>
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="0 -4 4" dir="0 1 -0.1" name="light1"></light>
<light directional="true" ambient="0.05 0.05 0.05" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="2.13 1.6 2.5" name="light2"></light>
<light pos="0 0 2" dir="0.2 0.2 -0.8" directional="true" diffuse="0.3 0.3 0.3" castshadow="false" name="light3"></light>
<camera fovy="50" name="camera0" pos="0.9559 1.0 1.1" euler="-1.1 -0.6 3.4" />
</worldbody>
<equality>
<connect body2="left_finger" body1="left_inner_knuckle" anchor="0.0 0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
<connect body2="right_finger" body1="right_inner_knuckle" anchor="0.0 -0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
<weld body1="right_hand" body2="B1" solimp="0.99 0.99 0.99" solref="0.02 1"></weld>
<joint joint1="left_inner_knuckle_joint" joint2="right_inner_knuckle_joint"></joint>
</equality>
<actuator>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="left_inner_knuckle_joint" gear="200.0"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="right_inner_knuckle_joint" gear="200.0"/>
</actuator>
</mujoco>

View File

@@ -0,0 +1,54 @@
<?xml version="1.0" encoding="utf-8"?>
<mujoco>
<compiler angle="radian" coordinate="local" meshdir="mesh" texturedir="texture"></compiler>
<size nconmax="2000" njmax="500"/>
<option timestep="0.002">
<flag warmstart="enable"></flag>
</option>
<include file="shared.xml"></include>
<worldbody>
<body name="floor0" pos="0 0 0">
<geom name="floorgeom0" pos="1.2 -2.0 0" size="1.0 10.0 1" type="plane" condim="3" material="floor_mat"></geom>
<site name="target0" pos="1.565 0.3 0.545" size="0.0475 0.001" rgba="1 0 0 1" type="cylinder"></site>
</body>
<include file="xarm.xml"></include>
<body pos="0.75 0 0.6325" name="pedestal0">
<geom name="pedestalgeom0" size="0.1 0.1 0.01" pos="0.32 0.27 0" type="box" mass="2000" material="pedestal_mat"></geom>
<site pos="0.30 0.30 0" size="0.075 0.075 0.002" type="box" name="robotmountsite0" rgba="0.55 0.54 0.53 1" />
</body>
<body pos="1.5 0.075 0.3425" name="table0">
<geom name="tablegeom0" size="0.3 0.6 0.2" pos="0 0 0" type="box" material="table_mat" density="2000" friction="1 0.005 0.0002"></geom>
</body>
<body name="object" pos="1.655 0.3 0.68">
<joint name="object_joint0" type="free" limited="false"></joint>
<geom size="0.024 0.024 0.024" type="box" name="object" material="block_mat" density="50000" condim="4" friction="1 1 1" solimp="1 1 1" solref="0.02 1"></geom>
<site name="object_site" pos="0 0 0" size="0.024 0.024 0.024" rgba="0 0 0 0" type="box"></site>
</body>
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="1.65 0 10" dir="-0.57 -0.57 -0.57" name="light0"></light>
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="0 -4 4" dir="0 1 -0.1" name="light1"></light>
<light directional="true" ambient="0.05 0.05 0.05" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="2.13 1.6 2.5" name="light2"></light>
<light pos="0 0 2" dir="0.2 0.2 -0.8" directional="true" diffuse="0.3 0.3 0.3" castshadow="false" name="light3"></light>
<camera fovy="50" name="camera0" pos="0.9559 1.0 1.1" euler="-1.1 -0.6 3.4" />
</worldbody>
<equality>
<connect body2="left_finger" body1="left_inner_knuckle" anchor="0.0 0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
<connect body2="right_finger" body1="right_inner_knuckle" anchor="0.0 -0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
<joint joint1="left_inner_knuckle_joint" joint2="right_inner_knuckle_joint"></joint>
</equality>
<actuator>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="left_inner_knuckle_joint" gear="200.0"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="right_inner_knuckle_joint" gear="200.0"/>
</actuator>
</mujoco>

View File

@@ -0,0 +1,48 @@
<?xml version="1.0" encoding="utf-8"?>
<mujoco>
<compiler angle="radian" coordinate="local" meshdir="mesh" texturedir="texture"></compiler>
<size nconmax="2000" njmax="500"/>
<option timestep="0.002">
<flag warmstart="enable"></flag>
</option>
<include file="shared.xml"></include>
<worldbody>
<body name="floor0" pos="0 0 0">
<geom name="floorgeom0" pos="1.2 -2.0 0" size="1.0 10.0 1" type="plane" condim="3" material="floor_mat"></geom>
<site name="target0" pos="1.605 0.3 0.58" size="0.0475 0.001" rgba="1 0 0 1" type="cylinder"></site>
</body>
<include file="xarm.xml"></include>
<body pos="0.75 0 0.6325" name="pedestal0">
<geom name="pedestalgeom0" size="0.1 0.1 0.01" pos="0.32 0.27 0" type="box" mass="2000" material="pedestal_mat"></geom>
<site pos="0.30 0.30 0" size="0.075 0.075 0.002" type="box" name="robotmountsite0" rgba="0.55 0.54 0.53 1" />
</body>
<body pos="1.5 0.075 0.3425" name="table0">
<geom name="tablegeom0" size="0.3 0.6 0.2" pos="0 0 0" type="box" material="table_mat" density="2000" friction="1 0.005 0.0002"></geom>
</body>
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="1.65 0 10" dir="-0.57 -0.57 -0.57" name="light0"></light>
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="0 -4 4" dir="0 1 -0.1" name="light1"></light>
<light directional="true" ambient="0.05 0.05 0.05" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="2.13 1.6 2.5" name="light2"></light>
<light pos="0 0 2" dir="0.2 0.2 -0.8" directional="true" diffuse="0.3 0.3 0.3" castshadow="false" name="light3"></light>
<camera fovy="50" name="camera0" pos="0.9559 1.0 1.1" euler="-1.1 -0.6 3.4" />
</worldbody>
<equality>
<connect body2="left_finger" body1="left_inner_knuckle" anchor="0.0 0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
<connect body2="right_finger" body1="right_inner_knuckle" anchor="0.0 -0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
<joint joint1="left_inner_knuckle_joint" joint2="right_inner_knuckle_joint"></joint>
</equality>
<actuator>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="left_inner_knuckle_joint" gear="200.0"/>
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="right_inner_knuckle_joint" gear="200.0"/>
</actuator>
</mujoco>

View File

@@ -0,0 +1,51 @@
<mujoco>
<asset>
<texture type="skybox" builtin="gradient" rgb1="0.0 0.0 0.0" rgb2="0.0 0.0 0.0" width="32" height="32"></texture>
<material name="floor_mat" specular="0" shininess="0.0" reflectance="0" rgba="0.043 0.055 0.051 1"></material>
<material name="table_mat" specular="0.2" shininess="0.2" reflectance="0" rgba="1 1 1 1"></material>
<material name="pedestal_mat" specular="0.35" shininess="0.5" reflectance="0" rgba="0.705 0.585 0.405 1"></material>
<material name="block_mat" specular="0.5" shininess="0.9" reflectance="0.05" rgba="0.373 0.678 0.627 1"></material>
<material name="robot0:geomMat" shininess="0.03" specular="0.4"></material>
<material name="robot0:gripper_finger_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
<material name="robot0:gripper_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
<material name="background:gripper_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
<material name="robot0:arm_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
<material name="robot0:head_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
<material name="robot0:torso_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
<material name="robot0:base_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
<mesh name="link_base" file="link_base.stl" />
<mesh name="link1" file="link1.stl" />
<mesh name="link2" file="link2.stl" />
<mesh name="link3" file="link3.stl" />
<mesh name="link4" file="link4.stl" />
<mesh name="link5" file="link5.stl" />
<mesh name="link6" file="link6.stl" />
<mesh name="link7" file="link7.stl" />
<mesh name="base_link" file="base_link.stl" />
<mesh name="left_outer_knuckle" file="left_outer_knuckle.stl" />
<mesh name="left_finger" file="left_finger.stl" />
<mesh name="left_inner_knuckle" file="left_inner_knuckle.stl" />
<mesh name="right_outer_knuckle" file="right_outer_knuckle.stl" />
<mesh name="right_finger" file="right_finger.stl" />
<mesh name="right_inner_knuckle" file="right_inner_knuckle.stl" />
</asset>
<equality>
<weld body1="robot0:mocap2" body2="link7" solimp="0.9 0.95 0.001" solref="0.02 1"></weld>
</equality>
<default>
<joint armature="1" damping="0.1" limited="true"/>
<default class="robot0:blue">
<geom rgba="0.086 0.506 0.767 1.0"></geom>
</default>
<default class="robot0:grey">
<geom rgba="0.356 0.361 0.376 1.0"></geom>
</default>
</default>
</mujoco>

View File

@@ -0,0 +1,88 @@
<mujoco model="xarm7">
<body mocap="true" name="robot0:mocap2" pos="0 0 0">
<geom conaffinity="0" contype="0" pos="0 0 0" rgba="0 0.5 0 0" size="0.005 0.005 0.005" type="box"></geom>
<geom conaffinity="0" contype="0" pos="0 0 0" rgba="0.5 0 0 0" size="1 0.005 0.005" type="box"></geom>
<geom conaffinity="0" contype="0" pos="0 0 0" rgba="0 0 0.5 0" size="0.005 1 0.001" type="box"></geom>
<geom conaffinity="0" contype="0" pos="0 0 0" rgba="0.5 0.5 0 0" size="0.005 0.005 1" type="box"></geom>
</body>
<body name="link0" pos="1.09 0.28 0.655">
<geom name="bb" type="mesh" mesh="link_base" material="robot0:base_mat" rgba="1 1 1 1"/>
<body name="link1" pos="0 0 0.267">
<inertial pos="-0.0042142 0.02821 -0.0087788" quat="0.917781 -0.277115 0.0606681 0.277858" mass="0.42603" diaginertia="0.00144551 0.00137757 0.000823511" />
<joint name="joint1" pos="0 0 0" axis="0 0 1" limited="true" range="-6.28319 6.28319" damping="10" frictionloss="1" />
<geom name="j1" type="mesh" mesh="link1" material="robot0:arm_mat" rgba="1 1 1 1"/>
<body name="link2" pos="0 0 0" quat="0.707105 -0.707108 0 0">
<inertial pos="-3.3178e-05 -0.12849 0.026337" quat="0.447793 0.894132 -0.00224061 0.00218314" mass="0.56095" diaginertia="0.00319151 0.00311598 0.000980804" />
<joint name="joint2" pos="0 0 0" axis="0 0 1" limited="true" range="-2.059 2.0944" damping="10" frictionloss="1" />
<geom name="j2" type="mesh" mesh="link2" material="robot0:head_mat" rgba="1 1 1 1"/>
<body name="link3" pos="0 -0.293 0" quat="0.707105 0.707108 0 0">
<inertial pos="0.04223 -0.023258 -0.0096674" quat="0.883205 0.339803 0.323238 0.000542237" mass="0.44463" diaginertia="0.00133227 0.00119126 0.000780475" />
<joint name="joint3" pos="0 0 0" axis="0 0 1" limited="true" range="-6.28319 6.28319" damping="5" frictionloss="1" />
<geom name="j3" type="mesh" mesh="link3" material="robot0:gripper_mat" rgba="1 1 1 1"/>
<body name="link4" pos="0.0525 0 0" quat="0.707105 0.707108 0 0">
<inertial pos="0.067148 -0.10732 0.024479" quat="0.0654142 0.483317 -0.738663 0.465298" mass="0.52387" diaginertia="0.00288984 0.00282705 0.000894409" />
<joint name="joint4" pos="0 0 0" axis="0 0 1" limited="true" range="-0.19198 3.927" damping="5" frictionloss="1" />
<geom name="j4" type="mesh" mesh="link4" material="robot0:arm_mat" rgba="1 1 1 1"/>
<body name="link5" pos="0.0775 -0.3425 0" quat="0.707105 0.707108 0 0">
<inertial pos="-0.00023397 0.036705 -0.080064" quat="0.981064 -0.19003 0.00637998 0.0369004" mass="0.18554" diaginertia="0.00099553 0.000988613 0.000247126" />
<joint name="joint5" pos="0 0 0" axis="0 0 1" limited="true" range="-6.28319 6.28319" damping="5" frictionloss="1" />
<geom name="j5" type="mesh" material="robot0:gripper_mat" rgba="1 1 1 1" mesh="link5" />
<body name="link6" pos="0 0 0" quat="0.707105 0.707108 0 0">
<inertial pos="0.058911 0.028469 0.0068428" quat="-0.188705 0.793535 0.166088 0.554173" mass="0.31344" diaginertia="0.000827892 0.000768871 0.000386708" />
<joint name="joint6" pos="0 0 0" axis="0 0 1" limited="true" range="-1.69297 3.14159" damping="2" frictionloss="1" />
<geom name="j6" type="mesh" material="robot0:gripper_mat" rgba="1 1 1 1" mesh="link6" />
<body name="link7" pos="0.076 0.097 0" quat="0.707105 -0.707108 0 0">
<inertial pos="-0.000420033 -0.00287433 0.0257078" quat="0.999372 -0.0349129 -0.00605634 0.000551744" mass="0.85624" diaginertia="0.00137671 0.00118744 0.000514968" />
<joint name="joint7" pos="0 0 0" axis="0 0 1" limited="true" range="-6.28319 6.28319" damping="2" frictionloss="1" />
<geom name="j8" material="robot0:gripper_mat" type="mesh" rgba="0.753 0.753 0.753 1" mesh="link7" />
<geom name="j9" material="robot0:gripper_mat" type="mesh" rgba="1 1 1 1" mesh="base_link" />
<site name="grasp" pos="0 0 0.16" rgba="1 0 0 0" type="sphere" size="0.01" group="1"/>
<body name="left_outer_knuckle" pos="0 0.035 0.059098">
<inertial pos="0 0.021559 0.015181" quat="0.47789 0.87842 0 0" mass="0.033618" diaginertia="1.9111e-05 1.79089e-05 1.90167e-06" />
<joint name="drive_joint" pos="0 0 0" axis="1 0 0" limited="true" range="0 0.85" />
<geom type="mesh" rgba="0 0 0 1" conaffinity="1" contype="0" mesh="left_outer_knuckle" />
<body name="left_finger" pos="0 0.035465 0.042039">
<inertial pos="0 -0.016413 0.029258" quat="0.697634 0.115353 -0.115353 0.697634" mass="0.048304" diaginertia="1.88037e-05 1.7493e-05 3.56792e-06" />
<joint name="left_finger_joint" pos="0 0 0" axis="-1 0 0" limited="true" range="0 0.85" />
<geom name="j10" material="robot0:gripper_finger_mat" type="mesh" rgba="0 0 0 1" conaffinity="3" contype="2" mesh="left_finger" friction='1.5 1.5 1.5' solref='0.01 1' solimp='0.99 0.99 0.01'/>
<body name="right_hand" pos="0 -0.03 0.05" quat="-0.7071 0 0 0.7071">
<site name="ee" pos="0 0 0" rgba="0 0 1 0" type="sphere" group="1"/>
<site name="ee_x" pos="0 0 0" size="0.005 .1" quat="0.707105 0.707108 0 0 " rgba="1 0 0 0" type="cylinder" group="1"/>
<site name="ee_z" pos="0 0 0" size="0.005 .1" quat="0.707105 0 0 0.707108" rgba="0 0 1 0" type="cylinder" group="1"/>
<site name="ee_y" pos="0 0 0" size="0.005 .1" quat="0.707105 0 0.707108 0 " rgba="0 1 0 0" type="cylinder" group="1"/>
</body>
</body>
</body>
<body name="left_inner_knuckle" pos="0 0.02 0.074098">
<inertial pos="1.86601e-06 0.0220468 0.0261335" quat="0.664139 -0.242732 0.242713 0.664146" mass="0.0230126" diaginertia="8.34216e-06 6.0949e-06 2.75601e-06" />
<joint name="left_inner_knuckle_joint" pos="0 0 0" axis="1 0 0" limited="true" range="0 0.85" />
<geom type="mesh" rgba="0 0 0 1" conaffinity="1" contype="0" mesh="left_inner_knuckle" friction='1.5 1.5 1.5' solref='0.01 1' solimp='0.99 0.99 0.01'/>
</body>
<body name="right_outer_knuckle" pos="0 -0.035 0.059098">
<inertial pos="0 -0.021559 0.015181" quat="0.87842 0.47789 0 0" mass="0.033618" diaginertia="1.9111e-05 1.79089e-05 1.90167e-06" />
<joint name="right_outer_knuckle_joint" pos="0 0 0" axis="-1 0 0" limited="true" range="0 0.85" />
<geom type="mesh" rgba="0 0 0 1" conaffinity="1" contype="0" mesh="right_outer_knuckle" />
<body name="right_finger" pos="0 -0.035465 0.042039">
<inertial pos="0 0.016413 0.029258" quat="0.697634 -0.115356 0.115356 0.697634" mass="0.048304" diaginertia="1.88038e-05 1.7493e-05 3.56779e-06" />
<joint name="right_finger_joint" pos="0 0 0" axis="1 0 0" limited="true" range="0 0.85" />
<geom name="j11" material="robot0:gripper_finger_mat" type="mesh" rgba="0 0 0 1" conaffinity="3" contype="2" mesh="right_finger" friction='1.5 1.5 1.5' solref='0.01 1' solimp='0.99 0.99 0.01'/>
<body name="left_hand" pos="0 0.03 0.05" quat="-0.7071 0 0 0.7071">
<site name="ee_2" pos="0 0 0" rgba="1 0 0 0" type="sphere" size="0.01" group="1"/>
</body>
</body>
</body>
<body name="right_inner_knuckle" pos="0 -0.02 0.074098">
<inertial pos="1.866e-06 -0.022047 0.026133" quat="0.66415 0.242702 -0.242721 0.664144" mass="0.023013" diaginertia="8.34209e-06 6.0949e-06 2.75601e-06" />
<joint name="right_inner_knuckle_joint" pos="0 0 0" axis="-1 0 0" limited="true" range="0 0.85" />
<geom type="mesh" rgba="0 0 0 1" conaffinity="1" contype="0" mesh="right_inner_knuckle" friction='1.5 1.5 1.5' solref='0.01 1' solimp='0.99 0.99 0.01'/>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</body>
</mujoco>

View File

@@ -0,0 +1,145 @@
import os
import mujoco
import numpy as np
from gymnasium_robotics.envs import robot_env
from lerobot.common.envs.simxarm.simxarm.tasks import mocap
class Base(robot_env.MujocoRobotEnv):
"""
Superclass for all simxarm environments.
Args:
xml_name (str): name of the xml environment file
gripper_rotation (list): initial rotation of the gripper (given as a quaternion)
"""
def __init__(self, xml_name, gripper_rotation=None):
if gripper_rotation is None:
gripper_rotation = [0, 1, 0, 0]
self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32)
self.center_of_table = np.array([1.655, 0.3, 0.63625])
self.max_z = 1.2
self.min_z = 0.2
super().__init__(
model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"),
n_substeps=20,
n_actions=4,
initial_qpos={},
)
@property
def dt(self):
return self.n_substeps * self.model.opt.timestep
@property
def eef(self):
return self._utils.get_site_xpos(self.model, self.data, "grasp")
@property
def obj(self):
return self._utils.get_site_xpos(self.model, self.data, "object_site")
@property
def robot_state(self):
gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint")
return np.concatenate([self.eef, gripper_angle])
def is_success(self):
return NotImplementedError()
def get_reward(self):
raise NotImplementedError()
def _sample_goal(self):
raise NotImplementedError()
def get_obs(self):
return self._get_obs()
def _step_callback(self):
self._mujoco.mj_forward(self.model, self.data)
def _limit_gripper(self, gripper_pos, pos_ctrl):
if gripper_pos[0] > self.center_of_table[0] - 0.105 + 0.15:
pos_ctrl[0] = min(pos_ctrl[0], 0)
if gripper_pos[0] < self.center_of_table[0] - 0.105 - 0.3:
pos_ctrl[0] = max(pos_ctrl[0], 0)
if gripper_pos[1] > self.center_of_table[1] + 0.3:
pos_ctrl[1] = min(pos_ctrl[1], 0)
if gripper_pos[1] < self.center_of_table[1] - 0.3:
pos_ctrl[1] = max(pos_ctrl[1], 0)
if gripper_pos[2] > self.max_z:
pos_ctrl[2] = min(pos_ctrl[2], 0)
if gripper_pos[2] < self.min_z:
pos_ctrl[2] = max(pos_ctrl[2], 0)
return pos_ctrl
def _apply_action(self, action):
assert action.shape == (4,)
action = action.copy()
pos_ctrl, gripper_ctrl = action[:3], action[3]
pos_ctrl = self._limit_gripper(
self._utils.get_site_xpos(self.model, self.data, "grasp"), pos_ctrl
) * (1 / self.n_substeps)
gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl])
mocap.apply_action(
self.model,
self._model_names,
self.data,
np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl]),
)
def _render_callback(self):
self._mujoco.mj_forward(self.model, self.data)
def _reset_sim(self):
self.data.time = self.initial_time
self.data.qpos[:] = np.copy(self.initial_qpos)
self.data.qvel[:] = np.copy(self.initial_qvel)
self._sample_goal()
self._mujoco.mj_step(self.model, self.data, nstep=10)
return True
def _set_gripper(self, gripper_pos, gripper_rotation):
self._utils.set_mocap_pos(self.model, self.data, "robot0:mocap", gripper_pos)
self._utils.set_mocap_quat(self.model, self.data, "robot0:mocap", gripper_rotation)
self._utils.set_joint_qpos(self.model, self.data, "right_outer_knuckle_joint", 0)
self.data.qpos[10] = 0.0
self.data.qpos[12] = 0.0
def _env_setup(self, initial_qpos):
for name, value in initial_qpos.items():
self.data.set_joint_qpos(name, value)
mocap.reset(self.model, self.data)
mujoco.mj_forward(self.model, self.data)
self._sample_goal()
mujoco.mj_forward(self.model, self.data)
def reset(self):
self._reset_sim()
return self._get_obs()
def step(self, action):
assert action.shape == (4,)
assert self.action_space.contains(action), "{!r} ({}) invalid".format(action, type(action))
self._apply_action(action)
self._mujoco.mj_step(self.model, self.data, nstep=2)
self._step_callback()
obs = self._get_obs()
reward = self.get_reward()
done = False
info = {"is_success": self.is_success(), "success": self.is_success()}
return obs, reward, done, info
def render(self, mode="rgb_array", width=384, height=384):
self._render_callback()
# HACK
self.model.vis.global_.offwidth = width
self.model.vis.global_.offheight = height
return self.mujoco_renderer.render(mode)
def close(self):
if self.mujoco_renderer is not None:
self.mujoco_renderer.close()

View File

@@ -0,0 +1,100 @@
import numpy as np
from lerobot.common.envs.simxarm.simxarm import Base
class Lift(Base):
def __init__(self):
self._z_threshold = 0.15
super().__init__("lift")
@property
def z_target(self):
return self._init_z + self._z_threshold
def is_success(self):
return self.obj[2] >= self.z_target
def get_reward(self):
reach_dist = np.linalg.norm(self.obj - self.eef)
reach_dist_xy = np.linalg.norm(self.obj[:-1] - self.eef[:-1])
pick_completed = self.obj[2] >= (self.z_target - 0.01)
obj_dropped = (self.obj[2] < (self._init_z + 0.005)) and (reach_dist > 0.02)
# Reach
if reach_dist < 0.05:
reach_reward = -reach_dist + max(self._action[-1], 0) / 50
elif reach_dist_xy < 0.05:
reach_reward = -reach_dist
else:
z_bonus = np.linalg.norm(np.linalg.norm(self.obj[-1] - self.eef[-1]))
reach_reward = -reach_dist - 2 * z_bonus
# Pick
if pick_completed and not obj_dropped:
pick_reward = self.z_target
elif (reach_dist < 0.1) and (self.obj[2] > (self._init_z + 0.005)):
pick_reward = min(self.z_target, self.obj[2])
else:
pick_reward = 0
return reach_reward / 100 + pick_reward
def _get_obs(self):
eef_velp = self._utils.get_site_xvelp(self.model, self.data, "grasp") * self.dt
gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint")
eef = self.eef - self.center_of_table
obj = self.obj - self.center_of_table
obj_rot = self._utils.get_joint_qpos(self.model, self.data, "object_joint0")[-4:]
obj_velp = self._utils.get_site_xvelp(self.model, self.data, "object_site") * self.dt
obj_velr = self._utils.get_site_xvelr(self.model, self.data, "object_site") * self.dt
obs = np.concatenate(
[
eef,
eef_velp,
obj,
obj_rot,
obj_velp,
obj_velr,
eef - obj,
np.array(
[
np.linalg.norm(eef - obj),
np.linalg.norm(eef[:-1] - obj[:-1]),
self.z_target,
self.z_target - obj[-1],
self.z_target - eef[-1],
]
),
gripper_angle,
],
axis=0,
)
return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": eef}
def _sample_goal(self):
# Gripper
gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3)
super()._set_gripper(gripper_pos, self.gripper_rotation)
# Object
object_pos = self.center_of_table - np.array([0.15, 0.10, 0.07])
object_pos[0] += self.np_random.uniform(-0.05, 0.05, size=1)
object_pos[1] += self.np_random.uniform(-0.05, 0.05, size=1)
object_qpos = self._utils.get_joint_qpos(self.model, self.data, "object_joint0")
object_qpos[:3] = object_pos
self._utils.set_joint_qpos(self.model, self.data, "object_joint0", object_qpos)
self._init_z = object_pos[2]
# Goal
return object_pos + np.array([0, 0, self._z_threshold])
def reset(self):
self._action = np.zeros(4)
return super().reset()
def step(self, action):
self._action = action.copy()
return super().step(action)

View File

@@ -0,0 +1,67 @@
# import mujoco_py
import mujoco
import numpy as np
def apply_action(model, model_names, data, action):
if model.nmocap > 0:
pos_action, gripper_action = np.split(action, (model.nmocap * 7,))
if data.ctrl is not None:
for i in range(gripper_action.shape[0]):
data.ctrl[i] = gripper_action[i]
pos_action = pos_action.reshape(model.nmocap, 7)
pos_delta, quat_delta = pos_action[:, :3], pos_action[:, 3:]
reset_mocap2body_xpos(model, model_names, data)
data.mocap_pos[:] = data.mocap_pos + pos_delta
data.mocap_quat[:] = data.mocap_quat + quat_delta
def reset(model, data):
if model.nmocap > 0 and model.eq_data is not None:
for i in range(model.eq_data.shape[0]):
# if sim.model.eq_type[i] == mujoco_py.const.EQ_WELD:
if model.eq_type[i] == mujoco.mjtEq.mjEQ_WELD:
# model.eq_data[i, :] = np.array([0., 0., 0., 1., 0., 0., 0.])
model.eq_data[i, :] = np.array(
[
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
]
)
# sim.forward()
mujoco.mj_forward(model, data)
def reset_mocap2body_xpos(model, model_names, data):
if model.eq_type is None or model.eq_obj1id is None or model.eq_obj2id is None:
return
# For all weld constraints
for eq_type, obj1_id, obj2_id in zip(model.eq_type, model.eq_obj1id, model.eq_obj2id, strict=False):
# if eq_type != mujoco_py.const.EQ_WELD:
if eq_type != mujoco.mjtEq.mjEQ_WELD:
continue
# body2 = model.body_id2name(obj2_id)
body2 = model_names.body_id2name[obj2_id]
if body2 == "B0" or body2 == "B9" or body2 == "B1":
continue
mocap_id = model.body_mocapid[obj1_id]
if mocap_id != -1:
# obj1 is the mocap, obj2 is the welded body
body_idx = obj2_id
else:
# obj2 is the mocap, obj1 is the welded body
mocap_id = model.body_mocapid[obj2_id]
body_idx = obj1_id
assert mocap_id != -1
data.mocap_pos[mocap_id][:] = data.xpos[body_idx]
data.mocap_quat[mocap_id][:] = data.xquat[body_idx]

View File

@@ -0,0 +1,86 @@
import numpy as np
from lerobot.common.envs.simxarm.simxarm import Base
class PegInBox(Base):
def __init__(self):
super().__init__("peg_in_box")
def _reset_sim(self):
self._act_magnitude = 0
super()._reset_sim()
for _ in range(10):
self._apply_action(np.array([0, 0, 0, 1], dtype=np.float32))
self.sim.step()
@property
def box(self):
return self.sim.data.get_site_xpos("box_site")
def is_success(self):
return np.linalg.norm(self.obj - self.box) <= 0.05
def get_reward(self):
dist_xy = np.linalg.norm(self.obj[:2] - self.box[:2])
dist_xyz = np.linalg.norm(self.obj - self.box)
return float(dist_xy <= 0.045) * (2 - 6 * dist_xyz) - 0.2 * np.square(self._act_magnitude) - dist_xy
def _get_obs(self):
eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt
gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint")
eef, box = self.eef - self.center_of_table, self.box - self.center_of_table
obj = self.obj - self.center_of_table
obj_rot = self.sim.data.get_joint_qpos("object_joint0")[-4:]
obj_velp = self.sim.data.get_site_xvelp("object_site") * self.dt
obj_velr = self.sim.data.get_site_xvelr("object_site") * self.dt
obs = np.concatenate(
[
eef,
eef_velp,
box,
obj,
obj_rot,
obj_velp,
obj_velr,
eef - box,
eef - obj,
obj - box,
np.array(
[
np.linalg.norm(eef - box),
np.linalg.norm(eef - obj),
np.linalg.norm(obj - box),
gripper_angle,
]
),
],
axis=0,
)
return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": box}
def _sample_goal(self):
# Gripper
gripper_pos = np.array([1.280, 0.295, 0.9]) + self.np_random.uniform(-0.05, 0.05, size=3)
super()._set_gripper(gripper_pos, self.gripper_rotation)
# Object
object_pos = gripper_pos - np.array([0, 0, 0.06]) + self.np_random.uniform(-0.005, 0.005, size=3)
object_qpos = self.sim.data.get_joint_qpos("object_joint0")
object_qpos[:3] = object_pos
self.sim.data.set_joint_qpos("object_joint0", object_qpos)
# Box
box_pos = np.array([1.61, 0.18, 0.58])
box_pos[:2] += self.np_random.uniform(-0.11, 0.11, size=2)
box_qpos = self.sim.data.get_joint_qpos("box_joint0")
box_qpos[:3] = box_pos
self.sim.data.set_joint_qpos("box_joint0", box_qpos)
return self.box
def step(self, action):
self._act_magnitude = np.linalg.norm(action[:3])
return super().step(action)

View File

@@ -0,0 +1,78 @@
import numpy as np
from lerobot.common.envs.simxarm.simxarm import Base
class Push(Base):
def __init__(self):
super().__init__("push")
def _reset_sim(self):
self._act_magnitude = 0
super()._reset_sim()
def is_success(self):
return np.linalg.norm(self.obj - self.goal) <= 0.05
def get_reward(self):
dist = np.linalg.norm(self.obj - self.goal)
penalty = self._act_magnitude**2
return -(dist + 0.15 * penalty)
def _get_obs(self):
eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt
gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint")
eef, goal = self.eef - self.center_of_table, self.goal - self.center_of_table
obj = self.obj - self.center_of_table
obj_rot = self.sim.data.get_joint_qpos("object_joint0")[-4:]
obj_velp = self.sim.data.get_site_xvelp("object_site") * self.dt
obj_velr = self.sim.data.get_site_xvelr("object_site") * self.dt
obs = np.concatenate(
[
eef,
eef_velp,
goal,
obj,
obj_rot,
obj_velp,
obj_velr,
eef - goal,
eef - obj,
obj - goal,
np.array(
[
np.linalg.norm(eef - goal),
np.linalg.norm(eef - obj),
np.linalg.norm(obj - goal),
gripper_angle,
]
),
],
axis=0,
)
return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": goal}
def _sample_goal(self):
# Gripper
gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3)
super()._set_gripper(gripper_pos, self.gripper_rotation)
# Object
object_pos = self.center_of_table - np.array([0.25, 0, 0.07])
object_pos[0] += self.np_random.uniform(-0.08, 0.08, size=1)
object_pos[1] += self.np_random.uniform(-0.08, 0.08, size=1)
object_qpos = self.sim.data.get_joint_qpos("object_joint0")
object_qpos[:3] = object_pos
self.sim.data.set_joint_qpos("object_joint0", object_qpos)
# Goal
self.goal = np.array([1.600, 0.200, 0.545])
self.goal[:2] += self.np_random.uniform(-0.1, 0.1, size=2)
self.sim.model.site_pos[self.sim.model.site_name2id("target0")] = self.goal
return self.goal
def step(self, action):
self._act_magnitude = np.linalg.norm(action[:3])
return super().step(action)

View File

@@ -0,0 +1,44 @@
import numpy as np
from lerobot.common.envs.simxarm.simxarm import Base
class Reach(Base):
def __init__(self):
super().__init__("reach")
def _reset_sim(self):
self._act_magnitude = 0
super()._reset_sim()
def is_success(self):
return np.linalg.norm(self.eef - self.goal) <= 0.05
def get_reward(self):
dist = np.linalg.norm(self.eef - self.goal)
penalty = self._act_magnitude**2
return -(dist + 0.15 * penalty)
def _get_obs(self):
eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt
gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint")
eef, goal = self.eef - self.center_of_table, self.goal - self.center_of_table
obs = np.concatenate(
[eef, eef_velp, goal, eef - goal, np.array([np.linalg.norm(eef - goal), gripper_angle])], axis=0
)
return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": goal}
def _sample_goal(self):
# Gripper
gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3)
super()._set_gripper(gripper_pos, self.gripper_rotation)
# Goal
self.goal = np.array([1.550, 0.287, 0.580])
self.goal[:2] += self.np_random.uniform(-0.125, 0.125, size=2)
self.sim.model.site_pos[self.sim.model.site_name2id("target0")] = self.goal
return self.goal
def step(self, action):
self._act_magnitude = np.linalg.norm(action[:3])
return super().step(action)

View File

@@ -5,6 +5,7 @@ from pathlib import Path
from omegaconf import OmegaConf
from termcolor import colored
from lerobot.common.policies.abstract import AbstractPolicy
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
@@ -67,11 +68,11 @@ class Logger:
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
def save_model(self, policy, identifier):
def save_model(self, policy: AbstractPolicy, identifier):
if self._save_model:
self._model_dir.mkdir(parents=True, exist_ok=True)
fp = self._model_dir / f"{str(identifier)}.pt"
policy.save(fp)
policy.save_pretrained(fp)
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name
artifact = self._wandb.Artifact(

View File

@@ -2,22 +2,45 @@ from collections import deque
import torch
from torch import Tensor, nn
from huggingface_hub import PyTorchModelHubMixin
class AbstractPolicy(nn.Module):
class AbstractPolicy(nn.Module, PyTorchModelHubMixin):
"""Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
documentation for more information.
The policy is a PyTorchModelHubMixin, which means that it can be saved and loaded from the Hugging Face Hub and/or to a local directory.
# Save policy weights to local directory
>>> policy.save_pretrained("my-awesome-policy")
# Push policy weights to the Hub
>>> policy.push_to_hub("my-awesome-policy")
# Download and initialize policy from the Hub
>>> policy = MyPolicy.from_pretrained("username/my-awesome-policy")
Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
- for classes inheriting from `AbstractDataset`: `available_datasets`
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
- for classes inheriting from `AbstractPolicy`: `name`
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
3. update variables in `tests/test_available.py` by importing your new class
"""
def __init__(self, n_action_steps: int | None):
name: str | None = None # same name should be used to instantiate the policy in factory.py
def __init__(self, n_action_steps: int | None = None):
"""
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
adds that dimension.
"""
super().__init__()
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
self.n_action_steps = n_action_steps
self.clear_action_queue()
@@ -25,10 +48,10 @@ class AbstractPolicy(nn.Module):
"""One step of the policy's learning algorithm."""
raise NotImplementedError("Abstract method")
def save(self, fp):
def save(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
torch.save(self.state_dict(), fp)
def load(self, fp):
def load(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
d = torch.load(fp)
self.load_state_dict(d)

View File

@@ -42,6 +42,8 @@ def kl_divergence(mu, logvar):
class ActionChunkingTransformerPolicy(AbstractPolicy):
name = "act"
def __init__(self, cfg, device, n_action_steps=1):
super().__init__(n_action_steps)
self.cfg = cfg
@@ -134,8 +136,8 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
def load(self, fp, device=None):
d = torch.load(fp, map_location=device)
self.load_state_dict(d)
def compute_loss(self, batch):

View File

@@ -1,6 +1,7 @@
"""
Various positional encodings for the transformer.
"""
import math
import torch

View File

@@ -6,6 +6,7 @@ Copy-paste from torch.nn.Transformer with modifications:
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
import copy
from typing import Optional

View File

@@ -3,6 +3,7 @@ Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import datetime
import os
import pickle

View File

@@ -32,7 +32,7 @@ assert len(unexpected_keys) == 0
Then in that same runtime you can also save the weights with the new aligned state_dict:
```
policy.save("weights.pt")
policy.save_pretrained("my-policy")
```
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.

View File

@@ -2,6 +2,7 @@
A collection of utilities for working with nested tensor structures consisting
of numpy arrays and torch tensors.
"""
import collections
import numpy as np

View File

@@ -13,6 +13,8 @@ from lerobot.common.utils import get_safe_torch_device
class DiffusionPolicy(AbstractPolicy):
name = "diffusion"
def __init__(
self,
cfg,
@@ -201,8 +203,8 @@ class DiffusionPolicy(AbstractPolicy):
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
def load(self, fp, device=None):
d = torch.load(fp, map_location=device)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)

View File

@@ -1,35 +1,53 @@
def make_policy(cfg):
""" Factory for policies
"""
from lerobot.common.policies.abstract import AbstractPolicy
def make_policy(cfg: dict) -> AbstractPolicy:
""" Instantiate a policy from the configuration.
Currently supports TD-MPC, Diffusion, and ACT: select the policy with cfg.policy.name: tdmpc, diffusion, act.
Args:
cfg: The configuration (DictConfig)
"""
policy_kwargs = {}
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPC
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
policy = TDMPC(cfg.policy, cfg.device)
policy_cls = TDMPCPolicy
policy_kwargs = {"cfg": cfg.policy, "device": cfg.device}
elif cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
policy = DiffusionPolicy(
cfg=cfg.policy,
cfg_device=cfg.device,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
policy_cls = DiffusionPolicy
policy_kwargs = {
"cfg": cfg.policy,
"cfg_device": cfg.device,
"cfg_noise_scheduler": cfg.noise_scheduler,
"cfg_rgb_model": cfg.rgb_model,
"cfg_obs_encoder": cfg.obs_encoder,
"cfg_optimizer": cfg.optimizer,
"cfg_ema": cfg.ema,
"n_action_steps": cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
}
elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy(
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
)
policy_cls = ActionChunkingTransformerPolicy
policy_kwargs = {"cfg": cfg.policy, "device": cfg.device, "n_action_steps": cfg.n_action_steps + cfg.n_latency_steps}
else:
raise ValueError(cfg.policy.name)
if cfg.policy.pretrained_model_path:
# policy.load(cfg.policy.pretrained_model_path, device=cfg.device)
policy = policy_cls.from_pretrained(cfg.policy.pretrained_model_path, map_location=cfg.device, **policy_kwargs)
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
@@ -38,6 +56,5 @@ def make_policy(cfg):
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.policy.pretrained_model_path)
return policy

View File

@@ -87,9 +87,11 @@ class TOLD(nn.Module):
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
class TDMPC(AbstractPolicy):
class TDMPCPolicy(AbstractPolicy):
"""Implementation of TD-MPC learning + inference."""
name = "tdmpc"
def __init__(self, cfg, device):
super().__init__(None)
self.action_dim = cfg.action_dim
@@ -120,9 +122,9 @@ class TDMPC(AbstractPolicy):
"""Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp)
def load(self, fp):
def load(self, fp, device=None):
"""Load a saved state dict from filepath into current agent."""
d = torch.load(fp)
d = torch.load(fp, map_location=device)
self.model.load_state_dict(d["model"])
self.model_target.load_state_dict(d["model_target"])

View File

@@ -1,9 +1,13 @@
import logging
import os.path as osp
import random
from datetime import datetime
from pathlib import Path
import hydra
import numpy as np
import torch
from omegaconf import DictConfig
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
@@ -26,7 +30,7 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
return device
def set_seed(seed):
def set_global_seed(seed):
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
@@ -63,3 +67,31 @@ def format_big_number(num):
num /= divisor
return num
def _relative_path_between(path1: Path, path2: Path) -> Path:
"""Returns path1 relative to path2."""
path1 = path1.absolute()
path2 = path2.absolute()
try:
return path1.relative_to(path2)
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
)
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
"""Initialize a Hydra config given only the path to the relevant config file.
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
"""
# TODO(alexander-soare): Resolve configs without Hydra initialization.
hydra.core.global_hydra.GlobalHydra.instance().clear()
# Hydra needs a path relative to this file.
hydra.initialize(
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent))
)
cfg = hydra.compose(Path(config_path).stem, overrides)
return cfg

View File

@@ -5,7 +5,7 @@ defaults:
hydra:
run:
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
dir: outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
job:
name: default
@@ -26,7 +26,10 @@ fps: ???
offline_prioritized_sampler: true
dataset_id: ???
n_action_steps: ???
n_obs_steps: ???
env: ???
policy: ???

View File

@@ -10,9 +10,11 @@ online_steps: 25000
fps: 50
dataset_id: aloha_sim_insertion_human
env:
name: aloha
task: sim_insertion_human
task: sim_insertion
from_pixels: True
pixels_only: False
image_size: [3, 480, 640]

View File

@@ -10,6 +10,8 @@ online_steps: 25000
fps: 10
dataset_id: pusht
env:
name: pusht
task: pusht

View File

@@ -9,6 +9,8 @@ online_steps: 25000
fps: 15
dataset_id: xarm_lift_medium
env:
name: simxarm
task: lift

View File

@@ -103,29 +103,3 @@ optimizer:
betas: [0.95, 0.999]
eps: 1.0e-8
weight_decay: 1.0e-6
training:
device: "cuda:0"
seed: 42
debug: False
resume: True
# optimization
# lr_scheduler: cosine
# lr_warmup_steps: 500
num_epochs: 8000
# gradient_accumulate_every: 1
# EMA destroys performance when used with BatchNorm
# replace BatchNorm with GroupNorm.
# use_ema: True
freeze_encoder: False
# training loop control
# in epochs
rollout_every: 50
checkpoint_every: 50
val_every: 1
sample_every: 5
# steps per epoch
max_train_steps: null
max_val_steps: null
# misc
tqdm_interval_sec: 1.0

View File

@@ -1,6 +1,7 @@
# @package _global_
n_action_steps: 1
n_obs_steps: 1
policy:
name: tdmpc

View File

@@ -1,14 +1,47 @@
"""Evaluate a policy on an environment by running rollouts and computing metrics.
The script may be run in one of two ways:
1. By providing the path to a config file with the --config argument.
2. By providing a HuggingFace Hub ID with the --hub-id argument. You may also provide a revision number with the
--revision argument.
In either case, it is possible to override config arguments by adding a list of config.key=value arguments.
Examples:
You have a specific config file to go with trained model weights, and want to run 10 episodes.
```
python lerobot/scripts/eval.py \
--config PATH/TO/FOLDER/config.yaml \
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
eval_episodes=10
```
You have a HuggingFace Hub ID, you know which revision you want, and want to run 10 episodes (note that in this case,
you don't need to specify which weights to use):
```
python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10
```
"""
import argparse
import json
import logging
import threading
import time
from typing import Tuple, Union
from datetime import datetime as dt
from pathlib import Path
import einops
import hydra
import imageio
import numpy as np
import torch
import tqdm
from huggingface_hub import snapshot_download
from tensordict.nn import TensorDictModule
from torchrl.envs import EnvBase
from torchrl.envs.batched_envs import BatchedEnvBase
@@ -18,7 +51,7 @@ from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import get_safe_torch_device, init_logging, set_seed
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
def write_video(video_path, stacked_frames, fps):
@@ -34,13 +67,26 @@ def eval_policy(
video_dir: Path = None,
fps: int = 15,
return_first_video: bool = False,
):
) -> Union[dict, Tuple[dict, torch.Tensor]]:
""" Evaluate a policy on an environment by running rollouts and computing metrics.
Args:
env: The environment to evaluate.
policy: The policy to evaluate.
num_episodes: The number of episodes to evaluate.
max_steps: The maximum number of steps per episode.
save_video: Whether to save videos of the evaluation episodes.
video_dir: The directory to save the videos.
fps: The frames per second for the videos.
return_first_video: Whether to return the first video as a tensor.
"""
if policy is not None:
policy.eval()
start = time.time()
sum_rewards = []
max_rewards = []
successes = []
seeds = []
threads = [] # for video saving threads
episode_counter = 0 # for saving the correct number of videos
@@ -53,11 +99,16 @@ def eval_policy(
if save_video or (return_first_video and i == 0): # noqa: B023
ep_frames.append(env.render()) # noqa: B023
# Clear the policy's action queue before the start of a new rollout.
if policy is not None:
policy.clear_action_queue()
if env.is_closed:
env.start() # needed to be able to get the seeds the first time as BatchedEnvs are lazy
seeds.extend(env._next_seed)
with torch.inference_mode():
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
if policy is not None:
policy.clear_action_queue()
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
@@ -65,8 +116,8 @@ def eval_policy(
callback=maybe_render_frame,
break_when_any_done=env.batch_size[0] == 1,
)
# Figure out where in each rollout sequence the first done condition was encountered (results after this won't
# be included).
# Figure out where in each rollout sequence the first done condition was encountered (results after
# this won't be included).
# Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
rollout_steps = rollout["next", "done"].shape[1]
@@ -107,24 +158,46 @@ def eval_policy(
for thread in threads:
thread.join()
info = {
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
info = { # TODO: change to dataclass
"per_episode": [
{
"episode_ix": i,
"sum_reward": sum_reward,
"max_reward": max_reward,
"success": success,
"seed": seed,
}
for i, (sum_reward, max_reward, success, seed) in enumerate(
zip(
sum_rewards[:num_episodes],
max_rewards[:num_episodes],
successes[:num_episodes],
seeds[:num_episodes],
strict=True,
)
)
],
"aggregated": {
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
},
}
if return_first_video:
return info, first_video
return info
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def eval_cli(cfg: dict):
eval(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
def eval(cfg: dict, out_dir=None, stats_path=None):
""" Evaluate a policy.
def eval(cfg: dict, out_dir=None):
Args:
cfg: The configuration (DictConfig).
out_dir: The directory to save the evaluation results (JSON file and videos)
stats_path: The path to the stats file.
"""
if out_dir is None:
raise NotImplementedError()
@@ -135,14 +208,15 @@ def eval(cfg: dict, out_dir=None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
set_global_seed(cfg.seed)
log_output_dir(out_dir)
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg)
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
logging.info("make_env")
logging.info("Making environment.")
env = make_env(cfg, transform=offline_buffer.transform)
if cfg.policy.pretrained_model_path:
@@ -156,7 +230,7 @@ def eval(cfg: dict, out_dir=None):
# when policy is None, rollout a random policy
policy = None
metrics = eval_policy(
info = eval_policy(
env,
policy=policy,
save_video=True,
@@ -165,10 +239,45 @@ def eval(cfg: dict, out_dir=None):
max_steps=cfg.env.episode_length,
num_episodes=cfg.eval_episodes,
)
print(metrics)
print(info["aggregated"])
# Save info
with open(Path(out_dir) / "eval_info.json", "w") as f:
json.dump(info, f, indent=2)
logging.info("End of eval")
if __name__ == "__main__":
eval_cli()
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--config", help="Path to a specific yaml config you want to use.")
group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.")
parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.")
parser.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
if args.config is not None:
# Note: For the config_path, Hydra wants a path relative to this script file.
cfg = init_hydra_config(args.config, args.overrides)
# TODO(alexander-soare): Save and load stats in trained model directory.
stats_path = None
elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
cfg = init_hydra_config(
folder / "config.yaml", [*args.overrides]
# folder / "config.yaml" # , [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
)
stats_path = folder / "stats.pth"
eval(
cfg,
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
)

View File

@@ -12,7 +12,7 @@ from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_seed
from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_global_seed
from lerobot.scripts.eval import eval_policy
@@ -122,7 +122,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
set_global_seed(cfg.seed)
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg)
@@ -183,7 +183,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
video_dir=Path(out_dir) / "eval",
save_video=True,
)
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval")
logging.info("Resume training")
@@ -224,7 +224,22 @@ def train(cfg: dict, out_dir=None, job_name=None):
policy=td_policy,
auto_cast_to_device=True,
)
assert len(rollout) <= cfg.env.episode_length
assert (
len(rollout.batch_size) == 2
), "2 dimensions expected: number of env in parallel x max number of steps during rollout"
num_parallel_env = rollout.batch_size[0]
if num_parallel_env != 1:
# TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests
raise NotImplementedError()
num_max_steps = rollout.batch_size[1]
assert num_max_steps <= cfg.env.episode_length
# reshape to have a list of steps to insert into online_buffer
rollout = rollout.reshape(num_parallel_env * num_max_steps)
# set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout)

View File

@@ -25,6 +25,7 @@ def visualize_dataset_cli(cfg: dict):
def cat_and_write_video(video_path, frames, fps):
# Expects images in [0, 255].
frames = torch.cat(frames)
assert frames.dtype == torch.uint8
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
@@ -45,44 +46,63 @@ def visualize_dataset(cfg: dict, out_dir=None):
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(
cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12
cfg,
overwrite_sampler=sampler,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
normalize=False,
overwrite_batch_size=1,
overwrite_prefetch=12,
)
logging.info("Start rendering episodes from offline buffer")
video_paths = render_dataset(offline_buffer, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
for video_path in video_paths:
logging.info(video_path)
def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
out_dir = Path(out_dir)
video_paths = []
threads = []
frames = {}
current_ep_idx = 0
logging.info(f"Visualizing episode {current_ep_idx}")
for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER):
for i in range(max_num_samples):
# TODO(rcadene): make it work with bsize > 1
ep_td = offline_buffer.sample(1)
ep_idx = ep_td["episode"][FIRST_FRAME].item()
# TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
new_episode = ep_idx != current_ep_idx
# TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
num_frames_left = offline_buffer._sampler._sample_list.numel()
episode_is_done = ep_idx != current_ep_idx
if new_episode:
logging.info(f"Visualizing episode {current_ep_idx}")
if episode_is_done:
logging.info(f"Rendering episode {current_ep_idx}")
for im_key in offline_buffer.image_keys:
if new_episode or no_more_frames:
# append last observed frames (the ones after last action taken)
frames[im_key].append(ep_td[("next", *im_key)])
video_dir = Path(out_dir) / "visualize_dataset"
video_dir.mkdir(parents=True, exist_ok=True)
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
# when first frame of episode, initialize frames dict
if im_key not in frames:
frames[im_key] = []
# add current frame to list of frames to render
frames[im_key].append(ep_td[im_key])
else:
# When episode has no more frame in its list of observation,
# one frame still remains. It is the result of the last action taken.
# It is stored in `"next"`, so we add it to the list of frames to render.
frames[im_key].append(ep_td["next"][im_key])
out_dir.mkdir(parents=True, exist_ok=True)
if len(offline_buffer.image_keys) > 1:
camera = im_key[-1]
video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4"
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
else:
video_path = video_dir / f"episode_{current_ep_idx}.mp4"
video_path = out_dir / f"episode_{current_ep_idx}.mp4"
video_paths.append(str(video_path))
thread = threading.Thread(
target=cat_and_write_video,
args=(str(video_path), frames[im_key], cfg.fps),
args=(str(video_path), frames[im_key], fps),
)
thread.start()
threads.append(thread)
@@ -92,19 +112,18 @@ def visualize_dataset(cfg: dict, out_dir=None):
# reset list of frames
del frames[im_key]
# append current cameras images to list of frames
if im_key not in frames:
frames[im_key] = []
frames[im_key].append(ep_td[im_key])
if no_more_frames:
if num_frames_left == 0:
logging.info("Ran out of frames")
break
if current_ep_idx == NUM_EPISODES_TO_RENDER:
break
for thread in threads:
thread.join()
logging.info("End of visualize_dataset")
return video_paths
if __name__ == "__main__":

Binary file not shown.

After

Width:  |  Height:  |  Size: 199 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 160 KiB

287
poetry.lock generated
View File

@@ -327,6 +327,35 @@ files = [
{file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"},
]
[[package]]
name = "cmake"
version = "3.29.0.1"
description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software"
optional = false
python-versions = ">=3.7"
files = [
{file = "cmake-3.29.0.1-py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:ec8b39fdeb75c48fd5a2894658a1ca75f94fb49b421c1f753d86d3e5d5e9f196"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ead7dc5176a6c6347b3fc19532c25ec328f9279b6213902ac930242334e7b621"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f7c7dabf3dd40cb830d2eded43d51c5a3737625bbc5ab6916041c04e352b74f8"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de8f55198f4a820daf2c57645a4bb8cd1064dc92d950ad95be14c5ffabc15bd4"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bb8aaa3419eafd466931e4dcc161d3e5e6a82730ab508c75946ff4fc883b3f6"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e956e4b6c2d8d6bbee399bf3c77e5b901d916fe8f35d6b2f58444d5892c4602b"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a8f7c8b07e6ab0dd444c5b74e658d5013ca0da456041029f734d751090bb7ec"},
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22493049b6383ea2baa7237a326c2914ab4a7b3e1642f4233245e3a34aae39f6"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:d09573411901fc8a3ede7433713c000ac7b81cda0d771874e9182770acf29eb4"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_i686.whl", hash = "sha256:c85e35ec572e54152154637f24d6bde316fbcf94dfea644bd8f22b1855a09abe"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:a36f3b19a5caa6c63aa70bfa0b262bae8d296e68c6e6d8e918edc5c51d952bf8"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:85a0e28eaaee311d50fcee60f730e5a44a65b3cefc556a1163bbabd7328acd60"},
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:58879ef15dd8344e1583a36cead794fb0fee13f78c590a56749283ac2e27d30a"},
{file = "cmake-3.29.0.1-py3-none-win32.whl", hash = "sha256:068a3e7461dd9e487f5f3f720ae8072c2eeb37239ebe4642c4ae29058d83347f"},
{file = "cmake-3.29.0.1-py3-none-win_amd64.whl", hash = "sha256:c097892b3653e2d2d41d055c80a9af029fdd24f9eff2ecb66576e4da8b85b1c7"},
{file = "cmake-3.29.0.1-py3-none-win_arm64.whl", hash = "sha256:1808071047cb49ed0fe2359e4c310b49880c2805cad4ea9f03c959f51b881ac7"},
{file = "cmake-3.29.0.1.tar.gz", hash = "sha256:ec49a7a4480959c229d9d2aa77f7885859c17a45fc66981aaf4551ceffb4d030"},
]
[package.extras]
test = ["coverage (>=4.2)", "pytest (>=3.0.3)", "pytest-cov (>=2.4.0)"]
[[package]]
name = "colorama"
version = "0.4.6"
@@ -339,72 +368,72 @@ files = [
]
[[package]]
name = "cython"
version = "3.0.9"
description = "The Cython compiler for writing C extensions in the Python language."
name = "coverage"
version = "7.4.4"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
python-versions = ">=3.8"
files = [
{file = "Cython-3.0.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:296bd30d4445ac61b66c9d766567f6e81a6e262835d261e903c60c891a6729d3"},
{file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f496b52845cb45568a69d6359a2c335135233003e708ea02155c10ce3548aa89"},
{file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:858c3766b9aa3ab8a413392c72bbab1c144a9766b7c7bfdef64e2e414363fa0c"},
{file = "Cython-3.0.9-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c0eb1e6ef036028a52525fd9a012a556f6dd4788a0e8755fe864ba0e70cde2ff"},
{file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c8191941073ea5896321de3c8c958fd66e5f304b0cd1f22c59edd0b86c4dd90d"},
{file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e32b016030bc72a8a22a1f21f470a2f57573761a4f00fbfe8347263f4fbdb9f1"},
{file = "Cython-3.0.9-cp310-cp310-win32.whl", hash = "sha256:d6f3ff1cd6123973fe03e0fb8ee936622f976c0c41138969975824d08886572b"},
{file = "Cython-3.0.9-cp310-cp310-win_amd64.whl", hash = "sha256:56f3b643dbe14449248bbeb9a63fe3878a24256664bc8c8ef6efd45d102596d8"},
{file = "Cython-3.0.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:35e6665a20d6b8a152d72b7fd87dbb2af6bb6b18a235b71add68122d594dbd41"},
{file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f92f4960c40ad027bd8c364c50db11104eadc59ffeb9e5b7f605ca2f05946e20"},
{file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38df37d0e732fbd9a2fef898788492e82b770c33d1e4ed12444bbc8a3b3f89c0"},
{file = "Cython-3.0.9-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad7fd88ebaeaf2e76fd729a8919fae80dab3d6ac0005e28494261d52ff347a8f"},
{file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1365d5f76bf4d19df3d19ce932584c9bb76e9fb096185168918ef9b36e06bfa4"},
{file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c232e7f279388ac9625c3e5a5a9f0078a9334959c5d6458052c65bbbba895e1e"},
{file = "Cython-3.0.9-cp311-cp311-win32.whl", hash = "sha256:357e2fad46a25030b0c0496487e01a9dc0fdd0c09df0897f554d8ba3c1bc4872"},
{file = "Cython-3.0.9-cp311-cp311-win_amd64.whl", hash = "sha256:1315aee506506e8d69cf6631d8769e6b10131fdcc0eb66df2698f2a3ddaeeff2"},
{file = "Cython-3.0.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:157973807c2796addbed5fbc4d9c882ab34bbc60dc297ca729504901479d5df7"},
{file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00b105b5d050645dd59e6767bc0f18b48a4aa11c85f42ec7dd8181606f4059e3"},
{file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac5536d09bef240cae0416d5a703d298b74c7bbc397da803ac9d344e732d4369"},
{file = "Cython-3.0.9-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09c44501d476d16aaa4cbc29c87f8c0f54fc20e69b650d59cbfa4863426fc70c"},
{file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cc9c3b9f20d8e298618e5ccd32083ca386e785b08f9893fbec4c50b6b85be772"},
{file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a30d96938c633e3ec37000ac3796525da71254ef109e66bdfd78f29891af6454"},
{file = "Cython-3.0.9-cp312-cp312-win32.whl", hash = "sha256:757ca93bdd80702546df4d610d2494ef2e74249cac4d5ba9464589fb464bd8a3"},
{file = "Cython-3.0.9-cp312-cp312-win_amd64.whl", hash = "sha256:1dc320a9905ab95414013f6de805efbff9e17bb5fb3b90bbac533f017bec8136"},
{file = "Cython-3.0.9-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4ae349960ebe0da0d33724eaa7f1eb866688fe5434cc67ce4dbc06d6a719fbfc"},
{file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63d2537bf688247f76ded6dee28ebd26274f019309aef1eb4f2f9c5c482fde2d"},
{file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36f5a2dfc724bea1f710b649f02d802d80fc18320c8e6396684ba4a48412445a"},
{file = "Cython-3.0.9-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:deaf4197d4b0bcd5714a497158ea96a2bd6d0f9636095437448f7e06453cc83d"},
{file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:000af6deb7412eb7ac0c635ff5e637fb8725dd0a7b88cc58dfc2b3de14e701c4"},
{file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:15c7f5c2d35bed9aa5f2a51eaac0df23ae72f2dbacf62fc672dd6bfaa75d2d6f"},
{file = "Cython-3.0.9-cp36-cp36m-win32.whl", hash = "sha256:f49aa4970cd3bec66ac22e701def16dca2a49c59cceba519898dd7526e0be2c0"},
{file = "Cython-3.0.9-cp36-cp36m-win_amd64.whl", hash = "sha256:4558814fa025b193058d42eeee498a53d6b04b2980d01339fc2444b23fd98e58"},
{file = "Cython-3.0.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:539cd1d74fd61f6cfc310fa6bbbad5adc144627f2b7486a07075d4e002fd6aad"},
{file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3232926cd406ee02eabb732206f6e882c3aed9d58f0fea764013d9240405bcf"},
{file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33b6ac376538a7fc8c567b85d3c71504308a9318702ec0485dd66c059f3165cb"},
{file = "Cython-3.0.9-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2cc92504b5d22ac66031ffb827bd3a967fc75a5f0f76ab48bce62df19be6fdfd"},
{file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:22b8fae756c5c0d8968691bed520876de452f216c28ec896a00739a12dba3bd9"},
{file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:9cda0d92a09f3520f29bd91009f1194ba9600777c02c30c6d2d4ac65fb63e40d"},
{file = "Cython-3.0.9-cp37-cp37m-win32.whl", hash = "sha256:ec612418490941ed16c50c8d3784c7bdc4c4b2a10c361259871790b02ec8c1db"},
{file = "Cython-3.0.9-cp37-cp37m-win_amd64.whl", hash = "sha256:976c8d2bedc91ff6493fc973d38b2dc01020324039e2af0e049704a8e1b22936"},
{file = "Cython-3.0.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5055988b007c92256b6e9896441c3055556038c3497fcbf8c921a6c1fce90719"},
{file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9360606d964c2d0492a866464efcf9d0a92715644eede3f6a2aa696de54a137"},
{file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02c6e809f060bed073dc7cba1648077fe3b68208863d517c8b39f3920eecf9dd"},
{file = "Cython-3.0.9-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:95ed792c966f969cea7489c32ff90150b415c1f3567db8d5a9d489c7c1602dac"},
{file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8edd59d22950b400b03ca78d27dc694d2836a92ef0cac4f64cb4b2ff902f7e25"},
{file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4cf0ed273bf60e97922fcbbdd380c39693922a597760160b4b4355e6078ca188"},
{file = "Cython-3.0.9-cp38-cp38-win32.whl", hash = "sha256:5eb9bd4ae12ebb2bc79a193d95aacf090fbd8d7013e11ed5412711650cb34934"},
{file = "Cython-3.0.9-cp38-cp38-win_amd64.whl", hash = "sha256:44457279da56e0f829bb1fc5a5dc0836e5d498dbcf9b2324f32f7cc9d2ec6569"},
{file = "Cython-3.0.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c4b419a1adc2af43f4660e2f6eaf1e4fac2dbac59490771eb8ac3d6063f22356"},
{file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f836192140f033b2319a0128936367c295c2b32e23df05b03b672a6015757ea"},
{file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fd198c1a7f8e9382904d622cc0efa3c184605881fd5262c64cbb7168c4c1ec5"},
{file = "Cython-3.0.9-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a274fe9ca5c53fafbcf5c8f262f8ad6896206a466f0eeb40aaf36a7951e957c0"},
{file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:158c38360bbc5063341b1e78d3737f1251050f89f58a3df0d10fb171c44262be"},
{file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bf30b045f7deda0014b042c1b41c1d272facc762ab657529e3b05505888e878"},
{file = "Cython-3.0.9-cp39-cp39-win32.whl", hash = "sha256:9a001fd95c140c94d934078544ff60a3c46aca2dc86e75a76e4121d3cd1f4b33"},
{file = "Cython-3.0.9-cp39-cp39-win_amd64.whl", hash = "sha256:530c01c4aebba709c0ec9c7ecefe07177d0b9fd7ffee29450a118d92192ccbdf"},
{file = "Cython-3.0.9-py2.py3-none-any.whl", hash = "sha256:bf96417714353c5454c2e3238fca9338599330cf51625cdc1ca698684465646f"},
{file = "Cython-3.0.9.tar.gz", hash = "sha256:a2d354f059d1f055d34cfaa62c5b68bc78ac2ceab6407148d47fb508cf3ba4f3"},
{file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"},
{file = "coverage-7.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ccd341521be3d1b3daeb41960ae94a5e87abe2f46f17224ba5d6f2b8398016cf"},
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fa497a8ab37784fbb20ab699c246053ac294d13fc7eb40ec007a5043ec91f8"},
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1a93009cb80730c9bca5d6d4665494b725b6e8e157c1cb7f2db5b4b122ea562"},
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690db6517f09336559dc0b5f55342df62370a48f5469fabf502db2c6d1cffcd2"},
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:09c3255458533cb76ef55da8cc49ffab9e33f083739c8bd4f58e79fecfe288f7"},
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ce1415194b4a6bd0cdcc3a1dfbf58b63f910dcb7330fe15bdff542c56949f87"},
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b91cbc4b195444e7e258ba27ac33769c41b94967919f10037e6355e998af255c"},
{file = "coverage-7.4.4-cp310-cp310-win32.whl", hash = "sha256:598825b51b81c808cb6f078dcb972f96af96b078faa47af7dfcdf282835baa8d"},
{file = "coverage-7.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:09ef9199ed6653989ebbcaacc9b62b514bb63ea2f90256e71fea3ed74bd8ff6f"},
{file = "coverage-7.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f9f50e7ef2a71e2fae92774c99170eb8304e3fdf9c8c3c7ae9bab3e7229c5cf"},
{file = "coverage-7.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:623512f8ba53c422fcfb2ce68362c97945095b864cda94a92edbaf5994201083"},
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0513b9508b93da4e1716744ef6ebc507aff016ba115ffe8ecff744d1322a7b63"},
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40209e141059b9370a2657c9b15607815359ab3ef9918f0196b6fccce8d3230f"},
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a2b2b78c78293782fd3767d53e6474582f62443d0504b1554370bde86cc8227"},
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:73bfb9c09951125d06ee473bed216e2c3742f530fc5acc1383883125de76d9cd"},
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f384c3cc76aeedce208643697fb3e8437604b512255de6d18dae3f27655a384"},
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54eb8d1bf7cacfbf2a3186019bcf01d11c666bd495ed18717162f7eb1e9dd00b"},
{file = "coverage-7.4.4-cp311-cp311-win32.whl", hash = "sha256:cac99918c7bba15302a2d81f0312c08054a3359eaa1929c7e4b26ebe41e9b286"},
{file = "coverage-7.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:b14706df8b2de49869ae03a5ccbc211f4041750cd4a66f698df89d44f4bd30ec"},
{file = "coverage-7.4.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:201bef2eea65e0e9c56343115ba3814e896afe6d36ffd37bab783261db430f76"},
{file = "coverage-7.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41c9c5f3de16b903b610d09650e5e27adbfa7f500302718c9ffd1c12cf9d6818"},
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d898fe162d26929b5960e4e138651f7427048e72c853607f2b200909794ed978"},
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ea79bb50e805cd6ac058dfa3b5c8f6c040cb87fe83de10845857f5535d1db70"},
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce4b94265ca988c3f8e479e741693d143026632672e3ff924f25fab50518dd51"},
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00838a35b882694afda09f85e469c96367daa3f3f2b097d846a7216993d37f4c"},
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fdfafb32984684eb03c2d83e1e51f64f0906b11e64482df3c5db936ce3839d48"},
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:69eb372f7e2ece89f14751fbcbe470295d73ed41ecd37ca36ed2eb47512a6ab9"},
{file = "coverage-7.4.4-cp312-cp312-win32.whl", hash = "sha256:137eb07173141545e07403cca94ab625cc1cc6bc4c1e97b6e3846270e7e1fea0"},
{file = "coverage-7.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:d71eec7d83298f1af3326ce0ff1d0ea83c7cb98f72b577097f9083b20bdaf05e"},
{file = "coverage-7.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ae728ff3b5401cc320d792866987e7e7e880e6ebd24433b70a33b643bb0384"},
{file = "coverage-7.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc4f1358cb0c78edef3ed237ef2c86056206bb8d9140e73b6b89fbcfcbdd40e1"},
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8130a2aa2acb8788e0b56938786c33c7c98562697bf9f4c7d6e8e5e3a0501e4a"},
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf271892d13e43bc2b51e6908ec9a6a5094a4df1d8af0bfc360088ee6c684409"},
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4cdc86d54b5da0df6d3d3a2f0b710949286094c3a6700c21e9015932b81447e"},
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ae71e7ddb7a413dd60052e90528f2f65270aad4b509563af6d03d53e979feafd"},
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:38dd60d7bf242c4ed5b38e094baf6401faa114fc09e9e6632374388a404f98e7"},
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa5b1c1bfc28384f1f53b69a023d789f72b2e0ab1b3787aae16992a7ca21056c"},
{file = "coverage-7.4.4-cp38-cp38-win32.whl", hash = "sha256:dfa8fe35a0bb90382837b238fff375de15f0dcdb9ae68ff85f7a63649c98527e"},
{file = "coverage-7.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:b2991665420a803495e0b90a79233c1433d6ed77ef282e8e152a324bbbc5e0c8"},
{file = "coverage-7.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b799445b9f7ee8bf299cfaed6f5b226c0037b74886a4e11515e569b36fe310d"},
{file = "coverage-7.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4d33f418f46362995f1e9d4f3a35a1b6322cb959c31d88ae56b0298e1c22357"},
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aadacf9a2f407a4688d700e4ebab33a7e2e408f2ca04dbf4aef17585389eff3e"},
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c95949560050d04d46b919301826525597f07b33beba6187d04fa64d47ac82e"},
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff7687ca3d7028d8a5f0ebae95a6e4827c5616b31a4ee1192bdfde697db110d4"},
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5fc1de20b2d4a061b3df27ab9b7c7111e9a710f10dc2b84d33a4ab25065994ec"},
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c74880fc64d4958159fbd537a091d2a585448a8f8508bf248d72112723974cbd"},
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:742a76a12aa45b44d236815d282b03cfb1de3b4323f3e4ec933acfae08e54ade"},
{file = "coverage-7.4.4-cp39-cp39-win32.whl", hash = "sha256:d89d7b2974cae412400e88f35d86af72208e1ede1a541954af5d944a8ba46c57"},
{file = "coverage-7.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ca28a302acb19b6af89e90f33ee3e1906961f94b54ea37de6737b7ca9d8827c"},
{file = "coverage-7.4.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b2c5edc4ac10a7ef6605a966c58929ec6c1bd0917fb8c15cb3363f65aa40e677"},
{file = "coverage-7.4.4.tar.gz", hash = "sha256:c901df83d097649e257e803be22592aedfd5182f07b3cc87d640bbb9afd50f49"},
]
[package.dependencies]
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
[package.extras]
toml = ["tomli"]
[[package]]
name = "debugpy"
version = "1.8.1"
@@ -639,6 +668,17 @@ files = [
[package.extras]
test = ["pytest (>=6)"]
[[package]]
name = "farama-notifications"
version = "0.0.4"
description = "Notifications for all Farama Foundation maintained libraries."
optional = false
python-versions = "*"
files = [
{file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"},
{file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"},
]
[[package]]
name = "fasteners"
version = "0.19"
@@ -840,43 +880,58 @@ files = [
protobuf = ["grpcio-tools (>=1.62.1)"]
[[package]]
name = "gym"
version = "0.26.2"
description = "Gym: A universal API for reinforcement learning environments"
name = "gymnasium"
version = "0.29.1"
description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)."
optional = false
python-versions = ">=3.6"
python-versions = ">=3.8"
files = [
{file = "gym-0.26.2.tar.gz", hash = "sha256:e0d882f4b54f0c65f203104c24ab8a38b039f1289986803c7d02cdbe214fbcc4"},
{file = "gymnasium-0.29.1-py3-none-any.whl", hash = "sha256:61c3384b5575985bb7f85e43213bcb40f36fcdff388cae6bc229304c71f2843e"},
{file = "gymnasium-0.29.1.tar.gz", hash = "sha256:1a532752efcb7590478b1cc7aa04f608eb7a2fdad5570cd217b66b6a35274bb1"},
]
[package.dependencies]
cloudpickle = ">=1.2.0"
gym_notices = ">=0.0.4"
numpy = ">=1.18.0"
farama-notifications = ">=0.0.1"
numpy = ">=1.21.0"
typing-extensions = ">=4.3.0"
[package.extras]
accept-rom-license = ["autorom[accept-rom-license] (>=0.4.2,<0.5.0)"]
all = ["ale-py (>=0.8.0,<0.9.0)", "box2d-py (==2.3.5)", "imageio (>=2.14.1)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (==2.2)", "mujoco_py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (==2.1.0)", "pytest (==7.0.1)", "swig (==4.*)"]
atari = ["ale-py (>=0.8.0,<0.9.0)"]
box2d = ["box2d-py (==2.3.5)", "pygame (==2.1.0)", "swig (==4.*)"]
classic-control = ["pygame (==2.1.0)"]
mujoco = ["imageio (>=2.14.1)", "mujoco (==2.2)"]
mujoco-py = ["mujoco_py (>=2.1,<2.2)"]
other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-python (>=3.0)"]
testing = ["box2d-py (==2.3.5)", "imageio (>=2.14.1)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (==2.2)", "mujoco_py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (==2.1.0)", "pytest (==7.0.1)", "swig (==4.*)"]
toy-text = ["pygame (==2.1.0)"]
all = ["box2d-py (==2.3.5)", "cython (<3)", "imageio (>=2.14.1)", "jax (>=0.4.0)", "jaxlib (>=0.4.0)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (>=2.3.3)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (>=2.1.3)", "shimmy[atari] (>=0.1.0,<1.0)", "swig (==4.*)", "torch (>=1.0.0)"]
atari = ["shimmy[atari] (>=0.1.0,<1.0)"]
box2d = ["box2d-py (==2.3.5)", "pygame (>=2.1.3)", "swig (==4.*)"]
classic-control = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
jax = ["jax (>=0.4.0)", "jaxlib (>=0.4.0)"]
mujoco = ["imageio (>=2.14.1)", "mujoco (>=2.3.3)"]
mujoco-py = ["cython (<3)", "cython (<3)", "mujoco-py (>=2.1,<2.2)", "mujoco-py (>=2.1,<2.2)"]
other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-python (>=3.0)", "torch (>=1.0.0)"]
testing = ["pytest (==7.1.3)", "scipy (>=1.7.3)"]
toy-text = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"]
[[package]]
name = "gym-notices"
version = "0.0.8"
description = "Notices for gym"
name = "gymnasium-robotics"
version = "1.2.4"
description = "Robotics environments for the Gymnasium repo."
optional = false
python-versions = "*"
python-versions = ">=3.8"
files = [
{file = "gym-notices-0.0.8.tar.gz", hash = "sha256:ad25e200487cafa369728625fe064e88ada1346618526102659b4640f2b4b911"},
{file = "gym_notices-0.0.8-py3-none-any.whl", hash = "sha256:e5f82e00823a166747b4c2a07de63b6560b1acb880638547e0cabf825a01e463"},
{file = "gymnasium-robotics-1.2.4.tar.gz", hash = "sha256:d304192b066f8b800599dfbe3d9d90bba9b761ee884472bdc4d05968a8bc61cb"},
{file = "gymnasium_robotics-1.2.4-py3-none-any.whl", hash = "sha256:c2cb23e087ca0280ae6802837eb7b3a6d14e5bd24c00803ab09f015fcff3eef5"},
]
[package.dependencies]
gymnasium = ">=0.26"
imageio = "*"
Jinja2 = ">=3.0.3"
mujoco = ">=2.3.3,<3.0"
numpy = ">=1.21.0"
PettingZoo = ">=1.23.0"
[package.extras]
mujoco-py = ["cython (<3)", "mujoco-py (>=2.1,<2.2)"]
testing = ["Jinja2 (>=3.0.3)", "PettingZoo (>=1.23.0)", "cython (<3)", "mujoco-py (>=2.1,<2.2)", "pytest (==7.0.1)"]
[[package]]
name = "h5py"
version = "3.10.0"
@@ -1506,25 +1561,6 @@ glfw = "*"
numpy = "*"
pyopengl = "*"
[[package]]
name = "mujoco-py"
version = "2.1.2.14"
description = ""
optional = false
python-versions = ">=3.6"
files = [
{file = "mujoco-py-2.1.2.14.tar.gz", hash = "sha256:eb5b14485acf80a3cf8c15f4b080c6a28a9f79e68869aa696d16cbd51ea7706f"},
{file = "mujoco_py-2.1.2.14-py3-none-any.whl", hash = "sha256:37c0b41bc0153a8a0eb3663103a67c60f65467753f74e4ff6e68b879f3e3a71f"},
]
[package.dependencies]
cffi = ">=1.10"
Cython = ">=0.27.2"
fasteners = ">=0.15,<1.0"
glfw = ">=1.4.0"
imageio = ">=2.1.2"
numpy = ">=1.11"
[[package]]
name = "networkx"
version = "3.2.1"
@@ -1940,6 +1976,31 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.9.2)"]
[[package]]
name = "pettingzoo"
version = "1.24.3"
description = "Gymnasium for multi-agent reinforcement learning."
optional = false
python-versions = ">=3.8"
files = [
{file = "pettingzoo-1.24.3-py3-none-any.whl", hash = "sha256:23ed90517d2e8a7098bdaf5e31234b3a7f7b73ca578d70d1ca7b9d0cb0e37982"},
{file = "pettingzoo-1.24.3.tar.gz", hash = "sha256:91f9094f18e06fb74b98f4099cd22e8ae4396125e51719d50b30c9f1c7ab07e6"},
]
[package.dependencies]
gymnasium = ">=0.28.0"
numpy = ">=1.21.0"
[package.extras]
all = ["box2d-py (==2.3.5)", "chess (==1.9.4)", "multi-agent-ale-py (==0.1.11)", "pillow (>=8.0.1)", "pygame (==2.3.0)", "pymunk (==6.2.0)", "rlcard (==1.0.5)", "scipy (>=1.4.1)", "shimmy[openspiel] (>=1.2.0)"]
atari = ["multi-agent-ale-py (==0.1.11)", "pygame (==2.3.0)"]
butterfly = ["pygame (==2.3.0)", "pymunk (==6.2.0)"]
classic = ["chess (==1.9.4)", "pygame (==2.3.0)", "rlcard (==1.0.5)", "shimmy[openspiel] (>=1.2.0)"]
mpe = ["pygame (==2.3.0)"]
other = ["pillow (>=8.0.1)"]
sisl = ["box2d-py (==2.3.5)", "pygame (==2.3.0)", "pymunk (==6.2.0)", "scipy (>=1.4.1)"]
testing = ["AutoROM", "pre-commit", "pynput", "pytest", "pytest-cov", "pytest-markdown-docs", "pytest-xdist"]
[[package]]
name = "pillow"
version = "10.2.0"
@@ -2342,6 +2403,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-cov"
version = "5.0.0"
description = "Pytest plugin for measuring coverage."
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
]
[package.dependencies]
coverage = {version = ">=5.2.1", extras = ["toml"]}
pytest = ">=4.6"
[package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
@@ -3510,4 +3589,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "1a45c808e1c48bcbf4319d4cf6876771b7d50f40a5a8968a8b7f3af36192bf34"
content-hash = "174c7d42f8039eedd2c447a4e6cae5169782cbd94346b5606572a0010194ca05"

View File

@@ -21,7 +21,6 @@ packages = [{include = "lerobot"}]
[tool.poetry.dependencies]
python = "^3.10"
cython = "^3.0.8"
termcolor = "^2.4.0"
omegaconf = "^2.3.0"
dm-env = "^1.6"
@@ -42,9 +41,7 @@ mpmath = "^1.3.0"
torch = "^2.2.1"
tensordict = {git = "https://github.com/pytorch/tensordict"}
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
mujoco = "2.3.7"
mujoco-py = "^2.1.2.14"
gym = "^0.26.2"
mujoco = "^2.3.7"
opencv-python = "^4.9.0.80"
diffusers = "^0.26.3"
torchvision = "^0.17.1"
@@ -52,12 +49,16 @@ h5py = "^3.10.0"
dm-control = "1.0.14"
huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"}
robomimic = "0.2.0"
gymnasium-robotics = "^1.2.4"
gymnasium = "^0.29.1"
cmake = "^3.29.0.1"
[tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2"
debugpy = "^1.8.1"
pytest = "^8.1.0"
pytest-cov = "^5.0.0"
[tool.ruff]
@@ -90,7 +91,7 @@ exclude = [
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
ignore-init-module-imports = true
[tool.poetry-dynamic-versioning]
enable = true

Some files were not shown because too many files have changed in this diff Show More