Compare commits
39 Commits
user/rcade
...
thom-propo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1e47202c0 | ||
|
|
24821fee24 | ||
|
|
4751642ace | ||
|
|
11cbf1bea1 | ||
|
|
f1148b8c2d | ||
|
|
2a98cc71ed | ||
|
|
a7c9b78e56 | ||
|
|
404b8f8a75 | ||
|
|
17c2bbbeb8 | ||
|
|
006e5feabf | ||
|
|
b99ee8180a | ||
|
|
6bddcb647e | ||
|
|
58df2066a9 | ||
|
|
c89aa4f8ed | ||
|
|
62aad7104b | ||
|
|
9d9148dad8 | ||
|
|
1b6cb2b1be | ||
|
|
6f1a0aefab | ||
|
|
b7c9c33072 | ||
|
|
120f0aef5c | ||
|
|
032200e32c | ||
|
|
de1e9187c8 | ||
|
|
4f8f1926f9 | ||
|
|
6710121a29 | ||
|
|
5f4b8ab899 | ||
|
|
18e7f4c3e6 | ||
|
|
643d64e2a8 | ||
|
|
c037722e23 | ||
|
|
6cd671040f | ||
|
|
b6353964ba | ||
|
|
64c8851c40 | ||
|
|
dc745e3037 | ||
|
|
6f0c2445ca | ||
|
|
d1d2229407 | ||
|
|
68d02c80cf | ||
|
|
011f2d27fe | ||
|
|
be4441c7ff | ||
|
|
1ed0110900 | ||
|
|
cb6d1e0871 |
116
.github/poetry/cpu/poetry.lock
generated
vendored
116
.github/poetry/cpu/poetry.lock
generated
vendored
@@ -327,6 +327,35 @@ files = [
|
||||
{file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "3.29.0.1"
|
||||
description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "cmake-3.29.0.1-py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:ec8b39fdeb75c48fd5a2894658a1ca75f94fb49b421c1f753d86d3e5d5e9f196"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ead7dc5176a6c6347b3fc19532c25ec328f9279b6213902ac930242334e7b621"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f7c7dabf3dd40cb830d2eded43d51c5a3737625bbc5ab6916041c04e352b74f8"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de8f55198f4a820daf2c57645a4bb8cd1064dc92d950ad95be14c5ffabc15bd4"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bb8aaa3419eafd466931e4dcc161d3e5e6a82730ab508c75946ff4fc883b3f6"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e956e4b6c2d8d6bbee399bf3c77e5b901d916fe8f35d6b2f58444d5892c4602b"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a8f7c8b07e6ab0dd444c5b74e658d5013ca0da456041029f734d751090bb7ec"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22493049b6383ea2baa7237a326c2914ab4a7b3e1642f4233245e3a34aae39f6"},
|
||||
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:d09573411901fc8a3ede7433713c000ac7b81cda0d771874e9182770acf29eb4"},
|
||||
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_i686.whl", hash = "sha256:c85e35ec572e54152154637f24d6bde316fbcf94dfea644bd8f22b1855a09abe"},
|
||||
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:a36f3b19a5caa6c63aa70bfa0b262bae8d296e68c6e6d8e918edc5c51d952bf8"},
|
||||
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:85a0e28eaaee311d50fcee60f730e5a44a65b3cefc556a1163bbabd7328acd60"},
|
||||
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:58879ef15dd8344e1583a36cead794fb0fee13f78c590a56749283ac2e27d30a"},
|
||||
{file = "cmake-3.29.0.1-py3-none-win32.whl", hash = "sha256:068a3e7461dd9e487f5f3f720ae8072c2eeb37239ebe4642c4ae29058d83347f"},
|
||||
{file = "cmake-3.29.0.1-py3-none-win_amd64.whl", hash = "sha256:c097892b3653e2d2d41d055c80a9af029fdd24f9eff2ecb66576e4da8b85b1c7"},
|
||||
{file = "cmake-3.29.0.1-py3-none-win_arm64.whl", hash = "sha256:1808071047cb49ed0fe2359e4c310b49880c2805cad4ea9f03c959f51b881ac7"},
|
||||
{file = "cmake-3.29.0.1.tar.gz", hash = "sha256:ec49a7a4480959c229d9d2aa77f7885859c17a45fc66981aaf4551ceffb4d030"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
test = ["coverage (>=4.2)", "pytest (>=3.0.3)", "pytest-cov (>=2.4.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
@@ -338,6 +367,73 @@ files = [
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "7.4.4"
|
||||
description = "Code coverage measurement for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ccd341521be3d1b3daeb41960ae94a5e87abe2f46f17224ba5d6f2b8398016cf"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fa497a8ab37784fbb20ab699c246053ac294d13fc7eb40ec007a5043ec91f8"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1a93009cb80730c9bca5d6d4665494b725b6e8e157c1cb7f2db5b4b122ea562"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690db6517f09336559dc0b5f55342df62370a48f5469fabf502db2c6d1cffcd2"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:09c3255458533cb76ef55da8cc49ffab9e33f083739c8bd4f58e79fecfe288f7"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ce1415194b4a6bd0cdcc3a1dfbf58b63f910dcb7330fe15bdff542c56949f87"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b91cbc4b195444e7e258ba27ac33769c41b94967919f10037e6355e998af255c"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-win32.whl", hash = "sha256:598825b51b81c808cb6f078dcb972f96af96b078faa47af7dfcdf282835baa8d"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:09ef9199ed6653989ebbcaacc9b62b514bb63ea2f90256e71fea3ed74bd8ff6f"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f9f50e7ef2a71e2fae92774c99170eb8304e3fdf9c8c3c7ae9bab3e7229c5cf"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:623512f8ba53c422fcfb2ce68362c97945095b864cda94a92edbaf5994201083"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0513b9508b93da4e1716744ef6ebc507aff016ba115ffe8ecff744d1322a7b63"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40209e141059b9370a2657c9b15607815359ab3ef9918f0196b6fccce8d3230f"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a2b2b78c78293782fd3767d53e6474582f62443d0504b1554370bde86cc8227"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:73bfb9c09951125d06ee473bed216e2c3742f530fc5acc1383883125de76d9cd"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f384c3cc76aeedce208643697fb3e8437604b512255de6d18dae3f27655a384"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54eb8d1bf7cacfbf2a3186019bcf01d11c666bd495ed18717162f7eb1e9dd00b"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-win32.whl", hash = "sha256:cac99918c7bba15302a2d81f0312c08054a3359eaa1929c7e4b26ebe41e9b286"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:b14706df8b2de49869ae03a5ccbc211f4041750cd4a66f698df89d44f4bd30ec"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:201bef2eea65e0e9c56343115ba3814e896afe6d36ffd37bab783261db430f76"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41c9c5f3de16b903b610d09650e5e27adbfa7f500302718c9ffd1c12cf9d6818"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d898fe162d26929b5960e4e138651f7427048e72c853607f2b200909794ed978"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ea79bb50e805cd6ac058dfa3b5c8f6c040cb87fe83de10845857f5535d1db70"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce4b94265ca988c3f8e479e741693d143026632672e3ff924f25fab50518dd51"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00838a35b882694afda09f85e469c96367daa3f3f2b097d846a7216993d37f4c"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fdfafb32984684eb03c2d83e1e51f64f0906b11e64482df3c5db936ce3839d48"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:69eb372f7e2ece89f14751fbcbe470295d73ed41ecd37ca36ed2eb47512a6ab9"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-win32.whl", hash = "sha256:137eb07173141545e07403cca94ab625cc1cc6bc4c1e97b6e3846270e7e1fea0"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:d71eec7d83298f1af3326ce0ff1d0ea83c7cb98f72b577097f9083b20bdaf05e"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ae728ff3b5401cc320d792866987e7e7e880e6ebd24433b70a33b643bb0384"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc4f1358cb0c78edef3ed237ef2c86056206bb8d9140e73b6b89fbcfcbdd40e1"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8130a2aa2acb8788e0b56938786c33c7c98562697bf9f4c7d6e8e5e3a0501e4a"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf271892d13e43bc2b51e6908ec9a6a5094a4df1d8af0bfc360088ee6c684409"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4cdc86d54b5da0df6d3d3a2f0b710949286094c3a6700c21e9015932b81447e"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ae71e7ddb7a413dd60052e90528f2f65270aad4b509563af6d03d53e979feafd"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:38dd60d7bf242c4ed5b38e094baf6401faa114fc09e9e6632374388a404f98e7"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa5b1c1bfc28384f1f53b69a023d789f72b2e0ab1b3787aae16992a7ca21056c"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-win32.whl", hash = "sha256:dfa8fe35a0bb90382837b238fff375de15f0dcdb9ae68ff85f7a63649c98527e"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:b2991665420a803495e0b90a79233c1433d6ed77ef282e8e152a324bbbc5e0c8"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b799445b9f7ee8bf299cfaed6f5b226c0037b74886a4e11515e569b36fe310d"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4d33f418f46362995f1e9d4f3a35a1b6322cb959c31d88ae56b0298e1c22357"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aadacf9a2f407a4688d700e4ebab33a7e2e408f2ca04dbf4aef17585389eff3e"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c95949560050d04d46b919301826525597f07b33beba6187d04fa64d47ac82e"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff7687ca3d7028d8a5f0ebae95a6e4827c5616b31a4ee1192bdfde697db110d4"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5fc1de20b2d4a061b3df27ab9b7c7111e9a710f10dc2b84d33a4ab25065994ec"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c74880fc64d4958159fbd537a091d2a585448a8f8508bf248d72112723974cbd"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:742a76a12aa45b44d236815d282b03cfb1de3b4323f3e4ec933acfae08e54ade"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-win32.whl", hash = "sha256:d89d7b2974cae412400e88f35d86af72208e1ede1a541954af5d944a8ba46c57"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ca28a302acb19b6af89e90f33ee3e1906961f94b54ea37de6737b7ca9d8827c"},
|
||||
{file = "coverage-7.4.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b2c5edc4ac10a7ef6605a966c58929ec6c1bd0917fb8c15cb3363f65aa40e677"},
|
||||
{file = "coverage-7.4.4.tar.gz", hash = "sha256:c901df83d097649e257e803be22592aedfd5182f07b3cc87d640bbb9afd50f49"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
|
||||
|
||||
[package.extras]
|
||||
toml = ["tomli"]
|
||||
|
||||
[[package]]
|
||||
name = "debugpy"
|
||||
version = "1.8.1"
|
||||
@@ -2103,6 +2199,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""}
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "5.0.0"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
|
||||
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
coverage = {version = ">=5.2.1", extras = ["toml"]}
|
||||
pytest = ">=4.6"
|
||||
|
||||
[package.extras]
|
||||
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
@@ -3216,4 +3330,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "93c406139c456780b3d309d7ed3d68ea60cc0e8893c1ee717692984e573d3404"
|
||||
content-hash = "8800bb8b24312d17b765cd2ce2799f49436171dd5fbf1bec3b07f853cfa9befd"
|
||||
|
||||
2
.github/poetry/cpu/pyproject.toml
vendored
2
.github/poetry/cpu/pyproject.toml
vendored
@@ -52,12 +52,14 @@ robomimic = "0.2.0"
|
||||
huggingface-hub = "^0.21.4"
|
||||
gymnasium-robotics = "^1.2.4"
|
||||
gymnasium = "^0.29.1"
|
||||
cmake = "^3.29.0.1"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pre-commit = "^3.6.2"
|
||||
debugpy = "^1.8.1"
|
||||
pytest = "^8.1.0"
|
||||
pytest-cov = "^5.0.0"
|
||||
|
||||
|
||||
[[tool.poetry.source]]
|
||||
|
||||
117
.github/workflows/test.yml
vendored
117
.github/workflows/test.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Test
|
||||
name: Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
@@ -10,7 +10,7 @@ on:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
test:
|
||||
tests:
|
||||
if: |
|
||||
${{ github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'CI') }} ||
|
||||
${{ github.event_name == 'push' }}
|
||||
@@ -19,7 +19,6 @@ jobs:
|
||||
POETRY_VERSION: 1.8.2
|
||||
DATA_DIR: tests/data
|
||||
MUJOCO_GL: egl
|
||||
LEROBOT_TESTS_DEVICE: cpu
|
||||
steps:
|
||||
#----------------------------------------------
|
||||
# check-out repo and set-up python
|
||||
@@ -110,34 +109,126 @@ jobs:
|
||||
run: poetry install --no-interaction
|
||||
|
||||
#----------------------------------------------
|
||||
# run tests
|
||||
# run tests & coverage
|
||||
#----------------------------------------------
|
||||
- name: Run tests
|
||||
env:
|
||||
LEROBOT_TESTS_DEVICE: cpu
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
pytest tests
|
||||
pytest --cov=./lerobot --cov-report=xml tests
|
||||
|
||||
- name: Test train pusht end-to-end
|
||||
# TODO(aliberts): Link with HF Codecov account
|
||||
# - name: Upload coverage reports to Codecov with GitHub Action
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# files: ./coverage.xml
|
||||
# verbose: true
|
||||
|
||||
#----------------------------------------------
|
||||
# run end-to-end tests
|
||||
#----------------------------------------------
|
||||
- name: Test train ACT on ALOHA end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.job.name=pusht \
|
||||
policy=act \
|
||||
env=aloha \
|
||||
wandb.enable=False \
|
||||
offline_steps=2 \
|
||||
online_steps=0 \
|
||||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=2 \
|
||||
horizon=20 \
|
||||
policy.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/act/
|
||||
|
||||
- name: Test eval ACT on ALOHA end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/eval.py \
|
||||
--config tests/outputs/act/.hydra/config.yaml \
|
||||
eval_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/act/models/2.pt
|
||||
|
||||
# TODO(aliberts): This takes ~2mn to run, needs to be improved
|
||||
# - name: Test eval ACT on ALOHA end-to-end (policy is None)
|
||||
# run: |
|
||||
# source .venv/bin/activate
|
||||
# python lerobot/scripts/eval.py \
|
||||
# --config lerobot/configs/default.yaml \
|
||||
# policy=act \
|
||||
# env=aloha \
|
||||
# eval_episodes=1 \
|
||||
# device=cpu
|
||||
|
||||
- name: Test train Diffusion on PushT end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/train.py \
|
||||
policy=diffusion \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
offline_steps=2 \
|
||||
online_steps=0 \
|
||||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=1 \
|
||||
hydra.run.dir=tests/outputs/
|
||||
save_freq=2 \
|
||||
hydra.run.dir=tests/outputs/diffusion/
|
||||
|
||||
- name: Test eval pusht end-to-end
|
||||
- name: Test eval Diffusion on PushT end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/eval.py \
|
||||
--config tests/outputs/.hydra/config.yaml \
|
||||
wandb.enable=False \
|
||||
--config tests/outputs/diffusion/.hydra/config.yaml \
|
||||
eval_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/models/1.pt
|
||||
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
|
||||
|
||||
- name: Test eval Diffusion on PushT end-to-end (policy is None)
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/eval.py \
|
||||
--config lerobot/configs/default.yaml \
|
||||
policy=diffusion \
|
||||
env=pusht \
|
||||
eval_episodes=1 \
|
||||
device=cpu
|
||||
|
||||
- name: Test train TDMPC on Simxarm end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/train.py \
|
||||
policy=tdmpc \
|
||||
env=simxarm \
|
||||
wandb.enable=False \
|
||||
offline_steps=1 \
|
||||
online_steps=1 \
|
||||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=2 \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
||||
- name: Test eval TDMPC on Simxarm end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/eval.py \
|
||||
--config tests/outputs/tdmpc/.hydra/config.yaml \
|
||||
eval_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
|
||||
|
||||
- name: Test eval TDPMC on Simxarm end-to-end (policy is None)
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/eval.py \
|
||||
--config lerobot/configs/default.yaml \
|
||||
policy=tdmpc \
|
||||
env=simxarm \
|
||||
eval_episodes=1 \
|
||||
device=cpu
|
||||
|
||||
96
README.md
96
README.md
@@ -8,9 +8,25 @@
|
||||
<br/>
|
||||
</p>
|
||||
|
||||
# LeRobot
|
||||
<div align="center">
|
||||
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/test.yml?query=branch%3Amain)
|
||||
[](https://codecov.io/gh/huggingface/lerobot)
|
||||
[](https://www.python.org/downloads/)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/LICENSE)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://github.com/huggingface/lerobot/tree/main/examples)
|
||||
[](https://discord.gg/s3KuuzsPFb)
|
||||
|
||||
</div>
|
||||
|
||||
<h3 align="center">
|
||||
<p>State-of-the-art Machine Learning for real-world robotics</p>
|
||||
</h3>
|
||||
|
||||
---
|
||||
|
||||
**State-of-the-art machine learning for real-world robotics**
|
||||
|
||||
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
|
||||
|
||||
@@ -44,26 +60,21 @@
|
||||
|
||||
## Installation
|
||||
|
||||
Create a virtual environment with Python 3.10, e.g. using `conda`:
|
||||
Download our source code:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
```
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
[Install `poetry`](https://python-poetry.org/docs/#installation) (if you don't have it already)
|
||||
Then, install 🤗 LeRobot:
|
||||
```bash
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
```
|
||||
|
||||
Install dependencies
|
||||
```bash
|
||||
poetry install
|
||||
```
|
||||
|
||||
If you encounter a disk space error, try to change your tmp dir to a location where you have enough disk space, e.g.
|
||||
```bash
|
||||
mkdir ~/tmp
|
||||
export TMPDIR='~/tmp'
|
||||
python -m pip install .
|
||||
```
|
||||
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
|
||||
@@ -135,22 +146,17 @@ hydra.run.dir=outputs/visualize_dataset/example
|
||||
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
You can import our environment class, download pretrained policies from the HuggingFace hub, and use our rollout utilities with rendering:
|
||||
```python
|
||||
""" Copy pasted from `examples/2_evaluate_pretrained_policy.py`
|
||||
# TODO
|
||||
```
|
||||
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
|
||||
|
||||
Or you can achieve the same result by executing our script from the command line:
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--hub-id lerobot/diffusion_policy_pusht_image \
|
||||
--revision v1.0 \
|
||||
eval_episodes=10 \
|
||||
hydra.run.dir=outputs/eval/example_hub
|
||||
```
|
||||
|
||||
After launching training of your own policy, you can also re-evaluate the checkpoints with:
|
||||
After training your own policy, you can also re-evaluate the checkpoints with:
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--config PATH/TO/FOLDER/config.yaml \
|
||||
@@ -163,19 +169,9 @@ See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub):
|
||||
```python
|
||||
""" Copy pasted from `examples/3_train_policy.py`
|
||||
# TODO
|
||||
```
|
||||
You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): check out [example 3](./examples/3_train_policy.py). After you run this, you may want to revisit [example 2](./examples/2_evaluate_pretrained_policy.py) to evaluate your training output!
|
||||
|
||||
Or you can achieve the same result by executing our script from the command line:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=outputs/train/example
|
||||
```
|
||||
|
||||
You can easily train any policy on any environment:
|
||||
In general, you can use our training script to easily train any policy on any environment:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
env=aloha \
|
||||
@@ -189,11 +185,11 @@ hydra.run.dir=outputs/train/aloha_act
|
||||
|
||||
Feel free to open issues and PRs, and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co).
|
||||
|
||||
**TODO**
|
||||
### TODO
|
||||
|
||||
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
|
||||
|
||||
**Follow our style**
|
||||
### Follow our style
|
||||
|
||||
```bash
|
||||
# install if needed
|
||||
@@ -202,25 +198,33 @@ pre-commit install
|
||||
pre-commit
|
||||
```
|
||||
|
||||
**Add dependencies**
|
||||
### Add dependencies
|
||||
|
||||
Instead of `pip install some-package`, we use `poetry` to track the versions of our dependencies:
|
||||
Instead of using `pip` directly, we use `poetry` for development purposes to easily track our dependencies.
|
||||
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
|
||||
|
||||
Install the project with:
|
||||
```bash
|
||||
poetry install
|
||||
```
|
||||
|
||||
Then, the equivalent of `pip install some-package`, would just be:
|
||||
```bash
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
**NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated CI. To do this, you should create a separate environment and add the new package there as well. For example:
|
||||
**NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated to the CI. To do this, you should create a separate environment and add the new package there as well. For example:
|
||||
```bash
|
||||
# add the new package to your main poetry env
|
||||
# Add the new package to your main poetry env
|
||||
poetry add some-package
|
||||
# add the same package to the CPU-only env dedicated to CI
|
||||
# Add the same package to the CPU-only env dedicated to CI
|
||||
conda create -y -n lerobot-ci python=3.10
|
||||
conda activate lerobot-ci
|
||||
cd .github/poetry/cpu
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
**Run tests locally**
|
||||
### Run tests locally
|
||||
|
||||
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
|
||||
|
||||
@@ -251,7 +255,7 @@ Run tests
|
||||
DATA_DIR="tests/data" pytest -sx tests
|
||||
```
|
||||
|
||||
**Add a new dataset**
|
||||
### Add a new dataset
|
||||
|
||||
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
|
||||
```bash
|
||||
@@ -312,7 +316,7 @@ Finally, you might want to mock the dataset if you need to update the unit tests
|
||||
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
|
||||
```
|
||||
|
||||
**Add a pretrained policy**
|
||||
### Add a pretrained policy
|
||||
|
||||
Once you have trained a policy you may upload it to the HuggingFace hub.
|
||||
|
||||
@@ -348,7 +352,7 @@ huggingface-cli upload $HUB_ID to_upload
|
||||
See `eval.py` for an example of how a user may use your policy.
|
||||
|
||||
|
||||
**Improve your code with profiling**
|
||||
### Improve your code with profiling
|
||||
|
||||
An example of a code snippet to profile the evaluation of a policy:
|
||||
```python
|
||||
|
||||
@@ -1 +1,39 @@
|
||||
# TODO
|
||||
"""
|
||||
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
||||
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
from lerobot.scripts.eval import eval
|
||||
|
||||
# Get a pretrained policy from the hub.
|
||||
hub_id = "lerobot/diffusion_policy_pusht_image"
|
||||
folder = Path(snapshot_download(hub_id))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# folder = Path("outputs/train/example_pusht_diffusion")
|
||||
|
||||
config_path = folder / "config.yaml"
|
||||
weights_path = folder / "model.pt"
|
||||
stats_path = folder / "stats.pth" # normalization stats
|
||||
|
||||
# Override some config parameters to do with evaluation.
|
||||
overrides = [
|
||||
f"policy.pretrained_model_path={weights_path}",
|
||||
"eval_episodes=10",
|
||||
"rollout_batch_size=10",
|
||||
"device=cuda",
|
||||
]
|
||||
|
||||
# Create a Hydra config.
|
||||
cfg = init_hydra_config(config_path, overrides)
|
||||
|
||||
# Evaluate the policy and save the outputs including metrics and videos.
|
||||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
|
||||
stats_path=stats_path,
|
||||
)
|
||||
|
||||
@@ -1 +1,55 @@
|
||||
# TODO
|
||||
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
||||
|
||||
Once you have trained a model with this script, you can try to evaluate it on
|
||||
examples/2_evaluate_pretrained_policy.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
|
||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||
os.makedirs(output_directory, exist_ok=True)
|
||||
|
||||
overrides = [
|
||||
"env=pusht",
|
||||
"policy=diffusion",
|
||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||
"offline_steps=5000",
|
||||
"log_freq=250",
|
||||
"device=cuda",
|
||||
]
|
||||
|
||||
cfg = init_hydra_config("lerobot/configs/default.yaml", overrides)
|
||||
|
||||
policy = DiffusionPolicy(
|
||||
cfg=cfg.policy,
|
||||
cfg_device=cfg.device,
|
||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
policy.train()
|
||||
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
for offline_step in trange(cfg.offline_steps):
|
||||
train_info = policy.update(offline_buffer, offline_step)
|
||||
if offline_step % cfg.log_freq == 0:
|
||||
print(train_info)
|
||||
|
||||
# Save the policy, configuration, and normalization stats for later use.
|
||||
policy.save_pretrained(output_directory / "model.pt")
|
||||
OmegaConf.save(cfg, output_directory / "config.yaml")
|
||||
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
|
||||
|
||||
@@ -81,13 +81,8 @@ def make_offline_buffer(
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
# TODO(rcadene): backward compatiblity to load pretrained pusht policy
|
||||
dataset_id = cfg.get("dataset_id")
|
||||
if dataset_id is None and cfg.env.name == "pusht":
|
||||
dataset_id = "pusht"
|
||||
|
||||
offline_buffer = clsfunc(
|
||||
dataset_id=dataset_id,
|
||||
dataset_id=cfg.dataset_id,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
root=DATA_DIR,
|
||||
|
||||
@@ -206,7 +206,7 @@ class AlohaEnv(AbstractEnv):
|
||||
if self.from_pixels:
|
||||
if isinstance(self.image_size, int):
|
||||
image_shape = (3, self.image_size, self.image_size)
|
||||
elif OmegaConf.is_list(self.image_size):
|
||||
elif OmegaConf.is_list(self.image_size) or isinstance(self.image_size, list):
|
||||
assert len(self.image_size) == 3 # c h w
|
||||
assert self.image_size[0] == 3 # c is RGB
|
||||
image_shape = tuple(self.image_size)
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
@@ -67,11 +68,11 @@ class Logger:
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
def save_model(self, policy, identifier):
|
||||
def save_model(self, policy: AbstractPolicy, identifier):
|
||||
if self._save_model:
|
||||
self._model_dir.mkdir(parents=True, exist_ok=True)
|
||||
fp = self._model_dir / f"{str(identifier)}.pt"
|
||||
policy.save(fp)
|
||||
policy.save_pretrained(fp)
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
|
||||
@@ -2,14 +2,25 @@ from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
|
||||
class AbstractPolicy(nn.Module):
|
||||
class AbstractPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"""Base policy which all policies should be derived from.
|
||||
|
||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||
documentation for more information.
|
||||
|
||||
The policy is a PyTorchModelHubMixin, which means that it can be saved and loaded from the Hugging Face Hub and/or to a local directory.
|
||||
# Save policy weights to local directory
|
||||
>>> policy.save_pretrained("my-awesome-policy")
|
||||
|
||||
# Push policy weights to the Hub
|
||||
>>> policy.push_to_hub("my-awesome-policy")
|
||||
|
||||
# Download and initialize policy from the Hub
|
||||
>>> policy = MyPolicy.from_pretrained("username/my-awesome-policy")
|
||||
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
@@ -22,7 +33,7 @@ class AbstractPolicy(nn.Module):
|
||||
|
||||
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||
|
||||
def __init__(self, n_action_steps: int | None):
|
||||
def __init__(self, n_action_steps: int | None = None):
|
||||
"""
|
||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||
@@ -37,10 +48,10 @@ class AbstractPolicy(nn.Module):
|
||||
"""One step of the policy's learning algorithm."""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def save(self, fp):
|
||||
def save(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
def load(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
|
||||
@@ -136,8 +136,8 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
def load(self, fp, device=None):
|
||||
d = torch.load(fp, map_location=device)
|
||||
self.load_state_dict(d)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
|
||||
@@ -32,7 +32,7 @@ assert len(unexpected_keys) == 0
|
||||
Then in that same runtime you can also save the weights with the new aligned state_dict:
|
||||
|
||||
```
|
||||
policy.save("weights.pt")
|
||||
policy.save_pretrained("my-policy")
|
||||
```
|
||||
|
||||
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.
|
||||
|
||||
@@ -203,8 +203,8 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
def load(self, fp, device=None):
|
||||
d = torch.load(fp, map_location=device)
|
||||
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
||||
if len(missing_keys) > 0:
|
||||
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
||||
|
||||
@@ -1,35 +1,53 @@
|
||||
def make_policy(cfg):
|
||||
""" Factory for policies
|
||||
"""
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
|
||||
|
||||
def make_policy(cfg: dict) -> AbstractPolicy:
|
||||
""" Instantiate a policy from the configuration.
|
||||
Currently supports TD-MPC, Diffusion, and ACT: select the policy with cfg.policy.name: tdmpc, diffusion, act.
|
||||
|
||||
Args:
|
||||
cfg: The configuration (DictConfig)
|
||||
|
||||
"""
|
||||
policy_kwargs = {}
|
||||
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
|
||||
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
policy = TDMPCPolicy(cfg.policy, cfg.device)
|
||||
policy_cls = TDMPCPolicy
|
||||
policy_kwargs = {"cfg": cfg.policy, "device": cfg.device}
|
||||
elif cfg.policy.name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||
|
||||
policy = DiffusionPolicy(
|
||||
cfg=cfg.policy,
|
||||
cfg_device=cfg.device,
|
||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
policy_cls = DiffusionPolicy
|
||||
policy_kwargs = {
|
||||
"cfg": cfg.policy,
|
||||
"cfg_device": cfg.device,
|
||||
"cfg_noise_scheduler": cfg.noise_scheduler,
|
||||
"cfg_rgb_model": cfg.rgb_model,
|
||||
"cfg_obs_encoder": cfg.obs_encoder,
|
||||
"cfg_optimizer": cfg.optimizer,
|
||||
"cfg_ema": cfg.ema,
|
||||
"n_action_steps": cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
}
|
||||
elif cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||
|
||||
policy = ActionChunkingTransformerPolicy(
|
||||
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
|
||||
)
|
||||
policy_cls = ActionChunkingTransformerPolicy
|
||||
policy_kwargs = {"cfg": cfg.policy, "device": cfg.device, "n_action_steps": cfg.n_action_steps + cfg.n_latency_steps}
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
# policy.load(cfg.policy.pretrained_model_path, device=cfg.device)
|
||||
policy = policy_cls.from_pretrained(cfg.policy.pretrained_model_path, map_location=cfg.device, **policy_kwargs)
|
||||
|
||||
# TODO(rcadene): hack for old pretrained models from fowm
|
||||
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
||||
if "offline" in cfg.pretrained_model_path:
|
||||
@@ -38,6 +56,5 @@ def make_policy(cfg):
|
||||
policy.step[0] = 100000
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
policy.load(cfg.policy.pretrained_model_path)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -122,9 +122,9 @@ class TDMPCPolicy(AbstractPolicy):
|
||||
"""Save state dict of TOLD model to filepath."""
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
def load(self, fp, device=None):
|
||||
"""Load a saved state dict from filepath into current agent."""
|
||||
d = torch.load(fp)
|
||||
d = torch.load(fp, map_location=device)
|
||||
self.model.load_state_dict(d["model"])
|
||||
self.model_target.load_state_dict(d["model_target"])
|
||||
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import logging
|
||||
import os.path as osp
|
||||
import random
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||
@@ -63,3 +67,31 @@ def format_big_number(num):
|
||||
num /= divisor
|
||||
|
||||
return num
|
||||
|
||||
|
||||
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
"""Returns path1 relative to path2."""
|
||||
path1 = path1.absolute()
|
||||
path2 = path2.absolute()
|
||||
try:
|
||||
return path1.relative_to(path2)
|
||||
except ValueError: # most likely because path1 is not a subpath of path2
|
||||
common_parts = Path(osp.commonpath([path1, path2])).parts
|
||||
return Path(
|
||||
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
||||
)
|
||||
|
||||
|
||||
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
|
||||
"""Initialize a Hydra config given only the path to the relevant config file.
|
||||
|
||||
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
|
||||
"""
|
||||
# TODO(alexander-soare): Resolve configs without Hydra initialization.
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
# Hydra needs a path relative to this file.
|
||||
hydra.initialize(
|
||||
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent))
|
||||
)
|
||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||
return cfg
|
||||
|
||||
@@ -103,29 +103,3 @@ optimizer:
|
||||
betas: [0.95, 0.999]
|
||||
eps: 1.0e-8
|
||||
weight_decay: 1.0e-6
|
||||
|
||||
training:
|
||||
device: "cuda:0"
|
||||
seed: 42
|
||||
debug: False
|
||||
resume: True
|
||||
# optimization
|
||||
# lr_scheduler: cosine
|
||||
# lr_warmup_steps: 500
|
||||
num_epochs: 8000
|
||||
# gradient_accumulate_every: 1
|
||||
# EMA destroys performance when used with BatchNorm
|
||||
# replace BatchNorm with GroupNorm.
|
||||
# use_ema: True
|
||||
freeze_encoder: False
|
||||
# training loop control
|
||||
# in epochs
|
||||
rollout_every: 50
|
||||
checkpoint_every: 50
|
||||
val_every: 1
|
||||
sample_every: 5
|
||||
# steps per epoch
|
||||
max_train_steps: null
|
||||
max_val_steps: null
|
||||
# misc
|
||||
tqdm_interval_sec: 1.0
|
||||
|
||||
@@ -30,14 +30,13 @@ python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os.path as osp
|
||||
import threading
|
||||
import time
|
||||
from typing import Tuple, Union
|
||||
from datetime import datetime as dt
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import hydra
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -52,7 +51,7 @@ from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import get_safe_torch_device, init_logging, set_global_seed
|
||||
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
@@ -68,7 +67,19 @@ def eval_policy(
|
||||
video_dir: Path = None,
|
||||
fps: int = 15,
|
||||
return_first_video: bool = False,
|
||||
):
|
||||
) -> Union[dict, Tuple[dict, torch.Tensor]]:
|
||||
""" Evaluate a policy on an environment by running rollouts and computing metrics.
|
||||
|
||||
Args:
|
||||
env: The environment to evaluate.
|
||||
policy: The policy to evaluate.
|
||||
num_episodes: The number of episodes to evaluate.
|
||||
max_steps: The maximum number of steps per episode.
|
||||
save_video: Whether to save videos of the evaluation episodes.
|
||||
video_dir: The directory to save the videos.
|
||||
fps: The frames per second for the videos.
|
||||
return_first_video: Whether to return the first video as a tensor.
|
||||
"""
|
||||
if policy is not None:
|
||||
policy.eval()
|
||||
start = time.time()
|
||||
@@ -147,7 +158,7 @@ def eval_policy(
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
info = {
|
||||
info = { # TODO: change to dataclass
|
||||
"per_episode": [
|
||||
{
|
||||
"episode_ix": i,
|
||||
@@ -180,6 +191,13 @@ def eval_policy(
|
||||
|
||||
|
||||
def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
""" Evaluate a policy.
|
||||
|
||||
Args:
|
||||
cfg: The configuration (DictConfig).
|
||||
out_dir: The directory to save the evaluation results (JSON file and videos)
|
||||
stats_path: The path to the stats file.
|
||||
"""
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -195,6 +213,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("Making transforms.")
|
||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
|
||||
|
||||
logging.info("Making environment.")
|
||||
@@ -229,19 +248,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
"""Returns path1 relative to path2."""
|
||||
path1 = path1.absolute()
|
||||
path2 = path2.absolute()
|
||||
try:
|
||||
return path1.relative_to(path2)
|
||||
except ValueError: # most likely because path1 is not a subpath of path2
|
||||
common_parts = Path(osp.commonpath([path1, path2])).parts
|
||||
return Path(
|
||||
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
@@ -259,19 +265,15 @@ if __name__ == "__main__":
|
||||
|
||||
if args.config is not None:
|
||||
# Note: For the config_path, Hydra wants a path relative to this script file.
|
||||
hydra.initialize(
|
||||
config_path=str(
|
||||
_relative_path_between(Path(args.config).absolute().parent, Path(__file__).parent)
|
||||
)
|
||||
)
|
||||
cfg = hydra.compose(Path(args.config).stem, args.overrides)
|
||||
cfg = init_hydra_config(args.config, args.overrides)
|
||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
||||
stats_path = None
|
||||
elif args.hub_id is not None:
|
||||
folder = Path(snapshot_download(args.hub_id, revision="v1.0"))
|
||||
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))
|
||||
cfg = hydra.compose("config", args.overrides)
|
||||
cfg.policy.pretrained_model_path = folder / "model.pt"
|
||||
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||
cfg = init_hydra_config(
|
||||
folder / "config.yaml", [*args.overrides]
|
||||
# folder / "config.yaml" # , [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||
)
|
||||
stats_path = folder / "stats.pth"
|
||||
|
||||
eval(
|
||||
|
||||
116
poetry.lock
generated
116
poetry.lock
generated
@@ -327,6 +327,35 @@ files = [
|
||||
{file = "cloudpickle-3.0.0.tar.gz", hash = "sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "3.29.0.1"
|
||||
description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "cmake-3.29.0.1-py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:ec8b39fdeb75c48fd5a2894658a1ca75f94fb49b421c1f753d86d3e5d5e9f196"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ead7dc5176a6c6347b3fc19532c25ec328f9279b6213902ac930242334e7b621"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f7c7dabf3dd40cb830d2eded43d51c5a3737625bbc5ab6916041c04e352b74f8"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de8f55198f4a820daf2c57645a4bb8cd1064dc92d950ad95be14c5ffabc15bd4"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bb8aaa3419eafd466931e4dcc161d3e5e6a82730ab508c75946ff4fc883b3f6"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e956e4b6c2d8d6bbee399bf3c77e5b901d916fe8f35d6b2f58444d5892c4602b"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a8f7c8b07e6ab0dd444c5b74e658d5013ca0da456041029f734d751090bb7ec"},
|
||||
{file = "cmake-3.29.0.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22493049b6383ea2baa7237a326c2914ab4a7b3e1642f4233245e3a34aae39f6"},
|
||||
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:d09573411901fc8a3ede7433713c000ac7b81cda0d771874e9182770acf29eb4"},
|
||||
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_i686.whl", hash = "sha256:c85e35ec572e54152154637f24d6bde316fbcf94dfea644bd8f22b1855a09abe"},
|
||||
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:a36f3b19a5caa6c63aa70bfa0b262bae8d296e68c6e6d8e918edc5c51d952bf8"},
|
||||
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:85a0e28eaaee311d50fcee60f730e5a44a65b3cefc556a1163bbabd7328acd60"},
|
||||
{file = "cmake-3.29.0.1-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:58879ef15dd8344e1583a36cead794fb0fee13f78c590a56749283ac2e27d30a"},
|
||||
{file = "cmake-3.29.0.1-py3-none-win32.whl", hash = "sha256:068a3e7461dd9e487f5f3f720ae8072c2eeb37239ebe4642c4ae29058d83347f"},
|
||||
{file = "cmake-3.29.0.1-py3-none-win_amd64.whl", hash = "sha256:c097892b3653e2d2d41d055c80a9af029fdd24f9eff2ecb66576e4da8b85b1c7"},
|
||||
{file = "cmake-3.29.0.1-py3-none-win_arm64.whl", hash = "sha256:1808071047cb49ed0fe2359e4c310b49880c2805cad4ea9f03c959f51b881ac7"},
|
||||
{file = "cmake-3.29.0.1.tar.gz", hash = "sha256:ec49a7a4480959c229d9d2aa77f7885859c17a45fc66981aaf4551ceffb4d030"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
test = ["coverage (>=4.2)", "pytest (>=3.0.3)", "pytest-cov (>=2.4.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
@@ -338,6 +367,73 @@ files = [
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "7.4.4"
|
||||
description = "Code coverage measurement for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "coverage-7.4.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0be5efd5127542ef31f165de269f77560d6cdef525fffa446de6f7e9186cfb2"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ccd341521be3d1b3daeb41960ae94a5e87abe2f46f17224ba5d6f2b8398016cf"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fa497a8ab37784fbb20ab699c246053ac294d13fc7eb40ec007a5043ec91f8"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1a93009cb80730c9bca5d6d4665494b725b6e8e157c1cb7f2db5b4b122ea562"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:690db6517f09336559dc0b5f55342df62370a48f5469fabf502db2c6d1cffcd2"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:09c3255458533cb76ef55da8cc49ffab9e33f083739c8bd4f58e79fecfe288f7"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ce1415194b4a6bd0cdcc3a1dfbf58b63f910dcb7330fe15bdff542c56949f87"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b91cbc4b195444e7e258ba27ac33769c41b94967919f10037e6355e998af255c"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-win32.whl", hash = "sha256:598825b51b81c808cb6f078dcb972f96af96b078faa47af7dfcdf282835baa8d"},
|
||||
{file = "coverage-7.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:09ef9199ed6653989ebbcaacc9b62b514bb63ea2f90256e71fea3ed74bd8ff6f"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0f9f50e7ef2a71e2fae92774c99170eb8304e3fdf9c8c3c7ae9bab3e7229c5cf"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:623512f8ba53c422fcfb2ce68362c97945095b864cda94a92edbaf5994201083"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0513b9508b93da4e1716744ef6ebc507aff016ba115ffe8ecff744d1322a7b63"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40209e141059b9370a2657c9b15607815359ab3ef9918f0196b6fccce8d3230f"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a2b2b78c78293782fd3767d53e6474582f62443d0504b1554370bde86cc8227"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:73bfb9c09951125d06ee473bed216e2c3742f530fc5acc1383883125de76d9cd"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f384c3cc76aeedce208643697fb3e8437604b512255de6d18dae3f27655a384"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:54eb8d1bf7cacfbf2a3186019bcf01d11c666bd495ed18717162f7eb1e9dd00b"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-win32.whl", hash = "sha256:cac99918c7bba15302a2d81f0312c08054a3359eaa1929c7e4b26ebe41e9b286"},
|
||||
{file = "coverage-7.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:b14706df8b2de49869ae03a5ccbc211f4041750cd4a66f698df89d44f4bd30ec"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:201bef2eea65e0e9c56343115ba3814e896afe6d36ffd37bab783261db430f76"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:41c9c5f3de16b903b610d09650e5e27adbfa7f500302718c9ffd1c12cf9d6818"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d898fe162d26929b5960e4e138651f7427048e72c853607f2b200909794ed978"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ea79bb50e805cd6ac058dfa3b5c8f6c040cb87fe83de10845857f5535d1db70"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce4b94265ca988c3f8e479e741693d143026632672e3ff924f25fab50518dd51"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00838a35b882694afda09f85e469c96367daa3f3f2b097d846a7216993d37f4c"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fdfafb32984684eb03c2d83e1e51f64f0906b11e64482df3c5db936ce3839d48"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:69eb372f7e2ece89f14751fbcbe470295d73ed41ecd37ca36ed2eb47512a6ab9"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-win32.whl", hash = "sha256:137eb07173141545e07403cca94ab625cc1cc6bc4c1e97b6e3846270e7e1fea0"},
|
||||
{file = "coverage-7.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:d71eec7d83298f1af3326ce0ff1d0ea83c7cb98f72b577097f9083b20bdaf05e"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ae728ff3b5401cc320d792866987e7e7e880e6ebd24433b70a33b643bb0384"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cc4f1358cb0c78edef3ed237ef2c86056206bb8d9140e73b6b89fbcfcbdd40e1"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8130a2aa2acb8788e0b56938786c33c7c98562697bf9f4c7d6e8e5e3a0501e4a"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf271892d13e43bc2b51e6908ec9a6a5094a4df1d8af0bfc360088ee6c684409"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4cdc86d54b5da0df6d3d3a2f0b710949286094c3a6700c21e9015932b81447e"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ae71e7ddb7a413dd60052e90528f2f65270aad4b509563af6d03d53e979feafd"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:38dd60d7bf242c4ed5b38e094baf6401faa114fc09e9e6632374388a404f98e7"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa5b1c1bfc28384f1f53b69a023d789f72b2e0ab1b3787aae16992a7ca21056c"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-win32.whl", hash = "sha256:dfa8fe35a0bb90382837b238fff375de15f0dcdb9ae68ff85f7a63649c98527e"},
|
||||
{file = "coverage-7.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:b2991665420a803495e0b90a79233c1433d6ed77ef282e8e152a324bbbc5e0c8"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3b799445b9f7ee8bf299cfaed6f5b226c0037b74886a4e11515e569b36fe310d"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b4d33f418f46362995f1e9d4f3a35a1b6322cb959c31d88ae56b0298e1c22357"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aadacf9a2f407a4688d700e4ebab33a7e2e408f2ca04dbf4aef17585389eff3e"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c95949560050d04d46b919301826525597f07b33beba6187d04fa64d47ac82e"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff7687ca3d7028d8a5f0ebae95a6e4827c5616b31a4ee1192bdfde697db110d4"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5fc1de20b2d4a061b3df27ab9b7c7111e9a710f10dc2b84d33a4ab25065994ec"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:c74880fc64d4958159fbd537a091d2a585448a8f8508bf248d72112723974cbd"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:742a76a12aa45b44d236815d282b03cfb1de3b4323f3e4ec933acfae08e54ade"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-win32.whl", hash = "sha256:d89d7b2974cae412400e88f35d86af72208e1ede1a541954af5d944a8ba46c57"},
|
||||
{file = "coverage-7.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ca28a302acb19b6af89e90f33ee3e1906961f94b54ea37de6737b7ca9d8827c"},
|
||||
{file = "coverage-7.4.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b2c5edc4ac10a7ef6605a966c58929ec6c1bd0917fb8c15cb3363f65aa40e677"},
|
||||
{file = "coverage-7.4.4.tar.gz", hash = "sha256:c901df83d097649e257e803be22592aedfd5182f07b3cc87d640bbb9afd50f49"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
|
||||
|
||||
[package.extras]
|
||||
toml = ["tomli"]
|
||||
|
||||
[[package]]
|
||||
name = "debugpy"
|
||||
version = "1.8.1"
|
||||
@@ -2307,6 +2403,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""}
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "5.0.0"
|
||||
description = "Pytest plugin for measuring coverage."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
|
||||
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
coverage = {version = ">=5.2.1", extras = ["toml"]}
|
||||
pytest = ">=4.6"
|
||||
|
||||
[package.extras]
|
||||
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
@@ -3475,4 +3589,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "99addbfc02bcd35a308f4ecc5b4285c9c5054118f4aadea27650d8bf355d9616"
|
||||
content-hash = "174c7d42f8039eedd2c447a4e6cae5169782cbd94346b5606572a0010194ca05"
|
||||
|
||||
@@ -51,12 +51,14 @@ huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"}
|
||||
robomimic = "0.2.0"
|
||||
gymnasium-robotics = "^1.2.4"
|
||||
gymnasium = "^0.29.1"
|
||||
cmake = "^3.29.0.1"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pre-commit = "^3.6.2"
|
||||
debugpy = "^1.8.1"
|
||||
pytest = "^8.1.0"
|
||||
pytest-cov = "^5.0.0"
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
@@ -2,8 +2,9 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
|
||||
from .utils import DEVICE, init_config
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -18,7 +19,10 @@ from .utils import DEVICE, init_config
|
||||
],
|
||||
)
|
||||
def test_factory(env_name, dataset_id):
|
||||
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"])
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]
|
||||
)
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
for key in offline_buffer.image_keys:
|
||||
img = offline_buffer[0].get(key)
|
||||
|
||||
@@ -4,11 +4,13 @@ import torch
|
||||
from torchrl.envs.utils import check_env_specs, step_mdp
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
|
||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.pusht.env import PushtEnv
|
||||
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
|
||||
from .utils import DEVICE, init_config
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
def print_spec_rollout(env):
|
||||
@@ -38,6 +40,26 @@ def print_spec_rollout(env):
|
||||
print("data from rollout:", simple_rollout(100))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task,from_pixels,pixels_only",
|
||||
[
|
||||
("sim_insertion", True, False),
|
||||
("sim_insertion", True, True),
|
||||
("sim_transfer_cube", True, False),
|
||||
("sim_transfer_cube", True, True),
|
||||
],
|
||||
)
|
||||
def test_aloha(task, from_pixels, pixels_only):
|
||||
env = AlohaEnv(
|
||||
task,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=[3, 480, 640] if from_pixels else None,
|
||||
)
|
||||
# print_spec_rollout(env)
|
||||
check_env_specs(env)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task,from_pixels,pixels_only",
|
||||
[
|
||||
@@ -89,7 +111,10 @@ def test_pusht(from_pixels, pixels_only):
|
||||
],
|
||||
)
|
||||
def test_factory(env_name):
|
||||
cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"])
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||
)
|
||||
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
|
||||
@@ -1,19 +1,70 @@
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"examples/1_visualize_dataset.py",
|
||||
"examples/2_evaluate_pretrained_policy.py",
|
||||
"examples/3_train_policy.py",
|
||||
],
|
||||
)
|
||||
def test_example(path):
|
||||
|
||||
with open(path, 'r') as file:
|
||||
def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str:
|
||||
for f, r in zip(finds, replaces):
|
||||
assert f in text
|
||||
text = text.replace(f, r)
|
||||
return text
|
||||
|
||||
|
||||
def test_example_1():
|
||||
path = "examples/1_visualize_dataset.py"
|
||||
|
||||
with open(path, "r") as file:
|
||||
file_contents = file.read()
|
||||
exec(file_contents)
|
||||
|
||||
if path == "examples/1_visualize_dataset.py":
|
||||
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
|
||||
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
|
||||
|
||||
|
||||
def test_examples_3_and_2():
|
||||
"""
|
||||
Train a model with example 3, check the outputs.
|
||||
Evaluate the trained model with example 2, check the outputs.
|
||||
"""
|
||||
|
||||
path = "examples/3_train_policy.py"
|
||||
|
||||
with open(path, "r") as file:
|
||||
file_contents = file.read()
|
||||
|
||||
# Do less steps and use CPU.
|
||||
file_contents = _find_and_replace(
|
||||
file_contents,
|
||||
['"offline_steps=5000"', '"device=cuda"'],
|
||||
['"offline_steps=1"', '"device=cpu"'],
|
||||
)
|
||||
|
||||
exec(file_contents)
|
||||
|
||||
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||
|
||||
path = "examples/2_evaluate_pretrained_policy.py"
|
||||
|
||||
with open(path, "r") as file:
|
||||
file_contents = file.read()
|
||||
|
||||
# Do less evals, use CPU, and use the local model.
|
||||
file_contents = _find_and_replace(
|
||||
file_contents,
|
||||
[
|
||||
'"eval_episodes=10"',
|
||||
'"rollout_batch_size=10"',
|
||||
'"device=cuda"',
|
||||
'# folder = Path("outputs/train/example_pusht_diffusion")',
|
||||
'hub_id = "lerobot/diffusion_policy_pusht_image"',
|
||||
"folder = Path(snapshot_download(hub_id)",
|
||||
],
|
||||
[
|
||||
'"eval_episodes=1"',
|
||||
'"rollout_batch_size=1"',
|
||||
'"device=cpu"',
|
||||
'folder = Path("outputs/train/example_pusht_diffusion")',
|
||||
"",
|
||||
"",
|
||||
],
|
||||
)
|
||||
|
||||
assert Path(f"outputs/train/example_pusht_diffusion").exists()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from omegaconf import open_dict
|
||||
import pytest
|
||||
from tensordict import TensorDict
|
||||
from tensordict.nn import TensorDictModule
|
||||
@@ -10,8 +9,8 @@ from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
|
||||
from .utils import DEVICE, init_config
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
@@ -34,7 +33,8 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
|
||||
- Updating the policy.
|
||||
- Using the policy to select actions at inference time.
|
||||
"""
|
||||
cfg = init_config(
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
import os
|
||||
import hydra
|
||||
from hydra import compose, initialize
|
||||
|
||||
CONFIG_PATH = "../lerobot/configs"
|
||||
# Pass this as the first argument to init_hydra_config.
|
||||
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||
|
||||
DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
|
||||
|
||||
def init_config(config_name="default", overrides=None):
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=CONFIG_PATH)
|
||||
cfg = compose(config_name=config_name, overrides=overrides)
|
||||
return cfg
|
||||
|
||||
Reference in New Issue
Block a user