forked from tangger/lerobot
Compare commits
89 Commits
user/rcade
...
fix_path
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d374873849 | ||
|
|
9c88071bc7 | ||
|
|
5805a7ffb1 | ||
|
|
41521f7e96 | ||
|
|
b10c9507d4 | ||
|
|
a311d38796 | ||
|
|
19730b3412 | ||
|
|
95e84079ef | ||
|
|
8e856f1bf7 | ||
|
|
8c2b47752a | ||
|
|
f515cb6efd | ||
|
|
c3f8d14fd8 | ||
|
|
8c56770318 | ||
|
|
998dd2b874 | ||
|
|
7331df81d2 | ||
|
|
2c5d49cad5 | ||
|
|
5881eec376 | ||
|
|
29c73844b1 | ||
|
|
f9258898ff | ||
|
|
9d002032d1 | ||
|
|
060bac7672 | ||
|
|
337208f28d | ||
|
|
48e70e044e | ||
|
|
4449c06823 | ||
|
|
a94800fc8a | ||
|
|
a207b416b7 | ||
|
|
78690d197f | ||
|
|
6d6c84b4a3 | ||
|
|
772a826bf2 | ||
|
|
2cb8ae5037 | ||
|
|
fab2b3240b | ||
|
|
84a1647c01 | ||
|
|
ccd5dc5a42 | ||
|
|
c1e9c13ade | ||
|
|
00fe4f4f18 | ||
|
|
225eebde40 | ||
|
|
816b2e9d63 | ||
|
|
a7ef4a6a33 | ||
|
|
d4ea4f0ad1 | ||
|
|
f54ee7cda0 | ||
|
|
134009f337 | ||
|
|
7982425670 | ||
|
|
6c867d78ef | ||
|
|
302b78962c | ||
|
|
59397fb44a | ||
|
|
1cc621ec36 | ||
|
|
471ebfef62 | ||
|
|
30753d879c | ||
|
|
c6fb40fb29 | ||
|
|
fa7a947acc | ||
|
|
450e32e4b5 | ||
|
|
0da85b2cef | ||
|
|
f2c7ab5b3b | ||
|
|
cde866dac0 | ||
|
|
a54a0feb63 | ||
|
|
f440a681ad | ||
|
|
35bd577deb | ||
|
|
327f60e4be | ||
|
|
74ad9d5154 | ||
|
|
89eaab140b | ||
|
|
7dbdbb051c | ||
|
|
4cc7e1539e | ||
|
|
f1e2837d63 | ||
|
|
54b05bfb77 | ||
|
|
524d29aa80 | ||
|
|
c2c0ef9927 | ||
|
|
66373e9b13 | ||
|
|
7d33b437fa | ||
|
|
b9dc3be463 | ||
|
|
86ec62f98a | ||
|
|
52bdfc659e | ||
|
|
d782b029e1 | ||
|
|
49c0955f97 | ||
|
|
eed24b083a | ||
|
|
f95ecd66fc | ||
|
|
d34c0a3c49 | ||
|
|
11a5a7ca45 | ||
|
|
a6d353c419 | ||
|
|
d6556e6519 | ||
|
|
12af67066d | ||
|
|
7a20ef65f6 | ||
|
|
2f80d71c3e | ||
|
|
d4e0849970 | ||
|
|
e132a267aa | ||
|
|
a027f4edfb | ||
|
|
570f8d01df | ||
|
|
7938adcdfc | ||
|
|
20c08bb740 | ||
|
|
2bcf2631b9 |
1
.gitattributes
vendored
Normal file
1
.gitattributes
vendored
Normal file
@@ -0,0 +1 @@
|
||||
*.memmap filter=lfs diff=lfs merge=lfs -text
|
||||
3126
.github/poetry/cpu/poetry.lock
generated
vendored
Normal file
3126
.github/poetry/cpu/poetry.lock
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
107
.github/poetry/cpu/pyproject.toml
vendored
Normal file
107
.github/poetry/cpu/pyproject.toml
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
[tool.poetry]
|
||||
name = "lerobot"
|
||||
version = "0.1.0"
|
||||
description = "Le robot is learning"
|
||||
authors = [
|
||||
"Rémi Cadène <re.cadene@gmail.com>",
|
||||
"Simon Alibert <alibert.sim@gmail.com>",
|
||||
]
|
||||
repository = "https://github.com/Cadene/lerobot"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Topic :: Software Development :: Build Tools",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
packages = [{include = "lerobot"}]
|
||||
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
cython = "^3.0.8"
|
||||
termcolor = "^2.4.0"
|
||||
omegaconf = "^2.3.0"
|
||||
dm-env = "^1.6"
|
||||
pandas = "^2.2.1"
|
||||
wandb = "^0.16.3"
|
||||
moviepy = "^1.0.3"
|
||||
imageio = {extras = ["pyav"], version = "^2.34.0"}
|
||||
gdown = "^5.1.0"
|
||||
hydra-core = "^1.3.2"
|
||||
einops = "^0.7.0"
|
||||
pygame = "^2.5.2"
|
||||
pymunk = "^6.6.0"
|
||||
zarr = "^2.17.0"
|
||||
shapely = "^2.0.3"
|
||||
scikit-image = "^0.22.0"
|
||||
numba = "^0.59.0"
|
||||
mpmath = "^1.3.0"
|
||||
torch = {version = "^2.2.1", source = "torch-cpu"}
|
||||
tensordict = {git = "https://github.com/pytorch/tensordict"}
|
||||
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
|
||||
mujoco = "^3.1.2"
|
||||
mujoco-py = "^2.1.2.14"
|
||||
gym = "^0.26.2"
|
||||
opencv-python = "^4.9.0.80"
|
||||
diffusers = "^0.26.3"
|
||||
torchvision = {version = "^0.17.1", source = "torch-cpu"}
|
||||
h5py = "^3.10.0"
|
||||
dm = "^1.3"
|
||||
dm-control = "^1.0.16"
|
||||
huggingface-hub = "^0.21.4"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pre-commit = "^3.6.2"
|
||||
debugpy = "^1.8.1"
|
||||
pytest = "^8.1.0"
|
||||
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = "torch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
priority = "supplemental"
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
target-version = "py310"
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"venv",
|
||||
]
|
||||
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||
|
||||
|
||||
[tool.poetry-dynamic-versioning]
|
||||
enable = true
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
|
||||
build-backend = "poetry_dynamic_versioning.backend"
|
||||
144
.github/workflows/test.yml
vendored
Normal file
144
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,144 @@
|
||||
name: Test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
types: [opened, synchronize, reopened, labeled]
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
test:
|
||||
if: |
|
||||
${{ github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'CI') }} ||
|
||||
${{ github.event_name == 'push' }}
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
POETRY_VERSION: 1.8.1
|
||||
DATA_DIR: tests/data
|
||||
TMPDIR: ~/tmp
|
||||
TEMP: ~/tmp
|
||||
TMP: ~/tmp
|
||||
PYOPENGL_PLATFORM: egl
|
||||
MUJOCO_GL: egl
|
||||
LEROBOT_TESTS_DEVICE: cpu
|
||||
steps:
|
||||
#----------------------------------------------
|
||||
# check-out repo and set-up python
|
||||
#----------------------------------------------
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up python
|
||||
id: setup-python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
#----------------------------------------------
|
||||
# install & configure poetry
|
||||
#----------------------------------------------
|
||||
- name: Load cached Poetry installation
|
||||
id: restore-poetry-cache
|
||||
uses: actions/cache/restore@v3
|
||||
with:
|
||||
path: ~/.local
|
||||
key: poetry-${{ env.POETRY_VERSION }}
|
||||
|
||||
- name: Install Poetry
|
||||
if: steps.restore-poetry-cache.outputs.cache-hit != 'true'
|
||||
uses: snok/install-poetry@v1
|
||||
with:
|
||||
version: ${{ env.POETRY_VERSION }}
|
||||
virtualenvs-create: true
|
||||
installer-parallel: true
|
||||
|
||||
- name: Save cached Poetry installation
|
||||
if: |
|
||||
steps.restore-poetry-cache.outputs.cache-hit != 'true' &&
|
||||
github.ref_name == 'main'
|
||||
id: save-poetry-cache
|
||||
uses: actions/cache/save@v3
|
||||
with:
|
||||
path: ~/.local
|
||||
key: poetry-${{ env.POETRY_VERSION }}
|
||||
|
||||
- name: Configure Poetry
|
||||
run: poetry config virtualenvs.in-project true
|
||||
|
||||
#----------------------------------------------
|
||||
# install dependencies
|
||||
#----------------------------------------------
|
||||
# TODO(aliberts): move to gpu runners
|
||||
- name: Select cpu dependencies # HACK
|
||||
run: cp -t . .github/poetry/cpu/pyproject.toml .github/poetry/cpu/poetry.lock
|
||||
|
||||
- name: Load cached venv
|
||||
id: restore-dependencies-cache
|
||||
uses: actions/cache/restore@v3
|
||||
with:
|
||||
path: .venv
|
||||
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
mkdir ~/tmp
|
||||
poetry install --no-interaction --no-root
|
||||
|
||||
- name: Save cached venv
|
||||
if: |
|
||||
steps.restore-dependencies-cache.outputs.cache-hit != 'true' &&
|
||||
github.ref_name == 'main'
|
||||
id: save-dependencies-cache
|
||||
uses: actions/cache/save@v3
|
||||
with:
|
||||
path: .venv
|
||||
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
|
||||
|
||||
- name: Install libegl1-mesa-dev (to use MUJOCO_GL=egl)
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||
|
||||
#----------------------------------------------
|
||||
# install project
|
||||
#----------------------------------------------
|
||||
- name: Install project
|
||||
run: poetry install --no-interaction
|
||||
|
||||
#----------------------------------------------
|
||||
# run tests
|
||||
#----------------------------------------------
|
||||
- name: Run tests
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
pytest tests
|
||||
|
||||
- name: Test train pusht end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.job.name=pusht \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
offline_steps=2 \
|
||||
online_steps=0 \
|
||||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=1 \
|
||||
hydra.run.dir=tests/outputs/
|
||||
|
||||
- name: Test eval pusht end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/eval.py \
|
||||
hydra.job.name=pusht \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
eval_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/models/1.pt
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,6 +1,3 @@
|
||||
# Custom
|
||||
diffusion_policy
|
||||
|
||||
# Logging
|
||||
logs
|
||||
tmp
|
||||
@@ -54,6 +51,7 @@ pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
!tests/data
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
exclude: ^(data/|tests/|diffusion_policy/)
|
||||
exclude: ^(data/|tests/)
|
||||
default_language_version:
|
||||
python: python3.10
|
||||
repos:
|
||||
|
||||
278
LICENSE
Normal file
278
LICENSE
Normal file
@@ -0,0 +1,278 @@
|
||||
Copyright 2024 The Hugging Face team. All rights reserved.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
||||
## Some of lerobot's code is derived from Diffusion Policy, which is subject to the following copyright notice:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Columbia Artificial Intelligence and Robotics Lab
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
## Some of lerobot's code is derived from FOWM, which is subject to the following copyright notice:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Yunhai Feng
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
## Some of lerobot's code is derived from ALOHA, which is subject to the following copyright notice:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Tony Z. Zhao
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
89
README.md
89
README.md
@@ -2,7 +2,7 @@
|
||||
|
||||
## Installation
|
||||
|
||||
Create a virtual environment with python 3.10, e.g. using `conda`:
|
||||
Create a virtual environment with Python 3.10, e.g. using `conda`:
|
||||
```
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
@@ -24,11 +24,9 @@ mkdir ~/tmp
|
||||
export TMPDIR='~/tmp'
|
||||
```
|
||||
|
||||
Install `diffusion_policy` #HACK
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
|
||||
```
|
||||
# from this directory
|
||||
git clone https://github.com/real-stanford/diffusion_policy
|
||||
cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/
|
||||
wandb login
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -61,19 +59,10 @@ env=pusht
|
||||
|
||||
## TODO
|
||||
|
||||
- [x] priority update doesnt match FOWM or original paper
|
||||
- [x] self.step=100000 should be updated at every step to adjust to horizon of planner
|
||||
- [ ] prefetch replay buffer to speedup training
|
||||
- [ ] parallelize env to speedup eval
|
||||
- [ ] clean checkpointing / loading
|
||||
- [ ] clean logging
|
||||
- [ ] clean config
|
||||
- [ ] clean hyperparameter tuning
|
||||
- [ ] add pusht
|
||||
- [ ] add aloha
|
||||
- [ ] add act
|
||||
- [ ] add diffusion
|
||||
- [ ] add aloha 2
|
||||
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/users/Cadene/projects/1)
|
||||
|
||||
Ask [Remi Cadene](re.cadene@gmail.com) for access if needed.
|
||||
|
||||
|
||||
## Profile
|
||||
|
||||
@@ -114,7 +103,69 @@ pre-commit install
|
||||
pre-commit run -a
|
||||
```
|
||||
|
||||
**Adding dependencies (temporary)**
|
||||
|
||||
Right now, for the CI to work, whenever a new dependency is added it needs to be also added to the cpu env, eg:
|
||||
|
||||
```
|
||||
# Run in this directory, adds the package to the main env with cuda
|
||||
poetry add some-package
|
||||
|
||||
# Adds the same package to the cpu env
|
||||
cd .github/poetry/cpu && poetry add some-package
|
||||
```
|
||||
|
||||
**Tests**
|
||||
|
||||
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
|
||||
|
||||
On Mac:
|
||||
```
|
||||
pytest -sx tests
|
||||
brew install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
On Ubuntu:
|
||||
```
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
Pull artifacts if they're not in [tests/data](tests/data)
|
||||
```
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
When adding a new dataset, mock it with
|
||||
```
|
||||
python tests/scripts/mock_dataset.py --in-data-dir data/<dataset_id> --out-data-dir tests/data/<dataset_id>
|
||||
```
|
||||
|
||||
Run tests
|
||||
```
|
||||
DATA_DIR="tests/data" pytest -sx tests
|
||||
```
|
||||
|
||||
**Datasets**
|
||||
|
||||
To add a pytorch rl dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
|
||||
```
|
||||
huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential
|
||||
```
|
||||
|
||||
Then you can upload it to the hub with:
|
||||
```
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload --repo-type dataset $HF_USER/$DATASET data/$DATASET
|
||||
```
|
||||
|
||||
For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used:
|
||||
```
|
||||
HF_USER=cadene
|
||||
DATASET=pusht
|
||||
```
|
||||
|
||||
|
||||
## Acknowledgment
|
||||
- Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
|
||||
- Our TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
|
||||
- Our ACT policy and ALOHA environment are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha/)
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
name: lerobot
|
||||
dependencies:
|
||||
- python=3.8.16
|
||||
- pytorch::pytorch=1.13.1
|
||||
- pytorch::torchvision=0.14.1
|
||||
- nvidia::cudatoolkit=11.7
|
||||
- anaconda::pip
|
||||
- pip:
|
||||
- cython==0.29.33
|
||||
- mujoco==2.3.2
|
||||
- mujoco-py==2.1.2.14
|
||||
- termcolor
|
||||
- omegaconf
|
||||
- gym==0.21.0
|
||||
- dm-env==1.6
|
||||
- pandas
|
||||
- wandb
|
||||
- moviepy
|
||||
- imageio
|
||||
- gdown
|
||||
# - -e benchmarks/d4rl
|
||||
# TODO: verify this works
|
||||
- git+https://github.com/nicklashansen/simxarm.git@main#egg=simxarm
|
||||
@@ -0,0 +1 @@
|
||||
from lerobot.__version__ import __version__ # noqa: F401
|
||||
|
||||
@@ -1 +1,8 @@
|
||||
__version__ = "0.0.0"
|
||||
""" To enable `lerobot.__version__` """
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
try:
|
||||
__version__ = version("lerobot")
|
||||
except PackageNotFoundError:
|
||||
__version__ = "unknown"
|
||||
|
||||
159
lerobot/common/datasets/abstract.py
Normal file
159
lerobot/common/datasets/abstract.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.envs.transforms.transforms import Compose
|
||||
|
||||
|
||||
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
self.dataset_id = dataset_id
|
||||
self.shuffle = shuffle
|
||||
self.root = root
|
||||
storage = self._download_or_load_dataset()
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=ImmutableDatasetWriter() if writer is None else writer,
|
||||
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def stats_patterns(self) -> dict:
|
||||
return {
|
||||
("observation", "state"): "b c -> 1 c",
|
||||
("observation", "image"): "b c h w -> 1 c 1 1",
|
||||
("action",): "b c -> 1 c",
|
||||
}
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list:
|
||||
return [("observation", "image")]
|
||||
|
||||
@property
|
||||
def num_cameras(self) -> int:
|
||||
return len(self.image_keys)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self._storage._storage["episode"].unique())
|
||||
|
||||
@property
|
||||
def transform(self):
|
||||
return self._transform
|
||||
|
||||
def set_transform(self, transform):
|
||||
if not isinstance(transform, Compose):
|
||||
# required since torchrl calls `len(self._transform)` downstream
|
||||
if isinstance(transform, list):
|
||||
self._transform = Compose(*transform)
|
||||
else:
|
||||
self._transform = Compose(transform)
|
||||
else:
|
||||
self._transform = transform
|
||||
|
||||
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
||||
stats_path = Path(self.data_dir) / "stats.pth"
|
||||
if stats_path.exists():
|
||||
stats = torch.load(stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
stats = self._compute_stats(num_batch, batch_size)
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
def _download_or_load_dataset(self) -> torch.StorageBase:
|
||||
if self.root is None:
|
||||
self.data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")
|
||||
else:
|
||||
self.data_dir = self.root / self.dataset_id
|
||||
return TensorStorage(TensorDict.load_memmap(self.data_dir))
|
||||
|
||||
def _compute_stats(self, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
storage=self._storage,
|
||||
batch_size=batch_size,
|
||||
prefetch=True,
|
||||
)
|
||||
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
|
||||
# compute mean, min, max
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
batch = rb.sample()
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
if key not in mean:
|
||||
# first batch initialize mean, min, max
|
||||
mean[key] = einops.reduce(batch[key], pattern, "mean")
|
||||
max[key] = einops.reduce(batch[key], pattern, "max")
|
||||
min[key] = einops.reduce(batch[key], pattern, "min")
|
||||
else:
|
||||
mean[key] += einops.reduce(batch[key], pattern, "mean")
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
batch = rb.sample()
|
||||
|
||||
for key in self.stats_patterns:
|
||||
mean[key] /= num_batch
|
||||
|
||||
# compute std, min, max
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
batch = rb.sample()
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
if key not in std:
|
||||
# first batch initialize std
|
||||
std[key] = (batch_mean - mean[key]) ** 2
|
||||
else:
|
||||
std[key] += (batch_mean - mean[key]) ** 2
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
for key in self.stats_patterns:
|
||||
std[key] = torch.sqrt(std[key] / num_batch)
|
||||
|
||||
stats = TensorDict({}, batch_size=[])
|
||||
for key in self.stats_patterns:
|
||||
stats[(*key, "mean")] = mean[key]
|
||||
stats[(*key, "std")] = std[key]
|
||||
stats[(*key, "max")] = max[key]
|
||||
stats[(*key, "min")] = min[key]
|
||||
|
||||
if key[0] == "observation":
|
||||
# use same stats for the next observations
|
||||
stats[("next", *key)] = stats[key]
|
||||
return stats
|
||||
183
lerobot/common/datasets/aloha.py
Normal file
183
lerobot/common/datasets/aloha.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import gdown
|
||||
import h5py
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
|
||||
DATASET_IDS = [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
|
||||
FOLDER_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
|
||||
}
|
||||
|
||||
EP48_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
|
||||
}
|
||||
|
||||
EP49_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
|
||||
}
|
||||
|
||||
NUM_EPISODES = {
|
||||
"aloha_sim_insertion_human": 50,
|
||||
"aloha_sim_insertion_scripted": 50,
|
||||
"aloha_sim_transfer_cube_human": 50,
|
||||
"aloha_sim_transfer_cube_scripted": 50,
|
||||
}
|
||||
|
||||
EPISODE_LEN = {
|
||||
"aloha_sim_insertion_human": 500,
|
||||
"aloha_sim_insertion_scripted": 400,
|
||||
"aloha_sim_transfer_cube_human": 400,
|
||||
"aloha_sim_transfer_cube_scripted": 400,
|
||||
}
|
||||
|
||||
CAMERAS = {
|
||||
"aloha_sim_insertion_human": ["top"],
|
||||
"aloha_sim_insertion_scripted": ["top"],
|
||||
"aloha_sim_transfer_cube_human": ["top"],
|
||||
"aloha_sim_transfer_cube_scripted": ["top"],
|
||||
}
|
||||
|
||||
|
||||
def download(data_dir, dataset_id):
|
||||
assert dataset_id in DATASET_IDS
|
||||
assert dataset_id in FOLDER_URLS
|
||||
assert dataset_id in EP48_URLS
|
||||
assert dataset_id in EP49_URLS
|
||||
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
|
||||
|
||||
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
||||
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
|
||||
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
|
||||
|
||||
class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
assert dataset_id in DATASET_IDS
|
||||
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def stats_patterns(self) -> dict:
|
||||
d = {
|
||||
("observation", "state"): "b c -> 1 c",
|
||||
("action",): "b c -> 1 c",
|
||||
}
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1"
|
||||
return d
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list:
|
||||
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
if not raw_dir.is_dir():
|
||||
download(raw_dir, self.dataset_id)
|
||||
|
||||
total_num_frames = 0
|
||||
logging.info("Compute total number of frames to initialize offline buffer")
|
||||
for ep_id in range(NUM_EPISODES[self.dataset_id]):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
total_num_frames += ep["/action"].shape[0] - 1
|
||||
logging.info(f"{total_num_frames=}")
|
||||
|
||||
logging.info("Initialize and feed offline buffer")
|
||||
idxtd = 0
|
||||
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
ep_num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(ep_num_frames, 1, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "state"): state[:-1],
|
||||
"action": action[:-1],
|
||||
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
|
||||
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
|
||||
("next", "observation", "state"): state[1:],
|
||||
# TODO: compute reward and success
|
||||
# ("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
# ("next", "success"): success[1:],
|
||||
},
|
||||
batch_size=ep_num_frames - 1,
|
||||
)
|
||||
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||
ep_td["observation", "image", cam] = image[:-1]
|
||||
ep_td["next", "observation", "image", cam] = image[1:]
|
||||
|
||||
if ep_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
idxtd = idxtd + len(ep_td)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
@@ -1,36 +1,21 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||
|
||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
from lerobot.common.envs.transforms import NormalizeTransform, Prod
|
||||
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
||||
|
||||
# TODO(rcadene): implement
|
||||
|
||||
# dataset_d4rl = D4RLExperienceReplay(
|
||||
# dataset_id="maze2d-umaze-v1",
|
||||
# split_trajs=False,
|
||||
# batch_size=1,
|
||||
# sampler=SamplerWithoutReplacement(drop_last=False),
|
||||
# prefetch=4,
|
||||
# direct_download=True,
|
||||
# )
|
||||
|
||||
# dataset_openx = OpenXExperienceReplay(
|
||||
# "cmu_stretch",
|
||||
# batch_size=1,
|
||||
# num_slices=1,
|
||||
# #download="force",
|
||||
# streaming=False,
|
||||
# root="data",
|
||||
# )
|
||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
||||
# to load a subset of our datasets for faster continuous integration.
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_offline_buffer(cfg, sampler=None):
|
||||
def make_offline_buffer(
|
||||
cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None
|
||||
):
|
||||
if cfg.policy.balanced_sampling:
|
||||
assert cfg.online_steps > 0
|
||||
batch_size = None
|
||||
@@ -43,50 +28,104 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
pin_memory = cfg.device == "cuda"
|
||||
prefetch = cfg.prefetch
|
||||
|
||||
overwrite_sampler = sampler is not None
|
||||
if overwrite_batch_size is not None:
|
||||
batch_size = overwrite_batch_size
|
||||
|
||||
if not overwrite_sampler:
|
||||
if overwrite_prefetch is not None:
|
||||
prefetch = overwrite_prefetch
|
||||
|
||||
if overwrite_sampler is None:
|
||||
# TODO(rcadene): move batch_size outside
|
||||
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.policy.per_alpha,
|
||||
beta=cfg.policy.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
|
||||
if cfg.offline_prioritized_sampler:
|
||||
logging.info("use prioritized sampler for offline dataset")
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.policy.per_alpha,
|
||||
beta=cfg.policy.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
else:
|
||||
logging.info("use simple sampler for offline dataset")
|
||||
sampler = SliceSampler(
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
else:
|
||||
sampler = overwrite_sampler
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
||||
offline_buffer = SimxarmExperienceReplay(
|
||||
f"xarm_{cfg.env.task}_medium",
|
||||
# download="force",
|
||||
download=True,
|
||||
streaming=False,
|
||||
root=str(DATA_DIR),
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
|
||||
clsfunc = SimxarmExperienceReplay
|
||||
dataset_id = f"xarm_{cfg.env.task}_medium"
|
||||
|
||||
elif cfg.env.name == "pusht":
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
"pusht",
|
||||
streaming=False,
|
||||
root=DATA_DIR,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||
|
||||
clsfunc = PushtExperienceReplay
|
||||
dataset_id = "pusht"
|
||||
|
||||
elif cfg.env.name == "aloha":
|
||||
from lerobot.common.datasets.aloha import AlohaExperienceReplay
|
||||
|
||||
clsfunc = AlohaExperienceReplay
|
||||
dataset_id = f"aloha_{cfg.env.task}"
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
offline_buffer = clsfunc(
|
||||
dataset_id=dataset_id,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
root=DATA_DIR,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
img_keys = []
|
||||
for key in offline_buffer.image_keys:
|
||||
img_keys.append(("next", *key))
|
||||
img_keys += offline_buffer.image_keys
|
||||
else:
|
||||
img_keys = offline_buffer.image_keys
|
||||
|
||||
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
||||
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats()
|
||||
|
||||
# we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following)
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
|
||||
in_keys += img_keys
|
||||
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
|
||||
in_keys += [("next", *key) for key in img_keys]
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
||||
|
||||
offline_buffer.set_transform(transforms)
|
||||
|
||||
if not overwrite_sampler:
|
||||
num_steps = len(offline_buffer)
|
||||
index = torch.arange(0, num_steps, 1)
|
||||
index = torch.arange(0, offline_buffer.num_samples, 1)
|
||||
sampler.extend(index)
|
||||
|
||||
return offline_buffer
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
@@ -12,21 +9,18 @@ import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||
from lerobot.common.envs.transforms import NormalizeTransform
|
||||
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
||||
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
|
||||
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||
|
||||
@@ -54,8 +48,10 @@ def add_tee(
|
||||
angle,
|
||||
scale=30,
|
||||
color="LightSlateGray",
|
||||
mask=DEFAULT_TEE_MASK,
|
||||
mask=None,
|
||||
):
|
||||
if mask is None:
|
||||
mask = pymunk.ShapeFilter.ALL_MASKS()
|
||||
mass = 1
|
||||
length = 4
|
||||
vertices1 = [
|
||||
@@ -87,114 +83,37 @@ def add_tee(
|
||||
return body
|
||||
|
||||
|
||||
class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
num_slices: int = None,
|
||||
slice_len: int = None,
|
||||
pad: float = None,
|
||||
replacement: bool = None,
|
||||
streaming: bool = False,
|
||||
root: Path = None,
|
||||
sampler: Sampler = None,
|
||||
writer: Writer = None,
|
||||
collate_fn: Callable = None,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
transform: "torchrl.envs.Transform" = None, # noqa: F821
|
||||
split_trajs: bool = False,
|
||||
strict_length: bool = True,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
if streaming:
|
||||
raise NotImplementedError
|
||||
self.streaming = streaming
|
||||
self.dataset_id = dataset_id
|
||||
self.split_trajs = split_trajs
|
||||
self.shuffle = shuffle
|
||||
self.num_slices = num_slices
|
||||
self.slice_len = slice_len
|
||||
self.pad = pad
|
||||
|
||||
self.strict_length = strict_length
|
||||
if (self.num_slices is not None) and (self.slice_len is not None):
|
||||
raise ValueError("num_slices or slice_len can be not None, but not both.")
|
||||
if split_trajs:
|
||||
raise NotImplementedError
|
||||
|
||||
if root is None:
|
||||
root = _get_root_dir("pusht")
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
self.root = root
|
||||
if not self._is_downloaded():
|
||||
storage = self._download_and_preproc()
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||
|
||||
stats = self._compute_or_load_stats(storage)
|
||||
transform = NormalizeTransform(
|
||||
stats,
|
||||
in_keys=[
|
||||
# TODO(rcadene): imagenet normalization is applied inside diffusion policy
|
||||
# We need to automate this for tdmpc and others
|
||||
# ("observation", "image"),
|
||||
("observation", "state"),
|
||||
# TODO(rcadene): for tdmpc, we might want next image and state
|
||||
# ("next", "observation", "image"),
|
||||
# ("next", "observation", "state"),
|
||||
("action"),
|
||||
],
|
||||
mode="min_max",
|
||||
)
|
||||
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
transform.stats["observation", "state", "min"] = torch.tensor(
|
||||
[13.456424, 32.938293], dtype=torch.float32
|
||||
)
|
||||
transform.stats["observation", "state", "max"] = torch.tensor(
|
||||
[496.14618, 510.9579], dtype=torch.float32
|
||||
)
|
||||
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
if collate_fn is None:
|
||||
collate_fn = _collate_id
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=writer,
|
||||
collate_fn=collate_fn,
|
||||
dataset_id,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self._storage._storage["episode"].unique())
|
||||
|
||||
@property
|
||||
def data_path_root(self) -> Path:
|
||||
return None if self.streaming else self.root / self.dataset_id
|
||||
|
||||
def _is_downloaded(self) -> bool:
|
||||
return self.data_path_root.is_dir()
|
||||
|
||||
def _download_and_preproc(self):
|
||||
# download
|
||||
raw_dir = self.root / "raw"
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -208,6 +127,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
total_frames = dataset_dict["action"].shape[0]
|
||||
# to create test artifact
|
||||
# num_episodes = 1
|
||||
# total_frames = 50
|
||||
assert len(
|
||||
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
@@ -225,6 +147,8 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
idxtd = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
# to create test artifact
|
||||
# idx1 = 51
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
@@ -266,8 +190,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
print("before " + """episode = TensorDict(""")
|
||||
episode = TensorDict(
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "image"): image[:-1],
|
||||
("observation", "state"): agent_pos[:-1],
|
||||
@@ -286,120 +209,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idxtd : idxtd + len(episode)] = episode
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
|
||||
idx0 = idx1
|
||||
idxtd = idxtd + len(episode)
|
||||
idxtd = idxtd + len(ep_td)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
|
||||
def _compute_stats(self, storage, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
storage=storage,
|
||||
batch_size=batch_size,
|
||||
prefetch=True,
|
||||
)
|
||||
batch = rb.sample()
|
||||
|
||||
image_channels = batch["observation", "image"].shape[1]
|
||||
image_mean = torch.zeros(image_channels)
|
||||
image_std = torch.zeros(image_channels)
|
||||
image_max = torch.tensor([-math.inf] * image_channels)
|
||||
image_min = torch.tensor([math.inf] * image_channels)
|
||||
|
||||
state_channels = batch["observation", "state"].shape[1]
|
||||
state_mean = torch.zeros(state_channels)
|
||||
state_std = torch.zeros(state_channels)
|
||||
state_max = torch.tensor([-math.inf] * state_channels)
|
||||
state_min = torch.tensor([math.inf] * state_channels)
|
||||
|
||||
action_channels = batch["action"].shape[1]
|
||||
action_mean = torch.zeros(action_channels)
|
||||
action_std = torch.zeros(action_channels)
|
||||
action_max = torch.tensor([-math.inf] * action_channels)
|
||||
action_min = torch.tensor([math.inf] * action_channels)
|
||||
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
|
||||
|
||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||
image_max = torch.maximum(image_max, b_image_max)
|
||||
image_min = torch.maximum(image_min, b_image_min)
|
||||
state_max = torch.maximum(state_max, b_state_max)
|
||||
state_min = torch.maximum(state_min, b_state_min)
|
||||
action_max = torch.maximum(action_max, b_action_max)
|
||||
action_min = torch.maximum(action_min, b_action_min)
|
||||
|
||||
batch = rb.sample()
|
||||
|
||||
image_mean /= num_batch
|
||||
state_mean /= num_batch
|
||||
action_mean /= num_batch
|
||||
|
||||
for i in tqdm.tqdm(range(num_batch)):
|
||||
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
|
||||
image_std += (b_image_mean - image_mean) ** 2
|
||||
state_std += (b_state_mean - state_mean) ** 2
|
||||
action_std += (b_action_mean - action_mean) ** 2
|
||||
|
||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||
image_max = torch.maximum(image_max, b_image_max)
|
||||
image_min = torch.maximum(image_min, b_image_min)
|
||||
state_max = torch.maximum(state_max, b_state_max)
|
||||
state_min = torch.maximum(state_min, b_state_min)
|
||||
action_max = torch.maximum(action_max, b_action_max)
|
||||
action_min = torch.maximum(action_min, b_action_min)
|
||||
|
||||
if i < num_batch - 1:
|
||||
batch = rb.sample()
|
||||
|
||||
image_std = torch.sqrt(image_std / num_batch)
|
||||
state_std = torch.sqrt(state_std / num_batch)
|
||||
action_std = torch.sqrt(action_std / num_batch)
|
||||
|
||||
stats = TensorDict(
|
||||
{
|
||||
("observation", "image", "mean"): image_mean[None, :, None, None],
|
||||
("observation", "image", "std"): image_std[None, :, None, None],
|
||||
("observation", "image", "max"): image_max[None, :, None, None],
|
||||
("observation", "image", "min"): image_min[None, :, None, None],
|
||||
("observation", "state", "mean"): state_mean[None, :],
|
||||
("observation", "state", "std"): state_std[None, :],
|
||||
("observation", "state", "max"): state_max[None, :],
|
||||
("observation", "state", "min"): state_min[None, :],
|
||||
("action", "mean"): action_mean[None, :],
|
||||
("action", "std"): action_std[None, :],
|
||||
("action", "max"): action_max[None, :],
|
||||
("action", "min"): action_min[None, :],
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
stats["next", "observation", "image"] = stats["observation", "image"]
|
||||
stats["next", "observation", "state"] = stats["observation", "state"]
|
||||
return stats
|
||||
|
||||
def _compute_or_load_stats(self, storage) -> TensorDict:
|
||||
stats_path = self.root / self.dataset_id / "stats.pth"
|
||||
if stats_path.exists():
|
||||
stats = torch.load(stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
stats = self._compute_stats(storage)
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
import pickle
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
@@ -7,130 +7,69 @@ import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
Sampler,
|
||||
SliceSampler,
|
||||
SliceSamplerWithoutReplacement,
|
||||
)
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
|
||||
|
||||
class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
def download():
|
||||
raise NotImplementedError()
|
||||
import gdown
|
||||
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
download_path = "data.zip"
|
||||
gdown.download(url, download_path, quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(download_path, "r") as zip_f:
|
||||
for member in zip_f.namelist():
|
||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||
print(member)
|
||||
zip_f.extract(member=member)
|
||||
Path(download_path).unlink()
|
||||
|
||||
|
||||
class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id,
|
||||
dataset_id: str,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
num_slices: int = None,
|
||||
slice_len: int = None,
|
||||
pad: float = None,
|
||||
replacement: bool = None,
|
||||
streaming: bool = False,
|
||||
root: Path = None,
|
||||
download: bool = False,
|
||||
sampler: Sampler = None,
|
||||
writer: Writer = None,
|
||||
collate_fn: Callable = None,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
transform: "torchrl.envs.Transform" = None, # noqa-F821
|
||||
split_trajs: bool = False,
|
||||
strict_length: bool = True,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
self.download = download
|
||||
if streaming:
|
||||
raise NotImplementedError
|
||||
self.streaming = streaming
|
||||
self.dataset_id = dataset_id
|
||||
self.split_trajs = split_trajs
|
||||
self.shuffle = shuffle
|
||||
self.num_slices = num_slices
|
||||
self.slice_len = slice_len
|
||||
self.pad = pad
|
||||
|
||||
self.strict_length = strict_length
|
||||
if (self.num_slices is not None) and (self.slice_len is not None):
|
||||
raise ValueError("num_slices or slice_len can be not None, but not both.")
|
||||
if split_trajs:
|
||||
raise NotImplementedError
|
||||
|
||||
if root is None:
|
||||
root = _get_root_dir("simxarm")
|
||||
os.makedirs(root, exist_ok=True)
|
||||
self.root = Path(root)
|
||||
if self.download == "force" or (self.download and not self._is_downloaded()):
|
||||
storage = self._download_and_preproc()
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||
|
||||
if num_slices is not None or slice_len is not None:
|
||||
if sampler is not None:
|
||||
raise ValueError("`num_slices` and `slice_len` are exclusive with the `sampler` argument.")
|
||||
|
||||
if replacement:
|
||||
if not self.shuffle:
|
||||
raise RuntimeError("shuffle=False can only be used when replacement=False.")
|
||||
sampler = SliceSampler(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
strict_length=strict_length,
|
||||
)
|
||||
else:
|
||||
sampler = SliceSamplerWithoutReplacement(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
strict_length=strict_length,
|
||||
shuffle=self.shuffle,
|
||||
)
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
if collate_fn is None:
|
||||
collate_fn = _collate_id
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=writer,
|
||||
collate_fn=collate_fn,
|
||||
dataset_id,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
return len(self)
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
# TODO(rcadene): finish download
|
||||
download()
|
||||
|
||||
@property
|
||||
def num_episodes(self):
|
||||
return len(self._storage._storage["episode"].unique())
|
||||
|
||||
@property
|
||||
def data_path_root(self):
|
||||
if self.streaming:
|
||||
return None
|
||||
return self.root / self.dataset_id
|
||||
|
||||
def _is_downloaded(self):
|
||||
return os.path.exists(self.data_path_root)
|
||||
|
||||
def _download_and_preproc(self):
|
||||
# download
|
||||
# TODO(rcadene)
|
||||
|
||||
# load
|
||||
dataset_dir = Path("data") / self.dataset_id
|
||||
dataset_path = dataset_dir / "buffer.pkl"
|
||||
dataset_path = self.root / f"{self.dataset_id}_raw" / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
@@ -172,7 +111,7 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idx0:idx1] = episode
|
||||
|
||||
|
||||
80
lerobot/common/envs/abstract.py
Normal file
80
lerobot/common/envs/abstract.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import abc
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
from tensordict import TensorDict
|
||||
from torchrl.envs import EnvBase
|
||||
|
||||
|
||||
class AbstractEnv(EnvBase):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
self.num_prev_obs = num_prev_obs
|
||||
self.num_prev_action = num_prev_action
|
||||
self._rendering_hooks = []
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
|
||||
self._make_env()
|
||||
self._make_spec()
|
||||
self._current_seed = self.set_seed(seed)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||
if self.num_prev_action > 0:
|
||||
raise NotImplementedError()
|
||||
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||
|
||||
def register_rendering_hook(self, func):
|
||||
self._rendering_hooks.append(func)
|
||||
|
||||
def call_rendering_hooks(self):
|
||||
for func in self._rendering_hooks:
|
||||
func(self)
|
||||
|
||||
def reset_rendering_hooks(self):
|
||||
self._rendering_hooks = []
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self, mode="rgb_array", width=640, height=480):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _step(self, tensordict: TensorDict):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _make_env(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _make_spec(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,59 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
|
||||
<equality>
|
||||
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
</equality>
|
||||
|
||||
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
|
||||
<body name="peg" pos="0.2 0.5 0.05">
|
||||
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
<body name="socket" pos="-0.2 0.5 0.05">
|
||||
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
@@ -0,0 +1,48 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
|
||||
<equality>
|
||||
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
</equality>
|
||||
|
||||
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
|
||||
<body name="box" pos="0.2 0.5 0.05">
|
||||
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
@@ -0,0 +1,53 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body name="peg" pos="0.2 0.5 0.05">
|
||||
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
<body name="socket" pos="-0.2 0.5 0.05">
|
||||
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
@@ -0,0 +1,42 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body name="box" pos="0.2 0.5 0.05">
|
||||
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
38
lerobot/common/envs/aloha/assets/scene.xml
Normal file
38
lerobot/common/envs/aloha/assets/scene.xml
Normal file
@@ -0,0 +1,38 @@
|
||||
<mujocoinclude>
|
||||
<!-- <option timestep='0.0025' iterations="50" tolerance="1e-10" solver="Newton" jacobian="dense" cone="elliptic"/>-->
|
||||
|
||||
<asset>
|
||||
<mesh file="tabletop.stl" name="tabletop" scale="0.001 0.001 0.001"/>
|
||||
</asset>
|
||||
|
||||
<visual>
|
||||
<map fogstart="1.5" fogend="5" force="0.1" znear="0.1"/>
|
||||
<quality shadowsize="4096" offsamples="4"/>
|
||||
<headlight ambient="0.4 0.4 0.4"/>
|
||||
</visual>
|
||||
|
||||
<worldbody>
|
||||
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='-1 -1 1'
|
||||
dir='1 1 -1'/>
|
||||
<light directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='1 -1 1' dir='-1 1 -1'/>
|
||||
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='0 1 1'
|
||||
dir='0 -1 -1'/>
|
||||
|
||||
<body name="table" pos="0 .6 0">
|
||||
<geom group="1" mesh="tabletop" pos="0 0 0" type="mesh" conaffinity="1" contype="1" name="table" rgba="0.2 0.2 0.2 1" />
|
||||
</body>
|
||||
<body name="midair" pos="0 .6 0.2">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="midair" rgba="1 0 0 0"/>
|
||||
</body>
|
||||
|
||||
<camera name="left_pillar" pos="-0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="right_pillar" pos="0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="top" pos="0 0.6 0.8" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="angle" pos="0 0 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="front_close" pos="0 0.2 0.4" fovy="78" mode="targetbody" target="vx300s_left/camera_focus"/>
|
||||
|
||||
</worldbody>
|
||||
|
||||
|
||||
|
||||
</mujocoinclude>
|
||||
BIN
lerobot/common/envs/aloha/assets/tabletop.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/tabletop.stl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_1_base.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_1_base.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl
Normal file
Binary file not shown.
17
lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
Normal file
17
lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
Normal file
@@ -0,0 +1,17 @@
|
||||
<mujocoinclude>
|
||||
<compiler angle="radian" inertiafromgeom="auto" inertiagrouprange="4 5"/>
|
||||
<asset>
|
||||
<mesh name="vx300s_1_base" file="vx300s_1_base.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_2_shoulder" file="vx300s_2_shoulder.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_3_upper_arm" file="vx300s_3_upper_arm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_4_upper_forearm" file="vx300s_4_upper_forearm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_5_lower_forearm" file="vx300s_5_lower_forearm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_6_wrist" file="vx300s_6_wrist.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_7_gripper" file="vx300s_7_gripper.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_8_gripper_prop" file="vx300s_8_gripper_prop.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_9_gripper_bar" file="vx300s_9_gripper_bar.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_10_gripper_finger_left" file="vx300s_10_custom_finger_left.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_10_gripper_finger_right" file="vx300s_10_custom_finger_right.stl" scale="0.001 0.001 0.001" />
|
||||
</asset>
|
||||
|
||||
</mujocoinclude>
|
||||
59
lerobot/common/envs/aloha/assets/vx300s_left.xml
Normal file
59
lerobot/common/envs/aloha/assets/vx300s_left.xml
Normal file
@@ -0,0 +1,59 @@
|
||||
|
||||
<mujocoinclude>
|
||||
<body name="vx300s_left" pos="-0.469 0.5 0">
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_left/1_base" contype="0" conaffinity="0"/>
|
||||
<body name="vx300s_left/shoulder_link" pos="0 0 0.079">
|
||||
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
|
||||
<joint name="vx300s_left/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
|
||||
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_left/2_shoulder" />
|
||||
<body name="vx300s_left/upper_arm_link" pos="0 0 0.04805">
|
||||
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
|
||||
<joint name="vx300s_left/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_left/3_upper_arm"/>
|
||||
<body name="vx300s_left/upper_forearm_link" pos="0.05955 0 0.3">
|
||||
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
|
||||
<joint name="vx300s_left/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
|
||||
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_left/4_upper_forearm" />
|
||||
<body name="vx300s_left/lower_forearm_link" pos="0.2 0 0">
|
||||
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
|
||||
<joint name="vx300s_left/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_left/5_lower_forearm"/>
|
||||
<body name="vx300s_left/wrist_link" pos="0.1 0 0">
|
||||
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
|
||||
<joint name="vx300s_left/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_left/6_wrist" />
|
||||
<body name="vx300s_left/gripper_link" pos="0.069744 0 0">
|
||||
<body name="vx300s_left/camera_focus" pos="0.15 0 0.01">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="left_cam_focus" rgba="0 0 1 0"/>
|
||||
</body>
|
||||
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_left_site1" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_left_site2" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_left_site3" rgba="0 0 1 0"/>
|
||||
<camera name="left_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_left/camera_focus"/>
|
||||
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
|
||||
<joint name="vx300s_left/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_left/7_gripper" />
|
||||
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_left/9_gripper_bar" />
|
||||
<body name="vx300s_left/gripper_prop_link" pos="0.0485 0 0">
|
||||
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
|
||||
<!-- <joint name="vx300s_left/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
|
||||
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_left/8_gripper_prop" />
|
||||
</body>
|
||||
<body name="vx300s_left/left_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
|
||||
<joint name="vx300s_left/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_left/10_left_gripper_finger"/>
|
||||
</body>
|
||||
<body name="vx300s_left/right_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
|
||||
<joint name="vx300s_left/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_left/10_right_gripper_finger"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</mujocoinclude>
|
||||
59
lerobot/common/envs/aloha/assets/vx300s_right.xml
Normal file
59
lerobot/common/envs/aloha/assets/vx300s_right.xml
Normal file
@@ -0,0 +1,59 @@
|
||||
|
||||
<mujocoinclude>
|
||||
<body name="vx300s_right" pos="0.469 0.5 0" euler="0 0 3.1416">
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_right/1_base" contype="0" conaffinity="0"/>
|
||||
<body name="vx300s_right/shoulder_link" pos="0 0 0.079">
|
||||
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
|
||||
<joint name="vx300s_right/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
|
||||
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_right/2_shoulder" />
|
||||
<body name="vx300s_right/upper_arm_link" pos="0 0 0.04805">
|
||||
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
|
||||
<joint name="vx300s_right/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_right/3_upper_arm"/>
|
||||
<body name="vx300s_right/upper_forearm_link" pos="0.05955 0 0.3">
|
||||
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
|
||||
<joint name="vx300s_right/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
|
||||
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_right/4_upper_forearm" />
|
||||
<body name="vx300s_right/lower_forearm_link" pos="0.2 0 0">
|
||||
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
|
||||
<joint name="vx300s_right/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_right/5_lower_forearm"/>
|
||||
<body name="vx300s_right/wrist_link" pos="0.1 0 0">
|
||||
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
|
||||
<joint name="vx300s_right/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_right/6_wrist" />
|
||||
<body name="vx300s_right/gripper_link" pos="0.069744 0 0">
|
||||
<body name="vx300s_right/camera_focus" pos="0.15 0 0.01">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="right_cam_focus" rgba="0 0 1 0"/>
|
||||
</body>
|
||||
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_right_site1" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_right_site2" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_right_site3" rgba="0 0 1 0"/>
|
||||
<camera name="right_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_right/camera_focus"/>
|
||||
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
|
||||
<joint name="vx300s_right/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_right/7_gripper" />
|
||||
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_right/9_gripper_bar" />
|
||||
<body name="vx300s_right/gripper_prop_link" pos="0.0485 0 0">
|
||||
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
|
||||
<!-- <joint name="vx300s_right/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
|
||||
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_right/8_gripper_prop" />
|
||||
</body>
|
||||
<body name="vx300s_right/left_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
|
||||
<joint name="vx300s_right/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_right/10_left_gripper_finger"/>
|
||||
</body>
|
||||
<body name="vx300s_right/right_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
|
||||
<joint name="vx300s_right/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_right/10_right_gripper_finger"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</mujocoinclude>
|
||||
163
lerobot/common/envs/aloha/constants.py
Normal file
163
lerobot/common/envs/aloha/constants.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from pathlib import Path
|
||||
|
||||
### Simulation envs fixed constants
|
||||
DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz
|
||||
FPS = 50
|
||||
|
||||
|
||||
JOINTS = [
|
||||
# absolute joint position
|
||||
"left_arm_waist",
|
||||
"left_arm_shoulder",
|
||||
"left_arm_elbow",
|
||||
"left_arm_forearm_roll",
|
||||
"left_arm_wrist_angle",
|
||||
"left_arm_wrist_rotate",
|
||||
# normalized gripper position 0: close, 1: open
|
||||
"left_arm_gripper",
|
||||
# absolute joint position
|
||||
"right_arm_waist",
|
||||
"right_arm_shoulder",
|
||||
"right_arm_elbow",
|
||||
"right_arm_forearm_roll",
|
||||
"right_arm_wrist_angle",
|
||||
"right_arm_wrist_rotate",
|
||||
# normalized gripper position 0: close, 1: open
|
||||
"right_arm_gripper",
|
||||
]
|
||||
|
||||
ACTIONS = [
|
||||
# position and quaternion for end effector
|
||||
"left_arm_waist",
|
||||
"left_arm_shoulder",
|
||||
"left_arm_elbow",
|
||||
"left_arm_forearm_roll",
|
||||
"left_arm_wrist_angle",
|
||||
"left_arm_wrist_rotate",
|
||||
# normalized gripper position (0: close, 1: open)
|
||||
"left_arm_gripper",
|
||||
"right_arm_waist",
|
||||
"right_arm_shoulder",
|
||||
"right_arm_elbow",
|
||||
"right_arm_forearm_roll",
|
||||
"right_arm_wrist_angle",
|
||||
"right_arm_wrist_rotate",
|
||||
# normalized gripper position (0: close, 1: open)
|
||||
"right_arm_gripper",
|
||||
]
|
||||
|
||||
|
||||
START_ARM_POSE = [
|
||||
0,
|
||||
-0.96,
|
||||
1.16,
|
||||
0,
|
||||
-0.3,
|
||||
0,
|
||||
0.02239,
|
||||
-0.02239,
|
||||
0,
|
||||
-0.96,
|
||||
1.16,
|
||||
0,
|
||||
-0.3,
|
||||
0,
|
||||
0.02239,
|
||||
-0.02239,
|
||||
]
|
||||
|
||||
ASSETS_DIR = Path(__file__).parent.resolve() / "assets" # note: absolute path
|
||||
|
||||
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||
|
||||
# Gripper joint limits (qpos[6])
|
||||
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
||||
|
||||
############################ Helper functions ############################
|
||||
|
||||
|
||||
def normalize_master_gripper_position(x):
|
||||
return (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
||||
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def normalize_puppet_gripper_position(x):
|
||||
return (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
||||
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def unnormalize_master_gripper_position(x):
|
||||
return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||
|
||||
|
||||
def unnormalize_puppet_gripper_position(x):
|
||||
return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||
|
||||
|
||||
def convert_position_from_master_to_puppet(x):
|
||||
return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x))
|
||||
|
||||
|
||||
def normalizer_master_gripper_joint(x):
|
||||
return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
|
||||
|
||||
def normalize_puppet_gripper_joint(x):
|
||||
return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
|
||||
|
||||
def unnormalize_master_gripper_joint(x):
|
||||
return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
|
||||
|
||||
def unnormalize_puppet_gripper_joint(x):
|
||||
return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
|
||||
|
||||
def convert_join_from_master_to_puppet(x):
|
||||
return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x))
|
||||
|
||||
|
||||
def normalize_master_gripper_velocity(x):
|
||||
return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
|
||||
def normalize_puppet_gripper_velocity(x):
|
||||
return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
|
||||
def convert_master_from_position_to_joint(x):
|
||||
return (
|
||||
normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
+ MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def convert_master_from_joint_to_position(x):
|
||||
return unnormalize_master_gripper_position(
|
||||
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
|
||||
|
||||
def convert_puppet_from_position_to_join(x):
|
||||
return (
|
||||
normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
+ PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def convert_puppet_from_joint_to_position(x):
|
||||
return unnormalize_puppet_gripper_position(
|
||||
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
311
lerobot/common/envs/aloha/env.py
Normal file
311
lerobot/common/envs/aloha/env.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import importlib
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
BoundedTensorSpec,
|
||||
CompositeSpec,
|
||||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
|
||||
from lerobot.common.envs.abstract import AbstractEnv
|
||||
from lerobot.common.envs.aloha.constants import (
|
||||
ACTIONS,
|
||||
ASSETS_DIR,
|
||||
DT,
|
||||
JOINTS,
|
||||
)
|
||||
from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask
|
||||
from lerobot.common.envs.aloha.tasks.sim_end_effector import (
|
||||
InsertionEndEffectorTask,
|
||||
TransferCubeEndEffectorTask,
|
||||
)
|
||||
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
_has_gym = importlib.util.find_spec("gym") is not None
|
||||
|
||||
|
||||
class AlohaEnv(AbstractEnv):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(
|
||||
task=task,
|
||||
frame_skip=frame_skip,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=image_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
if not self.from_pixels:
|
||||
raise NotImplementedError()
|
||||
|
||||
self._env = self._make_env_task(self.task)
|
||||
|
||||
def render(self, mode="rgb_array", width=640, height=480):
|
||||
# TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
|
||||
image = self._env.physics.render(height=height, width=width, camera_id="top")
|
||||
return image
|
||||
|
||||
def _make_env_task(self, task_name):
|
||||
# time limit is controlled by StepCounter in env factory
|
||||
time_limit = float("inf")
|
||||
|
||||
if "sim_transfer_cube" in task_name:
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = TransferCubeTask(random=False)
|
||||
elif "sim_insertion" in task_name:
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = InsertionTask(random=False)
|
||||
elif "sim_end_effector_transfer_cube" in task_name:
|
||||
raise NotImplementedError()
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = TransferCubeEndEffectorTask(random=False)
|
||||
elif "sim_end_effector_insertion" in task_name:
|
||||
raise NotImplementedError()
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = InsertionEndEffectorTask(random=False)
|
||||
else:
|
||||
raise NotImplementedError(task_name)
|
||||
|
||||
env = control.Environment(
|
||||
physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False
|
||||
)
|
||||
return env
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
image = torch.from_numpy(raw_obs["images"]["top"].copy())
|
||||
image = einops.rearrange(image, "h w c -> c h w")
|
||||
assert image.dtype == torch.uint8
|
||||
obs = {"image": {"top": image}}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32)
|
||||
else:
|
||||
# TODO(rcadene):
|
||||
raise NotImplementedError()
|
||||
# obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||
|
||||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
|
||||
# TODO(rcadene): do not use global variable for this
|
||||
if "sim_transfer_cube" in self.task:
|
||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||
elif "sim_insertion" in self.task:
|
||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||
|
||||
raw_obs = self._env.reset()
|
||||
# TODO(rcadene): add assert
|
||||
# assert self._current_seed == self._env._seed
|
||||
|
||||
obs = self._format_raw_obs(raw_obs.observation)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.call_rendering_hooks()
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
# step expects shape=(4,) so we pad if necessary
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
|
||||
if action.ndim == 1:
|
||||
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
||||
else:
|
||||
if self.frame_skip > 1:
|
||||
raise NotImplementedError()
|
||||
|
||||
num_action_steps = action.shape[0]
|
||||
for i in range(num_action_steps):
|
||||
_, reward, discount, raw_obs = self._env.step(action[i])
|
||||
del discount # not used
|
||||
|
||||
# TOOD(rcadene): add an enum
|
||||
success = done = reward == 4
|
||||
sum_reward += reward
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"]["top"])
|
||||
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
self.call_rendering_hooks()
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||
# succes and done are true when coverage > self.success_threshold in env
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([success], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
return td
|
||||
|
||||
def _make_spec(self):
|
||||
obs = {}
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if self.from_pixels:
|
||||
if isinstance(self.image_size, int):
|
||||
image_shape = (3, self.image_size, self.image_size)
|
||||
elif OmegaConf.is_list(self.image_size):
|
||||
assert len(self.image_size) == 3 # c h w
|
||||
assert self.image_size[0] == 3 # c is RGB
|
||||
image_shape = tuple(self.image_size)
|
||||
else:
|
||||
raise ValueError(self.image_size)
|
||||
if self.num_prev_obs > 0:
|
||||
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||
|
||||
obs["image"] = {
|
||||
"top": BoundedTensorSpec(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=image_shape,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
}
|
||||
if not self.pixels_only:
|
||||
state_shape = (len(JOINTS),)
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
# TODO: add low and high bounds
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
state_shape = (len(JOINTS),)
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
# TODO: add low and high bounds
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.observation_spec = CompositeSpec({"observation": obs})
|
||||
|
||||
# TODO(rcadene): valid when controling end effector?
|
||||
# action_space = self._env.action_spec()
|
||||
# self.action_spec = BoundedTensorSpec(
|
||||
# low=action_space.minimum,
|
||||
# high=action_space.maximum,
|
||||
# shape=action_space.shape,
|
||||
# dtype=torch.float32,
|
||||
# device=self.device,
|
||||
# )
|
||||
|
||||
# TODO(rcaene): add bounds (where are they????)
|
||||
self.action_spec = BoundedTensorSpec(
|
||||
shape=(len(ACTIONS)),
|
||||
low=-1,
|
||||
high=1,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||
shape=(1,),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.done_spec = CompositeSpec(
|
||||
{
|
||||
"done": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
"success": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
set_seed(seed)
|
||||
# TODO(rcadene): seed the env
|
||||
# self._env.seed(seed)
|
||||
logging.warning("Aloha env is not seeded")
|
||||
219
lerobot/common/envs/aloha/tasks/sim.py
Normal file
219
lerobot/common/envs/aloha/tasks/sim.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
from dm_control.suite import base
|
||||
|
||||
from lerobot.common.envs.aloha.constants import (
|
||||
START_ARM_POSE,
|
||||
normalize_puppet_gripper_position,
|
||||
normalize_puppet_gripper_velocity,
|
||||
unnormalize_puppet_gripper_position,
|
||||
)
|
||||
|
||||
BOX_POSE = [None] # to be changed from outside
|
||||
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with joint position control
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
|
||||
class BimanualViperXTask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
left_arm_action = action[:6]
|
||||
right_arm_action = action[7 : 7 + 6]
|
||||
normalized_left_gripper_action = action[6]
|
||||
normalized_right_gripper_action = action[7 + 6]
|
||||
|
||||
left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action)
|
||||
right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action)
|
||||
|
||||
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
|
||||
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
|
||||
|
||||
env_action = np.concatenate(
|
||||
[left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action]
|
||||
)
|
||||
super().before_step(env_action, physics)
|
||||
return
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos(physics)
|
||||
obs["qvel"] = self.get_qvel(physics)
|
||||
obs["env_state"] = self.get_env_state(physics)
|
||||
obs["images"] = {}
|
||||
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
|
||||
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
|
||||
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
|
||||
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7:] = BOX_POSE[0]
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = (
|
||||
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
)
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = (
|
||||
("socket-1", "table") in all_contact_pairs
|
||||
or ("socket-2", "table") in all_contact_pairs
|
||||
or ("socket-3", "table") in all_contact_pairs
|
||||
or ("socket-4", "table") in all_contact_pairs
|
||||
)
|
||||
peg_touch_socket = (
|
||||
("red_peg", "socket-1") in all_contact_pairs
|
||||
or ("red_peg", "socket-2") in all_contact_pairs
|
||||
or ("red_peg", "socket-3") in all_contact_pairs
|
||||
or ("red_peg", "socket-4") in all_contact_pairs
|
||||
)
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if (
|
||||
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
|
||||
): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
263
lerobot/common/envs/aloha/tasks/sim_end_effector.py
Normal file
263
lerobot/common/envs/aloha/tasks/sim_end_effector.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
from dm_control.suite import base
|
||||
|
||||
from lerobot.common.envs.aloha.constants import (
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
START_ARM_POSE,
|
||||
normalize_puppet_gripper_position,
|
||||
normalize_puppet_gripper_velocity,
|
||||
unnormalize_puppet_gripper_position,
|
||||
)
|
||||
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with end-effector control.
|
||||
Action space: [left_arm_pose (7), # position and quaternion for end effector
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_pose (7), # position and quaternion for end effector
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
|
||||
class BimanualViperXEndEffectorTask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
a_len = len(action) // 2
|
||||
action_left = action[:a_len]
|
||||
action_right = action[a_len:]
|
||||
|
||||
# set mocap position and quat
|
||||
# left
|
||||
np.copyto(physics.data.mocap_pos[0], action_left[:3])
|
||||
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], action_right[:3])
|
||||
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
|
||||
|
||||
# set gripper
|
||||
g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7])
|
||||
g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7])
|
||||
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
|
||||
|
||||
def initialize_robots(self, physics):
|
||||
# reset joint position
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
|
||||
# reset mocap to align with end effector
|
||||
# to obtain these numbers:
|
||||
# (1) make an ee_sim env and reset to the same start_pose
|
||||
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
|
||||
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
|
||||
# repeat the same for right side
|
||||
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
|
||||
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
|
||||
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
|
||||
|
||||
# reset gripper control
|
||||
close_gripper_control = np.array(
|
||||
[
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
]
|
||||
)
|
||||
np.copyto(physics.data.ctrl, close_gripper_control)
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
# note: it is important to do .copy()
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos(physics)
|
||||
obs["qvel"] = self.get_qvel(physics)
|
||||
obs["env_state"] = self.get_env_state(physics)
|
||||
obs["images"] = {}
|
||||
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
|
||||
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
|
||||
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
|
||||
# used in scripted policy to obtain starting pose
|
||||
obs["mocap_pose_left"] = np.concatenate(
|
||||
[physics.data.mocap_pos[0], physics.data.mocap_quat[0]]
|
||||
).copy()
|
||||
obs["mocap_pose_right"] = np.concatenate(
|
||||
[physics.data.mocap_pos[1], physics.data.mocap_quat[1]]
|
||||
).copy()
|
||||
|
||||
# used when replaying joint trajectory
|
||||
obs["gripper_ctrl"] = physics.data.ctrl.copy()
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize box position
|
||||
cube_pose = sample_box_pose()
|
||||
box_start_idx = physics.model.name2id("red_box_joint", "joint")
|
||||
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionEndEffectorTask(BimanualViperXEndEffectorTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize peg and socket position
|
||||
peg_pose, socket_pose = sample_insertion_pose()
|
||||
|
||||
def id2index(j_id):
|
||||
return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
|
||||
|
||||
peg_start_id = physics.model.name2id("red_peg_joint", "joint")
|
||||
peg_start_idx = id2index(peg_start_id)
|
||||
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
socket_start_id = physics.model.name2id("blue_socket_joint", "joint")
|
||||
socket_start_idx = id2index(socket_start_id)
|
||||
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = (
|
||||
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
)
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = (
|
||||
("socket-1", "table") in all_contact_pairs
|
||||
or ("socket-2", "table") in all_contact_pairs
|
||||
or ("socket-3", "table") in all_contact_pairs
|
||||
or ("socket-4", "table") in all_contact_pairs
|
||||
)
|
||||
peg_touch_socket = (
|
||||
("red_peg", "socket-1") in all_contact_pairs
|
||||
or ("red_peg", "socket-2") in all_contact_pairs
|
||||
or ("red_peg", "socket-3") in all_contact_pairs
|
||||
or ("red_peg", "socket-4") in all_contact_pairs
|
||||
)
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if (
|
||||
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
|
||||
): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
39
lerobot/common/envs/aloha/utils.py
Normal file
39
lerobot/common/envs/aloha/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def sample_box_pose():
|
||||
x_range = [0.0, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
cube_quat = np.array([1, 0, 0, 0])
|
||||
return np.concatenate([cube_position, cube_quat])
|
||||
|
||||
|
||||
def sample_insertion_pose():
|
||||
# Peg
|
||||
x_range = [0.1, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
peg_quat = np.array([1, 0, 0, 0])
|
||||
peg_pose = np.concatenate([peg_position, peg_quat])
|
||||
|
||||
# Socket
|
||||
x_range = [-0.2, -0.1]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
socket_quat = np.array([1, 0, 0, 0])
|
||||
socket_pose = np.concatenate([socket_position, socket_quat])
|
||||
|
||||
return peg_pose, socket_pose
|
||||
@@ -1,4 +1,4 @@
|
||||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
||||
|
||||
|
||||
def make_env(cfg, transform=None):
|
||||
@@ -9,6 +9,7 @@ def make_env(cfg, transform=None):
|
||||
"image_size": cfg.env.image_size,
|
||||
# TODO(rcadene): do we want a specific eval_env_seed?
|
||||
"seed": cfg.seed,
|
||||
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||
}
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
@@ -17,11 +18,16 @@ def make_env(cfg, transform=None):
|
||||
kwargs["task"] = cfg.env.task
|
||||
clsfunc = SimxarmEnv
|
||||
elif cfg.env.name == "pusht":
|
||||
from lerobot.common.envs.pusht import PushtEnv
|
||||
from lerobot.common.envs.pusht.env import PushtEnv
|
||||
|
||||
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||
|
||||
clsfunc = PushtEnv
|
||||
elif cfg.env.name == "aloha":
|
||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||
|
||||
kwargs["task"] = cfg.env.task
|
||||
clsfunc = AlohaEnv
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
@@ -32,7 +38,13 @@ def make_env(cfg, transform=None):
|
||||
|
||||
if transform is not None:
|
||||
# useful to add normalization
|
||||
env.append_transform(transform)
|
||||
if isinstance(transform, Compose):
|
||||
for tf in transform:
|
||||
env.append_transform(tf.clone())
|
||||
elif isinstance(transform, Transform):
|
||||
env.append_transform(transform.clone())
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return env
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import importlib
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
@@ -10,18 +11,18 @@ from torchrl.data.tensor_specs import (
|
||||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||
|
||||
from lerobot.common.envs.abstract import AbstractEnv
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
_has_gym = importlib.util.find_spec("gym") is not None
|
||||
_has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _has_gym
|
||||
|
||||
|
||||
class PushtEnv(EnvBase):
|
||||
class PushtEnv(AbstractEnv):
|
||||
def __init__(
|
||||
self,
|
||||
task="pusht",
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
@@ -31,42 +32,31 @@ class PushtEnv(EnvBase):
|
||||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
self.num_prev_obs = num_prev_obs
|
||||
self.num_prev_action = num_prev_action
|
||||
super().__init__(
|
||||
task=task,
|
||||
frame_skip=frame_skip,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=image_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
|
||||
if not _has_diffpolicy:
|
||||
raise ImportError("Cannot import diffusion_policy.")
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
||||
# from diffusion_policy.env.pusht.pusht_env import PushTEnv
|
||||
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||
|
||||
if not from_pixels:
|
||||
if not self.from_pixels:
|
||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||
from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv
|
||||
|
||||
self._env = PushTImageEnv(render_size=self.image_size)
|
||||
|
||||
self._make_spec()
|
||||
self._current_seed = self.set_seed(seed)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||
if self.num_prev_action > 0:
|
||||
self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
if width != height:
|
||||
raise NotImplementedError()
|
||||
@@ -123,6 +113,8 @@ class PushtEnv(EnvBase):
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.call_rendering_hooks()
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
@@ -133,7 +125,7 @@ class PushtEnv(EnvBase):
|
||||
sum_reward = 0
|
||||
|
||||
if action.ndim == 1:
|
||||
action = action.repeat(self.frame_skip, 1)
|
||||
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
||||
else:
|
||||
if self.frame_skip > 1:
|
||||
raise NotImplementedError()
|
||||
@@ -155,6 +147,8 @@ class PushtEnv(EnvBase):
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
self.call_rendering_hooks()
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
@@ -172,24 +166,24 @@ class PushtEnv(EnvBase):
|
||||
if self.from_pixels:
|
||||
image_shape = (3, self.image_size, self.image_size)
|
||||
if self.num_prev_obs > 0:
|
||||
image_shape = (self.num_prev_obs, *image_shape)
|
||||
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||
|
||||
obs["image"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=1,
|
||||
high=255,
|
||||
shape=image_shape,
|
||||
dtype=torch.float32,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
state_shape = self._env.observation_space["agent_pos"].shape
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs, *state_shape)
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=512,
|
||||
shape=self._env.observation_space["agent_pos"].shape,
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
@@ -197,11 +191,11 @@ class PushtEnv(EnvBase):
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
state_shape = self._env.observation_space["observation"].shape
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs, *state_shape)
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
# TODO:
|
||||
shape=self._env.observation_space["observation"].shape,
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
378
lerobot/common/envs/pusht/pusht_env.py
Normal file
378
lerobot/common/envs/pusht/pusht_env.py
Normal file
@@ -0,0 +1,378 @@
|
||||
import collections
|
||||
|
||||
import cv2
|
||||
import gym
|
||||
import numpy as np
|
||||
import pygame
|
||||
import pymunk
|
||||
import pymunk.pygame_util
|
||||
import shapely.geometry as sg
|
||||
import skimage.transform as st
|
||||
from gym import spaces
|
||||
from pymunk.vec2d import Vec2d
|
||||
|
||||
from lerobot.common.envs.pusht.pymunk_override import DrawOptions
|
||||
|
||||
|
||||
def pymunk_to_shapely(body, shapes):
|
||||
geoms = []
|
||||
for shape in shapes:
|
||||
if isinstance(shape, pymunk.shapes.Poly):
|
||||
verts = [body.local_to_world(v) for v in shape.get_vertices()]
|
||||
verts += [verts[0]]
|
||||
geoms.append(sg.Polygon(verts))
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported shape type {type(shape)}")
|
||||
geom = sg.MultiPolygon(geoms)
|
||||
return geom
|
||||
|
||||
|
||||
class PushTEnv(gym.Env):
|
||||
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
|
||||
reward_range = (0.0, 1.0)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
legacy=False,
|
||||
block_cog=None,
|
||||
damping=None,
|
||||
render_action=True,
|
||||
render_size=96,
|
||||
reset_to_state=None,
|
||||
):
|
||||
self._seed = None
|
||||
self.seed()
|
||||
self.window_size = ws = 512 # The size of the PyGame window
|
||||
self.render_size = render_size
|
||||
self.sim_hz = 100
|
||||
# Local controller params.
|
||||
self.k_p, self.k_v = 100, 20 # PD control.z
|
||||
self.control_hz = self.metadata["video.frames_per_second"]
|
||||
# legcay set_state for data compatibility
|
||||
self.legacy = legacy
|
||||
|
||||
# agent_pos, block_pos, block_angle
|
||||
self.observation_space = spaces.Box(
|
||||
low=np.array([0, 0, 0, 0, 0], dtype=np.float64),
|
||||
high=np.array([ws, ws, ws, ws, np.pi * 2], dtype=np.float64),
|
||||
shape=(5,),
|
||||
dtype=np.float64,
|
||||
)
|
||||
|
||||
# positional goal for agent
|
||||
self.action_space = spaces.Box(
|
||||
low=np.array([0, 0], dtype=np.float64),
|
||||
high=np.array([ws, ws], dtype=np.float64),
|
||||
shape=(2,),
|
||||
dtype=np.float64,
|
||||
)
|
||||
|
||||
self.block_cog = block_cog
|
||||
self.damping = damping
|
||||
self.render_action = render_action
|
||||
|
||||
"""
|
||||
If human-rendering is used, `self.window` will be a reference
|
||||
to the window that we draw to. `self.clock` will be a clock that is used
|
||||
to ensure that the environment is rendered at the correct framerate in
|
||||
human-mode. They will remain `None` until human-mode is used for the
|
||||
first time.
|
||||
"""
|
||||
self.window = None
|
||||
self.clock = None
|
||||
self.screen = None
|
||||
|
||||
self.space = None
|
||||
self.teleop = None
|
||||
self.render_buffer = None
|
||||
self.latest_action = None
|
||||
self.reset_to_state = reset_to_state
|
||||
|
||||
def reset(self):
|
||||
seed = self._seed
|
||||
self._setup()
|
||||
if self.block_cog is not None:
|
||||
self.block.center_of_gravity = self.block_cog
|
||||
if self.damping is not None:
|
||||
self.space.damping = self.damping
|
||||
|
||||
# use legacy RandomState for compatibility
|
||||
state = self.reset_to_state
|
||||
if state is None:
|
||||
rs = np.random.RandomState(seed=seed)
|
||||
state = np.array(
|
||||
[
|
||||
rs.randint(50, 450),
|
||||
rs.randint(50, 450),
|
||||
rs.randint(100, 400),
|
||||
rs.randint(100, 400),
|
||||
rs.randn() * 2 * np.pi - np.pi,
|
||||
]
|
||||
)
|
||||
self._set_state(state)
|
||||
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
dt = 1.0 / self.sim_hz
|
||||
self.n_contact_points = 0
|
||||
n_steps = self.sim_hz // self.control_hz
|
||||
if action is not None:
|
||||
self.latest_action = action
|
||||
for _ in range(n_steps):
|
||||
# Step PD control.
|
||||
# self.agent.velocity = self.k_p * (act - self.agent.position) # P control works too.
|
||||
acceleration = self.k_p * (action - self.agent.position) + self.k_v * (
|
||||
Vec2d(0, 0) - self.agent.velocity
|
||||
)
|
||||
self.agent.velocity += acceleration * dt
|
||||
|
||||
# Step physics.
|
||||
self.space.step(dt)
|
||||
|
||||
# compute reward
|
||||
goal_body = self._get_goal_pose_body(self.goal_pose)
|
||||
goal_geom = pymunk_to_shapely(goal_body, self.block.shapes)
|
||||
block_geom = pymunk_to_shapely(self.block, self.block.shapes)
|
||||
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage = intersection_area / goal_area
|
||||
reward = np.clip(coverage / self.success_threshold, 0, 1)
|
||||
done = coverage > self.success_threshold
|
||||
|
||||
observation = self._get_obs()
|
||||
info = self._get_info()
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def render(self, mode):
|
||||
return self._render_frame(mode)
|
||||
|
||||
def teleop_agent(self):
|
||||
TeleopAgent = collections.namedtuple("TeleopAgent", ["act"])
|
||||
|
||||
def act(obs):
|
||||
act = None
|
||||
mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen)
|
||||
if self.teleop or (mouse_position - self.agent.position).length < 30:
|
||||
self.teleop = True
|
||||
act = mouse_position
|
||||
return act
|
||||
|
||||
return TeleopAgent(act)
|
||||
|
||||
def _get_obs(self):
|
||||
obs = np.array(
|
||||
tuple(self.agent.position) + tuple(self.block.position) + (self.block.angle % (2 * np.pi),)
|
||||
)
|
||||
return obs
|
||||
|
||||
def _get_goal_pose_body(self, pose):
|
||||
mass = 1
|
||||
inertia = pymunk.moment_for_box(mass, (50, 100))
|
||||
body = pymunk.Body(mass, inertia)
|
||||
# preserving the legacy assignment order for compatibility
|
||||
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
|
||||
body.position = pose[:2].tolist()
|
||||
body.angle = pose[2]
|
||||
return body
|
||||
|
||||
def _get_info(self):
|
||||
n_steps = self.sim_hz // self.control_hz
|
||||
n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps))
|
||||
info = {
|
||||
"pos_agent": np.array(self.agent.position),
|
||||
"vel_agent": np.array(self.agent.velocity),
|
||||
"block_pose": np.array(list(self.block.position) + [self.block.angle]),
|
||||
"goal_pose": self.goal_pose,
|
||||
"n_contacts": n_contact_points_per_step,
|
||||
}
|
||||
return info
|
||||
|
||||
def _render_frame(self, mode):
|
||||
if self.window is None and mode == "human":
|
||||
pygame.init()
|
||||
pygame.display.init()
|
||||
self.window = pygame.display.set_mode((self.window_size, self.window_size))
|
||||
if self.clock is None and mode == "human":
|
||||
self.clock = pygame.time.Clock()
|
||||
|
||||
canvas = pygame.Surface((self.window_size, self.window_size))
|
||||
canvas.fill((255, 255, 255))
|
||||
self.screen = canvas
|
||||
|
||||
draw_options = DrawOptions(canvas)
|
||||
|
||||
# Draw goal pose.
|
||||
goal_body = self._get_goal_pose_body(self.goal_pose)
|
||||
for shape in self.block.shapes:
|
||||
goal_points = [
|
||||
pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface)
|
||||
for v in shape.get_vertices()
|
||||
]
|
||||
goal_points += [goal_points[0]]
|
||||
pygame.draw.polygon(canvas, self.goal_color, goal_points)
|
||||
|
||||
# Draw agent and block.
|
||||
self.space.debug_draw(draw_options)
|
||||
|
||||
if mode == "human":
|
||||
# The following line copies our drawings from `canvas` to the visible window
|
||||
self.window.blit(canvas, canvas.get_rect())
|
||||
pygame.event.pump()
|
||||
pygame.display.update()
|
||||
|
||||
# the clock is already ticked during in step for "human"
|
||||
|
||||
img = np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))
|
||||
img = cv2.resize(img, (self.render_size, self.render_size))
|
||||
if self.render_action and self.latest_action is not None:
|
||||
action = np.array(self.latest_action)
|
||||
coord = (action / 512 * 96).astype(np.int32)
|
||||
marker_size = int(8 / 96 * self.render_size)
|
||||
thickness = int(1 / 96 * self.render_size)
|
||||
cv2.drawMarker(
|
||||
img,
|
||||
coord,
|
||||
color=(255, 0, 0),
|
||||
markerType=cv2.MARKER_CROSS,
|
||||
markerSize=marker_size,
|
||||
thickness=thickness,
|
||||
)
|
||||
return img
|
||||
|
||||
def close(self):
|
||||
if self.window is not None:
|
||||
pygame.display.quit()
|
||||
pygame.quit()
|
||||
|
||||
def seed(self, seed=None):
|
||||
if seed is None:
|
||||
seed = np.random.randint(0, 25536)
|
||||
self._seed = seed
|
||||
self.np_random = np.random.default_rng(seed)
|
||||
|
||||
def _handle_collision(self, arbiter, space, data):
|
||||
self.n_contact_points += len(arbiter.contact_point_set.points)
|
||||
|
||||
def _set_state(self, state):
|
||||
if isinstance(state, np.ndarray):
|
||||
state = state.tolist()
|
||||
pos_agent = state[:2]
|
||||
pos_block = state[2:4]
|
||||
rot_block = state[4]
|
||||
self.agent.position = pos_agent
|
||||
# setting angle rotates with respect to center of mass
|
||||
# therefore will modify the geometric position
|
||||
# if not the same as CoM
|
||||
# therefore should be modified first.
|
||||
if self.legacy:
|
||||
# for compatibility with legacy data
|
||||
self.block.position = pos_block
|
||||
self.block.angle = rot_block
|
||||
else:
|
||||
self.block.angle = rot_block
|
||||
self.block.position = pos_block
|
||||
|
||||
# Run physics to take effect
|
||||
self.space.step(1.0 / self.sim_hz)
|
||||
|
||||
def _set_state_local(self, state_local):
|
||||
agent_pos_local = state_local[:2]
|
||||
block_pose_local = state_local[2:]
|
||||
tf_img_obj = st.AffineTransform(translation=self.goal_pose[:2], rotation=self.goal_pose[2])
|
||||
tf_obj_new = st.AffineTransform(translation=block_pose_local[:2], rotation=block_pose_local[2])
|
||||
tf_img_new = st.AffineTransform(matrix=tf_img_obj.params @ tf_obj_new.params)
|
||||
agent_pos_new = tf_img_new(agent_pos_local)
|
||||
new_state = np.array(list(agent_pos_new[0]) + list(tf_img_new.translation) + [tf_img_new.rotation])
|
||||
self._set_state(new_state)
|
||||
return new_state
|
||||
|
||||
def _setup(self):
|
||||
self.space = pymunk.Space()
|
||||
self.space.gravity = 0, 0
|
||||
self.space.damping = 0
|
||||
self.teleop = False
|
||||
self.render_buffer = []
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
self._add_segment((5, 506), (5, 5), 2),
|
||||
self._add_segment((5, 5), (506, 5), 2),
|
||||
self._add_segment((506, 5), (506, 506), 2),
|
||||
self._add_segment((5, 506), (506, 506), 2),
|
||||
]
|
||||
self.space.add(*walls)
|
||||
|
||||
# Add agent, block, and goal zone.
|
||||
self.agent = self.add_circle((256, 400), 15)
|
||||
self.block = self.add_tee((256, 300), 0)
|
||||
self.goal_color = pygame.Color("LightGreen")
|
||||
self.goal_pose = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
|
||||
# Add collision handling
|
||||
self.collision_handeler = self.space.add_collision_handler(0, 0)
|
||||
self.collision_handeler.post_solve = self._handle_collision
|
||||
self.n_contact_points = 0
|
||||
|
||||
self.max_score = 50 * 100
|
||||
self.success_threshold = 0.95 # 95% coverage.
|
||||
|
||||
def _add_segment(self, a, b, radius):
|
||||
shape = pymunk.Segment(self.space.static_body, a, b, radius)
|
||||
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
|
||||
return shape
|
||||
|
||||
def add_circle(self, position, radius):
|
||||
body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)
|
||||
body.position = position
|
||||
body.friction = 1
|
||||
shape = pymunk.Circle(body, radius)
|
||||
shape.color = pygame.Color("RoyalBlue")
|
||||
self.space.add(body, shape)
|
||||
return body
|
||||
|
||||
def add_box(self, position, height, width):
|
||||
mass = 1
|
||||
inertia = pymunk.moment_for_box(mass, (height, width))
|
||||
body = pymunk.Body(mass, inertia)
|
||||
body.position = position
|
||||
shape = pymunk.Poly.create_box(body, (height, width))
|
||||
shape.color = pygame.Color("LightSlateGray")
|
||||
self.space.add(body, shape)
|
||||
return body
|
||||
|
||||
def add_tee(self, position, angle, scale=30, color="LightSlateGray", mask=None):
|
||||
if mask is None:
|
||||
mask = pymunk.ShapeFilter.ALL_MASKS()
|
||||
mass = 1
|
||||
length = 4
|
||||
vertices1 = [
|
||||
(-length * scale / 2, scale),
|
||||
(length * scale / 2, scale),
|
||||
(length * scale / 2, 0),
|
||||
(-length * scale / 2, 0),
|
||||
]
|
||||
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||
vertices2 = [
|
||||
(-scale / 2, scale),
|
||||
(-scale / 2, length * scale),
|
||||
(scale / 2, length * scale),
|
||||
(scale / 2, scale),
|
||||
]
|
||||
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||
body = pymunk.Body(mass, inertia1 + inertia2)
|
||||
shape1 = pymunk.Poly(body, vertices1)
|
||||
shape2 = pymunk.Poly(body, vertices2)
|
||||
shape1.color = pygame.Color(color)
|
||||
shape2.color = pygame.Color(color)
|
||||
shape1.filter = pymunk.ShapeFilter(mask=mask)
|
||||
shape2.filter = pymunk.ShapeFilter(mask=mask)
|
||||
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
|
||||
body.position = position
|
||||
body.angle = angle
|
||||
body.friction = 1
|
||||
self.space.add(body, shape1, shape2)
|
||||
return body
|
||||
55
lerobot/common/envs/pusht/pusht_image_env.py
Normal file
55
lerobot/common/envs/pusht/pusht_image_env.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||
|
||||
|
||||
class PushTImageEnv(PushTEnv):
|
||||
metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
|
||||
|
||||
def __init__(self, legacy=False, block_cog=None, damping=None, render_size=96):
|
||||
super().__init__(
|
||||
legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False
|
||||
)
|
||||
ws = self.window_size
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"image": spaces.Box(low=0, high=1, shape=(3, render_size, render_size), dtype=np.float32),
|
||||
"agent_pos": spaces.Box(low=0, high=ws, shape=(2,), dtype=np.float32),
|
||||
}
|
||||
)
|
||||
self.render_cache = None
|
||||
|
||||
def _get_obs(self):
|
||||
img = super()._render_frame(mode="rgb_array")
|
||||
|
||||
agent_pos = np.array(self.agent.position)
|
||||
img_obs = np.moveaxis(img, -1, 0)
|
||||
obs = {"image": img_obs, "agent_pos": agent_pos}
|
||||
|
||||
# draw action
|
||||
if self.latest_action is not None:
|
||||
action = np.array(self.latest_action)
|
||||
coord = (action / 512 * 96).astype(np.int32)
|
||||
marker_size = int(8 / 96 * self.render_size)
|
||||
thickness = int(1 / 96 * self.render_size)
|
||||
cv2.drawMarker(
|
||||
img,
|
||||
coord,
|
||||
color=(255, 0, 0),
|
||||
markerType=cv2.MARKER_CROSS,
|
||||
markerSize=marker_size,
|
||||
thickness=thickness,
|
||||
)
|
||||
self.render_cache = img
|
||||
|
||||
return obs
|
||||
|
||||
def render(self, mode):
|
||||
assert mode == "rgb_array"
|
||||
|
||||
if self.render_cache is None:
|
||||
self._get_obs()
|
||||
|
||||
return self.render_cache
|
||||
244
lerobot/common/envs/pusht/pymunk_override.py
Normal file
244
lerobot/common/envs/pusht/pymunk_override.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# ----------------------------------------------------------------------------
|
||||
# pymunk
|
||||
# Copyright (c) 2007-2016 Victor Blomqvist
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
"""This submodule contains helper functions to help with quick prototyping
|
||||
using pymunk together with pygame.
|
||||
|
||||
Intended to help with debugging and prototyping, not for actual production use
|
||||
in a full application. The methods contained in this module is opinionated
|
||||
about your coordinate system and not in any way optimized.
|
||||
"""
|
||||
|
||||
__docformat__ = "reStructuredText"
|
||||
|
||||
__all__ = [
|
||||
"DrawOptions",
|
||||
"get_mouse_pos",
|
||||
"to_pygame",
|
||||
"from_pygame",
|
||||
# "lighten",
|
||||
"positive_y_is_up",
|
||||
]
|
||||
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pygame
|
||||
import pymunk
|
||||
from pymunk.space_debug_draw_options import SpaceDebugColor
|
||||
from pymunk.vec2d import Vec2d
|
||||
|
||||
positive_y_is_up: bool = False
|
||||
"""Make increasing values of y point upwards.
|
||||
|
||||
When True::
|
||||
|
||||
y
|
||||
^
|
||||
| . (3, 3)
|
||||
|
|
||||
| . (2, 2)
|
||||
|
|
||||
+------ > x
|
||||
|
||||
When False::
|
||||
|
||||
+------ > x
|
||||
|
|
||||
| . (2, 2)
|
||||
|
|
||||
| . (3, 3)
|
||||
v
|
||||
y
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class DrawOptions(pymunk.SpaceDebugDrawOptions):
|
||||
def __init__(self, surface: pygame.Surface) -> None:
|
||||
"""Draw a pymunk.Space on a pygame.Surface object.
|
||||
|
||||
Typical usage::
|
||||
|
||||
>>> import pymunk
|
||||
>>> surface = pygame.Surface((10,10))
|
||||
>>> space = pymunk.Space()
|
||||
>>> options = pymunk.pygame_util.DrawOptions(surface)
|
||||
>>> space.debug_draw(options)
|
||||
|
||||
You can control the color of a shape by setting shape.color to the color
|
||||
you want it drawn in::
|
||||
|
||||
>>> c = pymunk.Circle(None, 10)
|
||||
>>> c.color = pygame.Color("pink")
|
||||
|
||||
See pygame_util.demo.py for a full example
|
||||
|
||||
Since pygame uses a coordinate system where y points down (in contrast
|
||||
to many other cases), you either have to make the physics simulation
|
||||
with Pymunk also behave in that way, or flip everything when you draw.
|
||||
|
||||
The easiest is probably to just make the simulation behave the same
|
||||
way as Pygame does. In that way all coordinates used are in the same
|
||||
orientation and easy to reason about::
|
||||
|
||||
>>> space = pymunk.Space()
|
||||
>>> space.gravity = (0, -1000)
|
||||
>>> body = pymunk.Body()
|
||||
>>> body.position = (0, 0) # will be positioned in the top left corner
|
||||
>>> space.debug_draw(options)
|
||||
|
||||
To flip the drawing its possible to set the module property
|
||||
:py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
|
||||
the simulation upside down before drawing::
|
||||
|
||||
>>> positive_y_is_up = True
|
||||
>>> body = pymunk.Body()
|
||||
>>> body.position = (0, 0)
|
||||
>>> # Body will be position in bottom left corner
|
||||
|
||||
:Parameters:
|
||||
surface : pygame.Surface
|
||||
Surface that the objects will be drawn on
|
||||
"""
|
||||
self.surface = surface
|
||||
super().__init__()
|
||||
|
||||
def draw_circle(
|
||||
self,
|
||||
pos: Vec2d,
|
||||
angle: float,
|
||||
radius: float,
|
||||
outline_color: SpaceDebugColor,
|
||||
fill_color: SpaceDebugColor,
|
||||
) -> None:
|
||||
p = to_pygame(pos, self.surface)
|
||||
|
||||
pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
|
||||
pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius - 4), 0)
|
||||
|
||||
# circle_edge = pos + Vec2d(radius, 0).rotated(angle)
|
||||
# p2 = to_pygame(circle_edge, self.surface)
|
||||
# line_r = 2 if radius > 20 else 1
|
||||
# pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)
|
||||
|
||||
def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
|
||||
p1 = to_pygame(a, self.surface)
|
||||
p2 = to_pygame(b, self.surface)
|
||||
|
||||
pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])
|
||||
|
||||
def draw_fat_segment(
|
||||
self,
|
||||
a: Tuple[float, float],
|
||||
b: Tuple[float, float],
|
||||
radius: float,
|
||||
outline_color: SpaceDebugColor,
|
||||
fill_color: SpaceDebugColor,
|
||||
) -> None:
|
||||
p1 = to_pygame(a, self.surface)
|
||||
p2 = to_pygame(b, self.surface)
|
||||
|
||||
r = round(max(1, radius * 2))
|
||||
pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
|
||||
if r > 2:
|
||||
orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
|
||||
if orthog[0] == 0 and orthog[1] == 0:
|
||||
return
|
||||
scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
|
||||
orthog[0] = round(orthog[0] * scale)
|
||||
orthog[1] = round(orthog[1] * scale)
|
||||
points = [
|
||||
(p1[0] - orthog[0], p1[1] - orthog[1]),
|
||||
(p1[0] + orthog[0], p1[1] + orthog[1]),
|
||||
(p2[0] + orthog[0], p2[1] + orthog[1]),
|
||||
(p2[0] - orthog[0], p2[1] - orthog[1]),
|
||||
]
|
||||
pygame.draw.polygon(self.surface, fill_color.as_int(), points)
|
||||
pygame.draw.circle(
|
||||
self.surface,
|
||||
fill_color.as_int(),
|
||||
(round(p1[0]), round(p1[1])),
|
||||
round(radius),
|
||||
)
|
||||
pygame.draw.circle(
|
||||
self.surface,
|
||||
fill_color.as_int(),
|
||||
(round(p2[0]), round(p2[1])),
|
||||
round(radius),
|
||||
)
|
||||
|
||||
def draw_polygon(
|
||||
self,
|
||||
verts: Sequence[Tuple[float, float]],
|
||||
radius: float,
|
||||
outline_color: SpaceDebugColor,
|
||||
fill_color: SpaceDebugColor,
|
||||
) -> None:
|
||||
ps = [to_pygame(v, self.surface) for v in verts]
|
||||
ps += [ps[0]]
|
||||
|
||||
radius = 2
|
||||
pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)
|
||||
|
||||
if radius > 0:
|
||||
for i in range(len(verts)):
|
||||
a = verts[i]
|
||||
b = verts[(i + 1) % len(verts)]
|
||||
self.draw_fat_segment(a, b, radius, fill_color, fill_color)
|
||||
|
||||
def draw_dot(self, size: float, pos: Tuple[float, float], color: SpaceDebugColor) -> None:
|
||||
p = to_pygame(pos, self.surface)
|
||||
pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)
|
||||
|
||||
|
||||
def get_mouse_pos(surface: pygame.Surface) -> Tuple[int, int]:
|
||||
"""Get position of the mouse pointer in pymunk coordinates."""
|
||||
p = pygame.mouse.get_pos()
|
||||
return from_pygame(p, surface)
|
||||
|
||||
|
||||
def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
|
||||
"""Convenience method to convert pymunk coordinates to pygame surface
|
||||
local coordinates.
|
||||
|
||||
Note that in case positive_y_is_up is False, this function won't actually do
|
||||
anything except converting the point to integers.
|
||||
"""
|
||||
if positive_y_is_up:
|
||||
return round(p[0]), surface.get_height() - round(p[1])
|
||||
else:
|
||||
return round(p[0]), round(p[1])
|
||||
|
||||
|
||||
def from_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
|
||||
"""Convenience method to convert pygame surface local coordinates to
|
||||
pymunk coordinates
|
||||
"""
|
||||
return to_pygame(p, surface)
|
||||
|
||||
|
||||
def light_color(color: SpaceDebugColor):
|
||||
color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
|
||||
color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
|
||||
return color
|
||||
@@ -1,6 +1,8 @@
|
||||
import importlib
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
@@ -10,9 +12,9 @@ from torchrl.data.tensor_specs import (
|
||||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||
|
||||
from lerobot.common.envs.abstract import AbstractEnv
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
MAX_NUM_ACTIONS = 4
|
||||
@@ -21,7 +23,7 @@ _has_gym = importlib.util.find_spec("gym") is not None
|
||||
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
|
||||
|
||||
|
||||
class SimxarmEnv(EnvBase):
|
||||
class SimxarmEnv(AbstractEnv):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
@@ -31,19 +33,22 @@ class SimxarmEnv(EnvBase):
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=0,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
super().__init__(
|
||||
task=task,
|
||||
frame_skip=frame_skip,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=image_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_simxarm:
|
||||
raise ImportError("Cannot import simxarm.")
|
||||
if not _has_gym:
|
||||
@@ -63,9 +68,6 @@ class SimxarmEnv(EnvBase):
|
||||
if "w" not in TASKS[self.task]["action_space"]:
|
||||
self._action_padding[-1] = 1.0
|
||||
|
||||
self._make_spec()
|
||||
self.set_seed(seed)
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
return self._env.render(mode, width=width, height=height)
|
||||
|
||||
@@ -90,15 +92,33 @@ class SimxarmEnv(EnvBase):
|
||||
if td is None or td.is_empty():
|
||||
raw_obs = self._env.reset()
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.call_rendering_hooks()
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
@@ -108,10 +128,32 @@ class SimxarmEnv(EnvBase):
|
||||
action = np.concatenate([action, self._action_padding])
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
for _ in range(self.frame_skip):
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
|
||||
if action.ndim == 1:
|
||||
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
||||
else:
|
||||
if self.frame_skip > 1:
|
||||
raise NotImplementedError()
|
||||
|
||||
num_action_steps = action.shape[0]
|
||||
for i in range(num_action_steps):
|
||||
raw_obs, reward, done, info = self._env.step(action[i])
|
||||
sum_reward += reward
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"])
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
self.call_rendering_hooks()
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
@@ -126,23 +168,36 @@ class SimxarmEnv(EnvBase):
|
||||
def _make_spec(self):
|
||||
obs = {}
|
||||
if self.from_pixels:
|
||||
image_shape = (3, self.image_size, self.image_size)
|
||||
if self.num_prev_obs > 0:
|
||||
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||
|
||||
obs["image"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(3, self.image_size, self.image_size),
|
||||
shape=image_shape,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
state_shape = (len(self._env.robot_state),)
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
shape=(len(self._env.robot_state),),
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
state_shape = self._env.observation_space["observation"].shape
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
shape=self._env.observation_space["observation"].shape,
|
||||
# TODO:
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDictBase
|
||||
from tensordict.nn import dispatch
|
||||
from tensordict.utils import NestedKey
|
||||
@@ -7,19 +8,45 @@ from torchrl.envs.transforms import ObservationTransform, Transform
|
||||
|
||||
|
||||
class Prod(ObservationTransform):
|
||||
invertible = True
|
||||
|
||||
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.prod = prod
|
||||
self.original_dtypes = {}
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td):
|
||||
for key in self.in_keys:
|
||||
td[key] *= self.prod
|
||||
if td.get(key, None) is None:
|
||||
continue
|
||||
self.original_dtypes[key] = td[key].dtype
|
||||
td[key] = td[key].type(torch.float32) * self.prod
|
||||
return td
|
||||
|
||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
for key in self.in_keys:
|
||||
if td.get(key, None) is None:
|
||||
continue
|
||||
td[key] = (td[key] / self.prod).type(self.original_dtypes[key])
|
||||
return td
|
||||
|
||||
def transform_observation_spec(self, obs_spec):
|
||||
for key in self.in_keys:
|
||||
obs_spec[key].space.high *= self.prod
|
||||
if obs_spec.get(key, None) is None:
|
||||
continue
|
||||
obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
|
||||
obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
|
||||
obs_spec[key].dtype = torch.float32
|
||||
return obs_spec
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,10 @@ from omegaconf import OmegaConf
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
|
||||
def cfg_to_group(cfg, return_list=False):
|
||||
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
||||
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
|
||||
@@ -34,7 +38,7 @@ class Logger:
|
||||
project = cfg.get("wandb", {}).get("project")
|
||||
entity = cfg.get("wandb", {}).get("entity")
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
run_offline = not enable_wandb or not project or not entity
|
||||
run_offline = not enable_wandb or not project
|
||||
if run_offline:
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
self._wandb = None
|
||||
@@ -59,6 +63,7 @@ class Logger:
|
||||
resume=None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
def save_model(self, policy, identifier):
|
||||
|
||||
115
lerobot/common/policies/act/backbone.py
Normal file
115
lerobot/common/policies/act/backbone.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
|
||||
from .position_encoding import build_position_encoding
|
||||
from .utils import NestedTensor, is_main_process
|
||||
|
||||
|
||||
class FrozenBatchNorm2d(torch.nn.Module):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
|
||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||
produce nans.
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(n))
|
||||
self.register_buffer("bias", torch.zeros(n))
|
||||
self.register_buffer("running_mean", torch.zeros(n))
|
||||
self.register_buffer("running_var", torch.ones(n))
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
num_batches_tracked_key = prefix + "num_batches_tracked"
|
||||
if num_batches_tracked_key in state_dict:
|
||||
del state_dict[num_batches_tracked_key]
|
||||
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# move reshapes to the beginning
|
||||
# to make it fuser-friendly
|
||||
w = self.weight.reshape(1, -1, 1, 1)
|
||||
b = self.bias.reshape(1, -1, 1, 1)
|
||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||
eps = 1e-5
|
||||
scale = w * (rv + eps).rsqrt()
|
||||
bias = b - rm * scale
|
||||
return x * scale + bias
|
||||
|
||||
|
||||
class BackboneBase(nn.Module):
|
||||
def __init__(
|
||||
self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool
|
||||
):
|
||||
super().__init__()
|
||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||
# parameter.requires_grad_(False)
|
||||
if return_interm_layers:
|
||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||
else:
|
||||
return_layers = {"layer4": "0"}
|
||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, tensor):
|
||||
xs = self.body(tensor)
|
||||
return xs
|
||||
# out: Dict[str, NestedTensor] = {}
|
||||
# for name, x in xs.items():
|
||||
# m = tensor_list.mask
|
||||
# assert m is not None
|
||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||
# out[name] = NestedTensor(x, mask)
|
||||
# return out
|
||||
|
||||
|
||||
class Backbone(BackboneBase):
|
||||
"""ResNet backbone with frozen BatchNorm."""
|
||||
|
||||
def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=is_main_process(),
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
) # pretrained # TODO do we want frozen batch_norm??
|
||||
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
class Joiner(nn.Sequential):
|
||||
def __init__(self, backbone, position_embedding):
|
||||
super().__init__(backbone, position_embedding)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
xs = self[0](tensor_list)
|
||||
out: List[NestedTensor] = []
|
||||
pos = []
|
||||
for _, x in xs.items():
|
||||
out.append(x)
|
||||
# position encoding
|
||||
pos.append(self[1](x).to(x.dtype))
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
def build_backbone(args):
|
||||
position_embedding = build_position_encoding(args)
|
||||
train_backbone = args.lr_backbone > 0
|
||||
return_interm_layers = args.masks
|
||||
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
||||
model = Joiner(backbone, position_embedding)
|
||||
model.num_channels = backbone.num_channels
|
||||
return model
|
||||
212
lerobot/common/policies/act/detr_vae.py
Normal file
212
lerobot/common/policies/act/detr_vae.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
from .backbone import build_backbone
|
||||
from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer
|
||||
|
||||
|
||||
def reparametrize(mu, logvar):
|
||||
std = logvar.div(2).exp()
|
||||
eps = Variable(std.data.new(std.size()).normal_())
|
||||
return mu + std * eps
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
|
||||
class DETRVAE(nn.Module):
|
||||
"""This is the DETR module that performs object detection"""
|
||||
|
||||
def __init__(
|
||||
self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae
|
||||
):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
self.vae = vae
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
else:
|
||||
# input_dim = 14 + 7 # robot_state + env_state
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
# TODO(rcadene): understand what is env_state, and why it needs to be 7
|
||||
self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim)
|
||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||
self.backbones = None
|
||||
|
||||
# encoder extra parameters
|
||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
|
||||
self.latent_proj = nn.Linear(
|
||||
hidden_dim, self.latent_dim * 2
|
||||
) # project hidden state to latent std, var
|
||||
self.register_buffer(
|
||||
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
|
||||
) # [CLS], qpos, a_seq
|
||||
|
||||
# decoder extra parameters
|
||||
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
||||
self.additional_pos_embed = nn.Embedding(
|
||||
2, hidden_dim
|
||||
) # learned position embedding for proprio and latent
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
### Obtain latent z from action sequence
|
||||
if self.vae and is_training:
|
||||
# project action sequence to embedding dim, and concat with a CLS token
|
||||
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
|
||||
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
||||
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
||||
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
|
||||
encoder_input = torch.cat(
|
||||
[cls_embed, qpos_embed, action_embed], axis=1
|
||||
) # (bs, seq+1, hidden_dim)
|
||||
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
|
||||
# do not mask cls token
|
||||
# cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
|
||||
# is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||
# obtain position embedding
|
||||
pos_embed = self.pos_table.clone().detach()
|
||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||
# query model
|
||||
encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad)
|
||||
encoder_output = encoder_output[0] # take cls output only
|
||||
latent_info = self.latent_proj(encoder_output)
|
||||
mu = latent_info[:, : self.latent_dim]
|
||||
logvar = latent_info[:, self.latent_dim :]
|
||||
latent_sample = reparametrize(mu, logvar)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
else:
|
||||
mu = logvar = None
|
||||
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
if self.backbones is not None:
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
all_cam_pos = []
|
||||
for cam_id, _ in enumerate(self.camera_names):
|
||||
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0]
|
||||
all_cam_features.append(self.input_proj(features))
|
||||
all_cam_pos.append(pos)
|
||||
# proprioception features
|
||||
proprio_input = self.input_proj_robot_state(qpos)
|
||||
# fold camera dimension into width dimension
|
||||
src = torch.cat(all_cam_features, axis=3)
|
||||
pos = torch.cat(all_cam_pos, axis=3)
|
||||
hs = self.transformer(
|
||||
src,
|
||||
None,
|
||||
self.query_embed.weight,
|
||||
pos,
|
||||
latent_input,
|
||||
proprio_input,
|
||||
self.additional_pos_embed.weight,
|
||||
)[0]
|
||||
else:
|
||||
qpos = self.input_proj_robot_state(qpos)
|
||||
env_state = self.input_proj_env_state(env_state)
|
||||
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
|
||||
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
|
||||
a_hat = self.action_head(hs)
|
||||
is_pad_hat = self.is_pad_head(hs)
|
||||
return a_hat, is_pad_hat, [mu, logvar]
|
||||
|
||||
|
||||
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||
if hidden_depth == 0:
|
||||
mods = [nn.Linear(input_dim, output_dim)]
|
||||
else:
|
||||
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
for _ in range(hidden_depth - 1):
|
||||
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
mods.append(nn.Linear(hidden_dim, output_dim))
|
||||
trunk = nn.Sequential(*mods)
|
||||
return trunk
|
||||
|
||||
|
||||
def build_encoder(args):
|
||||
d_model = args.hidden_dim # 256
|
||||
dropout = args.dropout # 0.1
|
||||
nhead = args.nheads # 8
|
||||
dim_feedforward = args.dim_feedforward # 2048
|
||||
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
|
||||
normalize_before = args.pre_norm # False
|
||||
activation = "relu"
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def build(args):
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
transformer = build_transformer(args)
|
||||
|
||||
encoder = build_encoder(args)
|
||||
|
||||
model = DETRVAE(
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim=args.state_dim,
|
||||
action_dim=args.action_dim,
|
||||
num_queries=args.num_queries,
|
||||
camera_names=args.camera_names,
|
||||
vae=args.vae,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
|
||||
|
||||
return model
|
||||
217
lerobot/common/policies/act/policy.py
Normal file
217
lerobot/common/policies/act/policy.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from lerobot.common.policies.act.detr_vae import build
|
||||
|
||||
|
||||
def build_act_model_and_optimizer(cfg):
|
||||
model = build(cfg)
|
||||
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
||||
"lr": cfg.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def kl_divergence(mu, logvar):
|
||||
batch_size = mu.size(0)
|
||||
assert batch_size != 0
|
||||
if mu.data.ndimension() == 4:
|
||||
mu = mu.view(mu.size(0), mu.size(1))
|
||||
if logvar.data.ndimension() == 4:
|
||||
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
||||
|
||||
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
||||
total_kld = klds.sum(1).mean(0, True)
|
||||
dimension_wise_kld = klds.mean(0)
|
||||
mean_kld = klds.mean(1).mean(0, True)
|
||||
|
||||
return total_kld, dimension_wise_kld, mean_kld
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.n_action_steps = n_action_steps
|
||||
self.device = device
|
||||
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
||||
self.kl_weight = self.cfg.kl_weight
|
||||
logging.info(f"KL Weight {self.kl_weight}")
|
||||
self.to(self.device)
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
del step
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
self.train()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
def process_batch(batch, horizon, num_slices):
|
||||
# trajectory t = 64, horizon h = 16
|
||||
# (t h) ... -> t h ...
|
||||
batch = batch.reshape(num_slices, horizon)
|
||||
|
||||
image = batch["observation", "image", "top"]
|
||||
image = image[:, 0] # first observation t=0
|
||||
# batch, num_cam, channel, height, width
|
||||
image = image.unsqueeze(1)
|
||||
assert image.ndim == 5
|
||||
image = image.float()
|
||||
|
||||
state = batch["observation", "state"]
|
||||
state = state[:, 0] # first observation t=0
|
||||
# batch, qpos_dim
|
||||
assert state.ndim == 2
|
||||
|
||||
action = batch["action"]
|
||||
# batch, seq, action_dim
|
||||
assert action.ndim == 3
|
||||
assert action.shape[1] == horizon
|
||||
|
||||
if self.cfg.n_obs_steps > 1:
|
||||
raise NotImplementedError()
|
||||
# # keep first n observations of the slice corresponding to t=[-1,0]
|
||||
# image = image[:, : self.cfg.n_obs_steps]
|
||||
# state = state[:, : self.cfg.n_obs_steps]
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": image.to(self.device, non_blocking=True),
|
||||
"agent_pos": state.to(self.device, non_blocking=True),
|
||||
},
|
||||
"action": action.to(self.device, non_blocking=True),
|
||||
}
|
||||
return out
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
||||
loss = self.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
# self.lr_scheduler.step()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
# "lr": self.lr_scheduler.get_last_lr()[0],
|
||||
"lr": self.cfg.lr,
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
loss_dict = self._forward(
|
||||
qpos=batch["obs"]["agent_pos"],
|
||||
image=batch["obs"]["image"],
|
||||
actions=batch["action"],
|
||||
)
|
||||
loss = loss_dict["loss"]
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
self.eval()
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
||||
observation["image", "top"] = observation["image", "top"].unsqueeze(0)
|
||||
# observation["state"] = observation["state"].unsqueeze(0)
|
||||
|
||||
# TODO(rcadene): remove hack
|
||||
# add 1 camera dimension
|
||||
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
|
||||
|
||||
obs_dict = {
|
||||
"image": observation["image", "top"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
|
||||
|
||||
if self.cfg.temporal_agg:
|
||||
# TODO(rcadene): implement temporal aggregation
|
||||
raise NotImplementedError()
|
||||
# all_time_actions[[t], t:t+num_queries] = action
|
||||
# actions_for_curr_step = all_time_actions[:, t]
|
||||
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
|
||||
# actions_for_curr_step = actions_for_curr_step[actions_populated]
|
||||
# k = 0.01
|
||||
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
||||
# exp_weights = exp_weights / exp_weights.sum()
|
||||
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
||||
|
||||
# remove bsize=1
|
||||
action = action.squeeze(0)
|
||||
|
||||
# take first predicted action or n first actions
|
||||
action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps]
|
||||
return action
|
||||
|
||||
def _forward(self, qpos, image, actions=None, is_pad=None):
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
|
||||
is_training = actions is not None
|
||||
if is_training: # training time
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
if is_pad is not None:
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
|
||||
loss_dict = {}
|
||||
loss_dict["l1"] = l1
|
||||
if self.cfg.vae:
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
loss_dict["kl"] = total_kld[0]
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = loss_dict["l1"]
|
||||
return loss_dict
|
||||
else:
|
||||
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
return action
|
||||
101
lerobot/common/policies/act/position_encoding.py
Normal file
101
lerobot/common/policies/act/position_encoding.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Various positional encodings for the transformer.
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .utils import NestedTensor
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, tensor):
|
||||
x = tensor
|
||||
# mask = tensor_list.mask
|
||||
# assert mask is not None
|
||||
# not_mask = ~mask
|
||||
|
||||
not_mask = torch.ones_like(x[0, [0]])
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Module):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.uniform_(self.row_embed.weight)
|
||||
nn.init.uniform_(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
h, w = x.shape[-2:]
|
||||
i = torch.arange(w, device=x.device)
|
||||
j = torch.arange(h, device=x.device)
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
pos = (
|
||||
torch.cat(
|
||||
[
|
||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
.permute(2, 0, 1)
|
||||
.unsqueeze(0)
|
||||
.repeat(x.shape[0], 1, 1, 1)
|
||||
)
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(args):
|
||||
n_steps = args.hidden_dim // 2
|
||||
if args.position_embedding in ("v2", "sine"):
|
||||
# TODO find a better way of exposing other arguments
|
||||
position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
|
||||
elif args.position_embedding in ("v3", "learned"):
|
||||
position_embedding = PositionEmbeddingLearned(n_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {args.position_embedding}")
|
||||
|
||||
return position_embedding
|
||||
370
lerobot/common/policies/act/transformer.py
Normal file
370
lerobot/common/policies/act/transformer.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
DETR Transformer class.
|
||||
|
||||
Copy-paste from torch.nn.Transformer with modifications:
|
||||
* positional encodings are passed in MHattention
|
||||
* extra LN at the end of encoder is removed
|
||||
* decoder returns a stack of activations from all decoding layers
|
||||
"""
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
return_intermediate_dec=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoder(
|
||||
decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask,
|
||||
query_embed,
|
||||
pos_embed,
|
||||
latent_input=None,
|
||||
proprio_input=None,
|
||||
additional_pos_embed=None,
|
||||
):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
bs, c, h, w = src.shape
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
# flatten NxHWxC to HWxNxC
|
||||
bs, hw, c = src.shape
|
||||
src = src.permute(1, 0, 2)
|
||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
tgt = torch.zeros_like(query_embed)
|
||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
|
||||
hs = hs.transpose(1, 2)
|
||||
return hs
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = tgt
|
||||
|
||||
intermediate = []
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos,
|
||||
query_pos=query_pos,
|
||||
)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(self.norm(output))
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
if self.return_intermediate:
|
||||
intermediate.pop()
|
||||
intermediate.append(output)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(src, pos)
|
||||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt2, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
return self.forward_post(
|
||||
tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
|
||||
)
|
||||
|
||||
|
||||
def _get_clones(module, n):
|
||||
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
|
||||
|
||||
|
||||
def build_transformer(args):
|
||||
return Transformer(
|
||||
d_model=args.hidden_dim,
|
||||
dropout=args.dropout,
|
||||
nhead=args.nheads,
|
||||
dim_feedforward=args.dim_feedforward,
|
||||
num_encoder_layers=args.enc_layers,
|
||||
num_decoder_layers=args.dec_layers,
|
||||
normalize_before=args.pre_norm,
|
||||
return_intermediate_dec=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
477
lerobot/common/policies/act/utils.py
Normal file
477
lerobot/common/policies/act/utils.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||
import torchvision
|
||||
from packaging import version
|
||||
from torch import Tensor
|
||||
|
||||
if version.parse(torchvision.__version__) < version.parse("0.7"):
|
||||
from torchvision.ops import _new_empty_tensor
|
||||
from torchvision.ops.misc import _output_size
|
||||
|
||||
|
||||
class SmoothedValue:
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
||||
)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list, strict=False):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that all processes
|
||||
have the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416
|
||||
return reduced_dict
|
||||
|
||||
|
||||
class MetricLogger:
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append("{}: {}".format(name, str(meter)))
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
if not header:
|
||||
header = ""
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join(
|
||||
[
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
"max mem: {memory:.0f}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = self.delimiter.join(
|
||||
[
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
]
|
||||
)
|
||||
mega_b = 1024.0 * 1024.0
|
||||
for i, obj in enumerate(iterable):
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / mega_b,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
||||
|
||||
sha = "N/A"
|
||||
diff = "clean"
|
||||
branch = "N/A"
|
||||
try:
|
||||
sha = _run(["git", "rev-parse", "HEAD"])
|
||||
subprocess.check_output(["git", "diff"], cwd=cwd)
|
||||
diff = _run(["git", "diff-index", "HEAD"])
|
||||
diff = "has uncommited changes" if diff else "clean"
|
||||
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
batch = list(zip(*batch, strict=False))
|
||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||
return tuple(batch)
|
||||
|
||||
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
# type: (Device) -> NestedTensor # noqa
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
assert mask is not None
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
# TODO make this more general
|
||||
if tensor_list[0].ndim == 3:
|
||||
if torchvision._is_tracing():
|
||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||
|
||||
# TODO make it support different-sized images
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
b, c, h, w = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], : img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError("not supported")
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||
@torch.jit.unused
|
||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||
max_size = []
|
||||
for i in range(tensor_list[0].dim()):
|
||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(
|
||||
torch.int64
|
||||
)
|
||||
max_size.append(max_size_i)
|
||||
max_size = tuple(max_size)
|
||||
|
||||
# work around for
|
||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
# m[: img.shape[1], :img.shape[2]] = False
|
||||
# which is not yet supported in onnx
|
||||
padded_imgs = []
|
||||
padded_masks = []
|
||||
for img in tensor_list:
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
|
||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||
padded_imgs.append(padded_img)
|
||||
|
||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||
padded_masks.append(padded_mask.to(torch.bool))
|
||||
|
||||
tensor = torch.stack(padded_imgs)
|
||||
mask = torch.stack(padded_masks)
|
||||
|
||||
return NestedTensor(tensor, mask=mask)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop("force", False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||
elif "SLURM_PROCID" in os.environ:
|
||||
args.rank = int(os.environ["SLURM_PROCID"])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print("Not using distributed mode")
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = "nccl"
|
||||
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
||||
)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if target.numel() == 0:
|
||||
return [torch.zeros([], device=output.device)]
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||
"""
|
||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||
This will eventually be supported natively by PyTorch, and this
|
||||
class can go away.
|
||||
"""
|
||||
if version.parse(torchvision.__version__) < version.parse("0.7"):
|
||||
if input.numel() > 0:
|
||||
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
|
||||
output_shape = _output_size(2, input, size, scale_factor)
|
||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||
return _new_empty_tensor(input, output_shape)
|
||||
else:
|
||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
@@ -5,11 +5,33 @@ import torch.nn.functional as F # noqa: N812
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from einops import reduce
|
||||
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
|
||||
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
|
||||
from lerobot.common.policies.diffusion.model.mask_generator import LowdimMaskGenerator
|
||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
from lerobot.common.policies.diffusion.model.normalizer import LinearNormalizer
|
||||
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
|
||||
|
||||
|
||||
class BaseImagePolicy(ModuleAttrMixin):
|
||||
# init accepts keyword argument shape_meta, see config/task/*_image.yaml
|
||||
|
||||
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
obs_dict:
|
||||
str: B,To,*
|
||||
return: B,Ta,Da
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
# reset state for stateful policies
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
# ========== training ===========
|
||||
# no standard training interface except setting normalizer
|
||||
def set_normalizer(self, normalizer: LinearNormalizer):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiffusionUnetImagePolicy(BaseImagePolicy):
|
||||
|
||||
286
lerobot/common/policies/diffusion/model/conditional_unet1d.py
Normal file
286
lerobot/common/policies/diffusion/model/conditional_unet1d.py
Normal file
@@ -0,0 +1,286 @@
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d
|
||||
from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConditionalResidualBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
]
|
||||
)
|
||||
|
||||
# FiLM modulation https://arxiv.org/abs/1709.07871
|
||||
# predicts per-channel scale and bias
|
||||
cond_channels = out_channels
|
||||
if cond_predict_scale:
|
||||
cond_channels = out_channels * 2
|
||||
self.cond_predict_scale = cond_predict_scale
|
||||
self.out_channels = out_channels
|
||||
self.cond_encoder = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(cond_dim, cond_channels),
|
||||
Rearrange("batch t -> batch t 1"),
|
||||
)
|
||||
|
||||
# make sure dimensions compatible
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, cond):
|
||||
"""
|
||||
x : [ batch_size x in_channels x horizon ]
|
||||
cond : [ batch_size x cond_dim]
|
||||
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
"""
|
||||
out = self.blocks[0](x)
|
||||
embed = self.cond_encoder(cond)
|
||||
if self.cond_predict_scale:
|
||||
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
|
||||
scale = embed[:, 0, ...]
|
||||
bias = embed[:, 1, ...]
|
||||
out = scale * out + bias
|
||||
else:
|
||||
out = out + embed
|
||||
out = self.blocks[1](out)
|
||||
out = out + self.residual_conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalUnet1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
local_cond_dim=None,
|
||||
global_cond_dim=None,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=None,
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=False,
|
||||
):
|
||||
super().__init__()
|
||||
if down_dims is None:
|
||||
down_dims = [256, 512, 1024]
|
||||
|
||||
all_dims = [input_dim] + list(down_dims)
|
||||
start_dim = down_dims[0]
|
||||
|
||||
dsed = diffusion_step_embed_dim
|
||||
diffusion_step_encoder = nn.Sequential(
|
||||
SinusoidalPosEmb(dsed),
|
||||
nn.Linear(dsed, dsed * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dsed * 4, dsed),
|
||||
)
|
||||
cond_dim = dsed
|
||||
if global_cond_dim is not None:
|
||||
cond_dim += global_cond_dim
|
||||
|
||||
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
|
||||
|
||||
local_cond_encoder = None
|
||||
if local_cond_dim is not None:
|
||||
_, dim_out = in_out[0]
|
||||
dim_in = local_cond_dim
|
||||
local_cond_encoder = nn.ModuleList(
|
||||
[
|
||||
# down encoder
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
# up encoder
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
mid_dim = all_dims[-1]
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
down_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
up_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
up_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out * 2,
|
||||
dim_in,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_in,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
final_conv = nn.Sequential(
|
||||
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
|
||||
nn.Conv1d(start_dim, input_dim, 1),
|
||||
)
|
||||
|
||||
self.diffusion_step_encoder = diffusion_step_encoder
|
||||
self.local_cond_encoder = local_cond_encoder
|
||||
self.up_modules = up_modules
|
||||
self.down_modules = down_modules
|
||||
self.final_conv = final_conv
|
||||
|
||||
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
local_cond=None,
|
||||
global_cond=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
x: (B,T,input_dim)
|
||||
timestep: (B,) or int, diffusion step
|
||||
local_cond: (B,T,local_cond_dim)
|
||||
global_cond: (B,global_cond_dim)
|
||||
output: (B,T,input_dim)
|
||||
"""
|
||||
sample = einops.rearrange(sample, "b h t -> b t h")
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
global_feature = self.diffusion_step_encoder(timesteps)
|
||||
|
||||
if global_cond is not None:
|
||||
global_feature = torch.cat([global_feature, global_cond], axis=-1)
|
||||
|
||||
# encode local features
|
||||
h_local = []
|
||||
if local_cond is not None:
|
||||
local_cond = einops.rearrange(local_cond, "b h t -> b t h")
|
||||
resnet, resnet2 = self.local_cond_encoder
|
||||
x = resnet(local_cond, global_feature)
|
||||
h_local.append(x)
|
||||
x = resnet2(local_cond, global_feature)
|
||||
h_local.append(x)
|
||||
|
||||
x = sample
|
||||
h = []
|
||||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||
x = resnet(x, global_feature)
|
||||
if idx == 0 and len(h_local) > 0:
|
||||
x = x + h_local[0]
|
||||
x = resnet2(x, global_feature)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
for mid_module in self.mid_modules:
|
||||
x = mid_module(x, global_feature)
|
||||
|
||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
# The correct condition should be:
|
||||
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
|
||||
# However this change will break compatibility with published checkpoints.
|
||||
# Therefore it is left as a comment.
|
||||
if idx == len(self.up_modules) and len(h_local) > 0:
|
||||
x = x + h_local[1]
|
||||
x = resnet2(x, global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
x = einops.rearrange(x, "b t h -> b h t")
|
||||
return x
|
||||
47
lerobot/common/policies/diffusion/model/conv1d_components.py
Normal file
47
lerobot/common/policies/diffusion/model/conv1d_components.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch.nn as nn
|
||||
|
||||
# from einops.layers.torch import Rearrange
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
# def test():
|
||||
# cb = Conv1dBlock(256, 128, kernel_size=3)
|
||||
# x = torch.zeros((1,256,16))
|
||||
# o = cb(x)
|
||||
294
lerobot/common/policies/diffusion/model/crop_randomizer.py
Normal file
294
lerobot/common/policies/diffusion/model/crop_randomizer.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as ttf
|
||||
|
||||
import lerobot.common.policies.diffusion.model.tensor_utils as tu
|
||||
|
||||
|
||||
class CropRandomizer(nn.Module):
|
||||
"""
|
||||
Randomly sample crops at input, and then average across crop features at output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shape,
|
||||
crop_height,
|
||||
crop_width,
|
||||
num_crops=1,
|
||||
pos_enc=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_shape (tuple, list): shape of input (not including batch dimension)
|
||||
crop_height (int): crop height
|
||||
crop_width (int): crop width
|
||||
num_crops (int): number of random crops to take
|
||||
pos_enc (bool): if True, add 2 channels to the output to encode the spatial
|
||||
location of the cropped pixels in the source image
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3 # (C, H, W)
|
||||
assert crop_height < input_shape[1]
|
||||
assert crop_width < input_shape[2]
|
||||
|
||||
self.input_shape = input_shape
|
||||
self.crop_height = crop_height
|
||||
self.crop_width = crop_width
|
||||
self.num_crops = num_crops
|
||||
self.pos_enc = pos_enc
|
||||
|
||||
def output_shape_in(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module. Corresponds to
|
||||
the @forward_in operation, where raw inputs (usually observation modalities)
|
||||
are passed in.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
|
||||
# outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
|
||||
# the number of crops are reshaped into the batch dimension, increasing the batch
|
||||
# size from B to B * N
|
||||
out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
|
||||
return [out_c, self.crop_height, self.crop_width]
|
||||
|
||||
def output_shape_out(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module. Corresponds to
|
||||
the @forward_out operation, where processed inputs (usually encoded observation
|
||||
modalities) are passed in.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
|
||||
# since the forward_out operation splits [B * N, ...] -> [B, N, ...]
|
||||
# and then pools to result in [B, ...], only the batch dimension changes,
|
||||
# and so the other dimensions retain their shape.
|
||||
return list(input_shape)
|
||||
|
||||
def forward_in(self, inputs):
|
||||
"""
|
||||
Samples N random crops for each input in the batch, and then reshapes
|
||||
inputs to [B * N, ...].
|
||||
"""
|
||||
assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions
|
||||
if self.training:
|
||||
# generate random crops
|
||||
out, _ = sample_random_image_crops(
|
||||
images=inputs,
|
||||
crop_height=self.crop_height,
|
||||
crop_width=self.crop_width,
|
||||
num_crops=self.num_crops,
|
||||
pos_enc=self.pos_enc,
|
||||
)
|
||||
# [B, N, ...] -> [B * N, ...]
|
||||
return tu.join_dimensions(out, 0, 1)
|
||||
else:
|
||||
# take center crop during eval
|
||||
out = ttf.center_crop(img=inputs, output_size=(self.crop_height, self.crop_width))
|
||||
if self.num_crops > 1:
|
||||
B, C, H, W = out.shape # noqa: N806
|
||||
out = out.unsqueeze(1).expand(B, self.num_crops, C, H, W).reshape(-1, C, H, W)
|
||||
# [B * N, ...]
|
||||
return out
|
||||
|
||||
def forward_out(self, inputs):
|
||||
"""
|
||||
Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
|
||||
to result in shape [B, ...] to make sure the network output is consistent with
|
||||
what would have happened if there were no randomization.
|
||||
"""
|
||||
if self.num_crops <= 1:
|
||||
return inputs
|
||||
else:
|
||||
batch_size = inputs.shape[0] // self.num_crops
|
||||
out = tu.reshape_dimensions(
|
||||
inputs, begin_axis=0, end_axis=0, target_dims=(batch_size, self.num_crops)
|
||||
)
|
||||
return out.mean(dim=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.forward_in(inputs)
|
||||
|
||||
def __repr__(self):
|
||||
"""Pretty print network."""
|
||||
header = "{}".format(str(self.__class__.__name__))
|
||||
msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
|
||||
self.input_shape, self.crop_height, self.crop_width, self.num_crops
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
|
||||
"""
|
||||
Crops images at the locations specified by @crop_indices. Crops will be
|
||||
taken across all channels.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||
|
||||
crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
|
||||
N is the number of crops to take per image and each entry corresponds
|
||||
to the pixel height and width of where to take the crop. Note that
|
||||
the indices can also be of shape [..., 2] if only 1 crop should
|
||||
be taken per image. Leading dimensions must be consistent with
|
||||
@images argument. Each index specifies the top left of the crop.
|
||||
Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
|
||||
H and W are the height and width of @images and CH and CW are
|
||||
@crop_height and @crop_width.
|
||||
|
||||
crop_height (int): height of crop to take
|
||||
|
||||
crop_width (int): width of crop to take
|
||||
|
||||
Returns:
|
||||
crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
|
||||
"""
|
||||
|
||||
# make sure length of input shapes is consistent
|
||||
assert crop_indices.shape[-1] == 2
|
||||
ndim_im_shape = len(images.shape)
|
||||
ndim_indices_shape = len(crop_indices.shape)
|
||||
assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2)
|
||||
|
||||
# maybe pad so that @crop_indices is shape [..., N, 2]
|
||||
is_padded = False
|
||||
if ndim_im_shape == ndim_indices_shape + 2:
|
||||
crop_indices = crop_indices.unsqueeze(-2)
|
||||
is_padded = True
|
||||
|
||||
# make sure leading dimensions between images and indices are consistent
|
||||
assert images.shape[:-3] == crop_indices.shape[:-2]
|
||||
|
||||
device = images.device
|
||||
image_c, image_h, image_w = images.shape[-3:]
|
||||
num_crops = crop_indices.shape[-2]
|
||||
|
||||
# make sure @crop_indices are in valid range
|
||||
assert (crop_indices[..., 0] >= 0).all().item()
|
||||
assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
|
||||
assert (crop_indices[..., 1] >= 0).all().item()
|
||||
assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
|
||||
|
||||
# convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
|
||||
|
||||
# 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
|
||||
crop_ind_grid_h = torch.arange(crop_height).to(device)
|
||||
crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1)
|
||||
# 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
|
||||
crop_ind_grid_w = torch.arange(crop_width).to(device)
|
||||
crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0)
|
||||
# combine into shape [CH, CW, 2]
|
||||
crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
|
||||
|
||||
# Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
|
||||
# After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
|
||||
# shape array that tells us which pixels from the corresponding source image to grab.
|
||||
grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2]
|
||||
all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape)
|
||||
|
||||
# For using @torch.gather, convert to flat indices from 2D indices, and also
|
||||
# repeat across the channel dimension. To get flat index of each pixel to grab for
|
||||
# each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
|
||||
all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] # shape [..., N, CH, CW]
|
||||
all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) # shape [..., N, C, CH, CW]
|
||||
all_crop_inds = tu.flatten(all_crop_inds, begin_axis=-2) # shape [..., N, C, CH * CW]
|
||||
|
||||
# Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
|
||||
images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
|
||||
images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
|
||||
crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
|
||||
# [..., N, C, CH * CW] -> [..., N, C, CH, CW]
|
||||
reshape_axis = len(crops.shape) - 1
|
||||
crops = tu.reshape_dimensions(
|
||||
crops, begin_axis=reshape_axis, end_axis=reshape_axis, target_dims=(crop_height, crop_width)
|
||||
)
|
||||
|
||||
if is_padded:
|
||||
# undo padding -> [..., C, CH, CW]
|
||||
crops = crops.squeeze(-4)
|
||||
return crops
|
||||
|
||||
|
||||
def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False):
|
||||
"""
|
||||
For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
|
||||
@images.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||
|
||||
crop_height (int): height of crop to take
|
||||
|
||||
crop_width (int): width of crop to take
|
||||
|
||||
num_crops (n): number of crops to sample
|
||||
|
||||
pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
|
||||
encoding of the original source pixel locations. This means that the
|
||||
output crops will contain information about where in the source image
|
||||
it was sampled from.
|
||||
|
||||
Returns:
|
||||
crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
|
||||
if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
|
||||
|
||||
crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
|
||||
"""
|
||||
device = images.device
|
||||
|
||||
# maybe add 2 channels of spatial encoding to the source image
|
||||
source_im = images
|
||||
if pos_enc:
|
||||
# spatial encoding [y, x] in [0, 1]
|
||||
h, w = source_im.shape[-2:]
|
||||
pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
|
||||
pos_y = pos_y.float().to(device) / float(h)
|
||||
pos_x = pos_x.float().to(device) / float(w)
|
||||
position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
|
||||
|
||||
# unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
|
||||
leading_shape = source_im.shape[:-3]
|
||||
position_enc = position_enc[(None,) * len(leading_shape)]
|
||||
position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
|
||||
|
||||
# concat across channel dimension with input
|
||||
source_im = torch.cat((source_im, position_enc), dim=-3)
|
||||
|
||||
# make sure sample boundaries ensure crops are fully within the images
|
||||
image_c, image_h, image_w = source_im.shape[-3:]
|
||||
max_sample_h = image_h - crop_height
|
||||
max_sample_w = image_w - crop_width
|
||||
|
||||
# Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
|
||||
# Each gets @num_crops samples - typically this will just be the batch dimension (B), so
|
||||
# we will sample [B, N] indices, but this supports having more than one leading dimension,
|
||||
# or possibly no leading dimension.
|
||||
#
|
||||
# Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
|
||||
crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||
crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||
crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) # shape [..., N, 2]
|
||||
|
||||
crops = crop_image_from_indices(
|
||||
images=source_im,
|
||||
crop_indices=crop_inds,
|
||||
crop_height=crop_height,
|
||||
crop_width=crop_width,
|
||||
)
|
||||
|
||||
return crops, crop_inds
|
||||
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class DictOfTensorMixin(nn.Module):
|
||||
def __init__(self, params_dict=None):
|
||||
super().__init__()
|
||||
if params_dict is None:
|
||||
params_dict = nn.ParameterDict()
|
||||
self.params_dict = params_dict
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
def dfs_add(dest, keys, value: torch.Tensor):
|
||||
if len(keys) == 1:
|
||||
dest[keys[0]] = value
|
||||
return
|
||||
|
||||
if keys[0] not in dest:
|
||||
dest[keys[0]] = nn.ParameterDict()
|
||||
dfs_add(dest[keys[0]], keys[1:], value)
|
||||
|
||||
def load_dict(state_dict, prefix):
|
||||
out_dict = nn.ParameterDict()
|
||||
for key, value in state_dict.items():
|
||||
value: torch.Tensor
|
||||
if key.startswith(prefix):
|
||||
param_keys = key[len(prefix) :].split(".")[1:]
|
||||
# if len(param_keys) == 0:
|
||||
# import pdb; pdb.set_trace()
|
||||
dfs_add(out_dict, param_keys, value.clone())
|
||||
return out_dict
|
||||
|
||||
self.params_dict = load_dict(state_dict, prefix + "params_dict")
|
||||
self.params_dict.requires_grad_(False)
|
||||
return
|
||||
84
lerobot/common/policies/diffusion/model/ema_model.py
Normal file
84
lerobot/common/policies/diffusion/model/ema_model.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
|
||||
class EMAModel:
|
||||
"""
|
||||
Exponential Moving Average of models weights
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999
|
||||
):
|
||||
"""
|
||||
@crowsonkb's notes on EMA Warmup:
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
||||
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
||||
at 215.4k steps).
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
"""
|
||||
|
||||
self.averaged_model = model
|
||||
self.averaged_model.eval()
|
||||
self.averaged_model.requires_grad_(False)
|
||||
|
||||
self.update_after_step = update_after_step
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
|
||||
self.decay = 0.0
|
||||
self.optimization_step = 0
|
||||
|
||||
def get_decay(self, optimization_step):
|
||||
"""
|
||||
Compute the decay factor for the exponential moving average.
|
||||
"""
|
||||
step = max(0, optimization_step - self.update_after_step - 1)
|
||||
value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
||||
|
||||
if step <= 0:
|
||||
return 0.0
|
||||
|
||||
return max(self.min_value, min(value, self.max_value))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, new_model):
|
||||
self.decay = self.get_decay(self.optimization_step)
|
||||
|
||||
# old_all_dataptrs = set()
|
||||
# for param in new_model.parameters():
|
||||
# data_ptr = param.data_ptr()
|
||||
# if data_ptr != 0:
|
||||
# old_all_dataptrs.add(data_ptr)
|
||||
|
||||
# all_dataptrs = set()
|
||||
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False):
|
||||
for param, ema_param in zip(
|
||||
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False
|
||||
):
|
||||
# iterative over immediate parameters only.
|
||||
if isinstance(param, dict):
|
||||
raise RuntimeError("Dict parameter not supported")
|
||||
|
||||
# data_ptr = param.data_ptr()
|
||||
# if data_ptr != 0:
|
||||
# all_dataptrs.add(data_ptr)
|
||||
|
||||
if isinstance(module, _BatchNorm):
|
||||
# skip batchnorms
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
elif not param.requires_grad:
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
else:
|
||||
ema_param.mul_(self.decay)
|
||||
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
|
||||
|
||||
# verify that iterating over module and then parameters is identical to parameters recursively.
|
||||
# assert old_all_dataptrs == all_dataptrs
|
||||
self.optimization_step += 1
|
||||
46
lerobot/common/policies/diffusion/model/lr_scheduler.py
Normal file
46
lerobot/common/policies/diffusion/model/lr_scheduler.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, Optimizer, Optional, SchedulerType, Union
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Added kwargs vs diffuser's original implementation
|
||||
|
||||
Unified API to get any scheduler from its name.
|
||||
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer, **kwargs)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs
|
||||
)
|
||||
65
lerobot/common/policies/diffusion/model/mask_generator.py
Normal file
65
lerobot/common/policies/diffusion/model/mask_generator.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||
|
||||
|
||||
class LowdimMaskGenerator(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
action_dim,
|
||||
obs_dim,
|
||||
# obs mask setup
|
||||
max_n_obs_steps=2,
|
||||
fix_obs_steps=True,
|
||||
# action mask
|
||||
action_visible=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.action_dim = action_dim
|
||||
self.obs_dim = obs_dim
|
||||
self.max_n_obs_steps = max_n_obs_steps
|
||||
self.fix_obs_steps = fix_obs_steps
|
||||
self.action_visible = action_visible
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, shape, seed=None):
|
||||
device = self.device
|
||||
B, T, D = shape # noqa: N806
|
||||
assert (self.action_dim + self.obs_dim) == D
|
||||
|
||||
# create all tensors on this device
|
||||
rng = torch.Generator(device=device)
|
||||
if seed is not None:
|
||||
rng = rng.manual_seed(seed)
|
||||
|
||||
# generate dim mask
|
||||
dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device)
|
||||
is_action_dim = dim_mask.clone()
|
||||
is_action_dim[..., : self.action_dim] = True
|
||||
is_obs_dim = ~is_action_dim
|
||||
|
||||
# generate obs mask
|
||||
if self.fix_obs_steps:
|
||||
obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device)
|
||||
else:
|
||||
obs_steps = torch.randint(
|
||||
low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device
|
||||
)
|
||||
|
||||
steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T)
|
||||
obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
|
||||
obs_mask = obs_mask & is_obs_dim
|
||||
|
||||
# generate action mask
|
||||
if self.action_visible:
|
||||
action_steps = torch.maximum(
|
||||
obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device)
|
||||
)
|
||||
action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D)
|
||||
action_mask = action_mask & is_action_dim
|
||||
|
||||
mask = obs_mask
|
||||
if self.action_visible:
|
||||
mask = mask | action_mask
|
||||
|
||||
return mask
|
||||
15
lerobot/common/policies/diffusion/model/module_attr_mixin.py
Normal file
15
lerobot/common/policies/diffusion/model/module_attr_mixin.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ModuleAttrMixin(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._dummy_variable = nn.Parameter()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
@@ -5,9 +5,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
from diffusion_policy.common.pytorch_util import replace_submodules
|
||||
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
||||
from diffusion_policy.model.vision.crop_randomizer import CropRandomizer
|
||||
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
|
||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
||||
|
||||
|
||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
358
lerobot/common/policies/diffusion/model/normalizer.py
Normal file
358
lerobot/common/policies/diffusion/model/normalizer.py
Normal file
@@ -0,0 +1,358 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import zarr
|
||||
|
||||
from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin
|
||||
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
|
||||
|
||||
|
||||
class LinearNormalizer(DictOfTensorMixin):
|
||||
avaliable_modes = ["limits", "gaussian"]
|
||||
|
||||
@torch.no_grad()
|
||||
def fit(
|
||||
self,
|
||||
data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array],
|
||||
last_n_dims=1,
|
||||
dtype=torch.float32,
|
||||
mode="limits",
|
||||
output_max=1.0,
|
||||
output_min=-1.0,
|
||||
range_eps=1e-4,
|
||||
fit_offset=True,
|
||||
):
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
self.params_dict[key] = _fit(
|
||||
value,
|
||||
last_n_dims=last_n_dims,
|
||||
dtype=dtype,
|
||||
mode=mode,
|
||||
output_max=output_max,
|
||||
output_min=output_min,
|
||||
range_eps=range_eps,
|
||||
fit_offset=fit_offset,
|
||||
)
|
||||
else:
|
||||
self.params_dict["_default"] = _fit(
|
||||
data,
|
||||
last_n_dims=last_n_dims,
|
||||
dtype=dtype,
|
||||
mode=mode,
|
||||
output_max=output_max,
|
||||
output_min=output_min,
|
||||
range_eps=range_eps,
|
||||
fit_offset=fit_offset,
|
||||
)
|
||||
|
||||
def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return self.normalize(x)
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return SingleFieldLinearNormalizer(self.params_dict[key])
|
||||
|
||||
def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"):
|
||||
self.params_dict[key] = value.params_dict
|
||||
|
||||
def _normalize_impl(self, x, forward=True):
|
||||
if isinstance(x, dict):
|
||||
result = {}
|
||||
for key, value in x.items():
|
||||
params = self.params_dict[key]
|
||||
result[key] = _normalize(value, params, forward=forward)
|
||||
return result
|
||||
else:
|
||||
if "_default" not in self.params_dict:
|
||||
raise RuntimeError("Not initialized")
|
||||
params = self.params_dict["_default"]
|
||||
return _normalize(x, params, forward=forward)
|
||||
|
||||
def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return self._normalize_impl(x, forward=True)
|
||||
|
||||
def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return self._normalize_impl(x, forward=False)
|
||||
|
||||
def get_input_stats(self) -> Dict:
|
||||
if len(self.params_dict) == 0:
|
||||
raise RuntimeError("Not initialized")
|
||||
if len(self.params_dict) == 1 and "_default" in self.params_dict:
|
||||
return self.params_dict["_default"]["input_stats"]
|
||||
|
||||
result = {}
|
||||
for key, value in self.params_dict.items():
|
||||
if key != "_default":
|
||||
result[key] = value["input_stats"]
|
||||
return result
|
||||
|
||||
def get_output_stats(self, key="_default"):
|
||||
input_stats = self.get_input_stats()
|
||||
if "min" in input_stats:
|
||||
# no dict
|
||||
return dict_apply(input_stats, self.normalize)
|
||||
|
||||
result = {}
|
||||
for key, group in input_stats.items():
|
||||
this_dict = {}
|
||||
for name, value in group.items():
|
||||
this_dict[name] = self.normalize({key: value})[key]
|
||||
result[key] = this_dict
|
||||
return result
|
||||
|
||||
|
||||
class SingleFieldLinearNormalizer(DictOfTensorMixin):
|
||||
avaliable_modes = ["limits", "gaussian"]
|
||||
|
||||
@torch.no_grad()
|
||||
def fit(
|
||||
self,
|
||||
data: Union[torch.Tensor, np.ndarray, zarr.Array],
|
||||
last_n_dims=1,
|
||||
dtype=torch.float32,
|
||||
mode="limits",
|
||||
output_max=1.0,
|
||||
output_min=-1.0,
|
||||
range_eps=1e-4,
|
||||
fit_offset=True,
|
||||
):
|
||||
self.params_dict = _fit(
|
||||
data,
|
||||
last_n_dims=last_n_dims,
|
||||
dtype=dtype,
|
||||
mode=mode,
|
||||
output_max=output_max,
|
||||
output_min=output_min,
|
||||
range_eps=range_eps,
|
||||
fit_offset=fit_offset,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs):
|
||||
obj = cls()
|
||||
obj.fit(data, **kwargs)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def create_manual(
|
||||
cls,
|
||||
scale: Union[torch.Tensor, np.ndarray],
|
||||
offset: Union[torch.Tensor, np.ndarray],
|
||||
input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]],
|
||||
):
|
||||
def to_tensor(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.from_numpy(x)
|
||||
x = x.flatten()
|
||||
return x
|
||||
|
||||
# check
|
||||
for x in [offset] + list(input_stats_dict.values()):
|
||||
assert x.shape == scale.shape
|
||||
assert x.dtype == scale.dtype
|
||||
|
||||
params_dict = nn.ParameterDict(
|
||||
{
|
||||
"scale": to_tensor(scale),
|
||||
"offset": to_tensor(offset),
|
||||
"input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)),
|
||||
}
|
||||
)
|
||||
return cls(params_dict)
|
||||
|
||||
@classmethod
|
||||
def create_identity(cls, dtype=torch.float32):
|
||||
scale = torch.tensor([1], dtype=dtype)
|
||||
offset = torch.tensor([0], dtype=dtype)
|
||||
input_stats_dict = {
|
||||
"min": torch.tensor([-1], dtype=dtype),
|
||||
"max": torch.tensor([1], dtype=dtype),
|
||||
"mean": torch.tensor([0], dtype=dtype),
|
||||
"std": torch.tensor([1], dtype=dtype),
|
||||
}
|
||||
return cls.create_manual(scale, offset, input_stats_dict)
|
||||
|
||||
def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return _normalize(x, self.params_dict, forward=True)
|
||||
|
||||
def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return _normalize(x, self.params_dict, forward=False)
|
||||
|
||||
def get_input_stats(self):
|
||||
return self.params_dict["input_stats"]
|
||||
|
||||
def get_output_stats(self):
|
||||
return dict_apply(self.params_dict["input_stats"], self.normalize)
|
||||
|
||||
def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
||||
return self.normalize(x)
|
||||
|
||||
|
||||
def _fit(
|
||||
data: Union[torch.Tensor, np.ndarray, zarr.Array],
|
||||
last_n_dims=1,
|
||||
dtype=torch.float32,
|
||||
mode="limits",
|
||||
output_max=1.0,
|
||||
output_min=-1.0,
|
||||
range_eps=1e-4,
|
||||
fit_offset=True,
|
||||
):
|
||||
assert mode in ["limits", "gaussian"]
|
||||
assert last_n_dims >= 0
|
||||
assert output_max > output_min
|
||||
|
||||
# convert data to torch and type
|
||||
if isinstance(data, zarr.Array):
|
||||
data = data[:]
|
||||
if isinstance(data, np.ndarray):
|
||||
data = torch.from_numpy(data)
|
||||
if dtype is not None:
|
||||
data = data.type(dtype)
|
||||
|
||||
# convert shape
|
||||
dim = 1
|
||||
if last_n_dims > 0:
|
||||
dim = np.prod(data.shape[-last_n_dims:])
|
||||
data = data.reshape(-1, dim)
|
||||
|
||||
# compute input stats min max mean std
|
||||
input_min, _ = data.min(axis=0)
|
||||
input_max, _ = data.max(axis=0)
|
||||
input_mean = data.mean(axis=0)
|
||||
input_std = data.std(axis=0)
|
||||
|
||||
# compute scale and offset
|
||||
if mode == "limits":
|
||||
if fit_offset:
|
||||
# unit scale
|
||||
input_range = input_max - input_min
|
||||
ignore_dim = input_range < range_eps
|
||||
input_range[ignore_dim] = output_max - output_min
|
||||
scale = (output_max - output_min) / input_range
|
||||
offset = output_min - scale * input_min
|
||||
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
|
||||
# ignore dims scaled to mean of output max and min
|
||||
else:
|
||||
# use this when data is pre-zero-centered.
|
||||
assert output_max > 0
|
||||
assert output_min < 0
|
||||
# unit abs
|
||||
output_abs = min(abs(output_min), abs(output_max))
|
||||
input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max))
|
||||
ignore_dim = input_abs < range_eps
|
||||
input_abs[ignore_dim] = output_abs
|
||||
# don't scale constant channels
|
||||
scale = output_abs / input_abs
|
||||
offset = torch.zeros_like(input_mean)
|
||||
elif mode == "gaussian":
|
||||
ignore_dim = input_std < range_eps
|
||||
scale = input_std.clone()
|
||||
scale[ignore_dim] = 1
|
||||
scale = 1 / scale
|
||||
|
||||
offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean)
|
||||
|
||||
# save
|
||||
this_params = nn.ParameterDict(
|
||||
{
|
||||
"scale": scale,
|
||||
"offset": offset,
|
||||
"input_stats": nn.ParameterDict(
|
||||
{"min": input_min, "max": input_max, "mean": input_mean, "std": input_std}
|
||||
),
|
||||
}
|
||||
)
|
||||
for p in this_params.parameters():
|
||||
p.requires_grad_(False)
|
||||
return this_params
|
||||
|
||||
|
||||
def _normalize(x, params, forward=True):
|
||||
assert "scale" in params
|
||||
if isinstance(x, np.ndarray):
|
||||
x = torch.from_numpy(x)
|
||||
scale = params["scale"]
|
||||
offset = params["offset"]
|
||||
x = x.to(device=scale.device, dtype=scale.dtype)
|
||||
src_shape = x.shape
|
||||
x = x.reshape(-1, scale.shape[0])
|
||||
x = x * scale + offset if forward else (x - offset) / scale
|
||||
x = x.reshape(src_shape)
|
||||
return x
|
||||
|
||||
|
||||
def test():
|
||||
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
||||
data[..., 0, 0] = 0
|
||||
|
||||
normalizer = SingleFieldLinearNormalizer()
|
||||
normalizer.fit(data, mode="limits", last_n_dims=2)
|
||||
datan = normalizer.normalize(data)
|
||||
assert datan.shape == data.shape
|
||||
assert np.allclose(datan.max(), 1.0)
|
||||
assert np.allclose(datan.min(), -1.0)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
assert torch.allclose(data, dataun, atol=1e-7)
|
||||
|
||||
_ = normalizer.get_input_stats()
|
||||
_ = normalizer.get_output_stats()
|
||||
|
||||
normalizer = SingleFieldLinearNormalizer()
|
||||
normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False)
|
||||
datan = normalizer.normalize(data)
|
||||
assert datan.shape == data.shape
|
||||
assert np.allclose(datan.max(), 1.0, atol=1e-3)
|
||||
assert np.allclose(datan.min(), 0.0, atol=1e-3)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
assert torch.allclose(data, dataun, atol=1e-7)
|
||||
|
||||
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
||||
normalizer = SingleFieldLinearNormalizer()
|
||||
normalizer.fit(data, mode="gaussian", last_n_dims=0)
|
||||
datan = normalizer.normalize(data)
|
||||
assert datan.shape == data.shape
|
||||
assert np.allclose(datan.mean(), 0.0, atol=1e-3)
|
||||
assert np.allclose(datan.std(), 1.0, atol=1e-3)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
assert torch.allclose(data, dataun, atol=1e-7)
|
||||
|
||||
# dict
|
||||
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
||||
data[..., 0, 0] = 0
|
||||
|
||||
normalizer = LinearNormalizer()
|
||||
normalizer.fit(data, mode="limits", last_n_dims=2)
|
||||
datan = normalizer.normalize(data)
|
||||
assert datan.shape == data.shape
|
||||
assert np.allclose(datan.max(), 1.0)
|
||||
assert np.allclose(datan.min(), -1.0)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
assert torch.allclose(data, dataun, atol=1e-7)
|
||||
|
||||
_ = normalizer.get_input_stats()
|
||||
_ = normalizer.get_output_stats()
|
||||
|
||||
data = {
|
||||
"obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512,
|
||||
"action": torch.zeros((1000, 128, 2)).uniform_() * 512,
|
||||
}
|
||||
normalizer = LinearNormalizer()
|
||||
normalizer.fit(data)
|
||||
datan = normalizer.normalize(data)
|
||||
dataun = normalizer.unnormalize(datan)
|
||||
for key in data:
|
||||
assert torch.allclose(data[key], dataun[key], atol=1e-4)
|
||||
|
||||
_ = normalizer.get_input_stats()
|
||||
_ = normalizer.get_output_stats()
|
||||
|
||||
state_dict = normalizer.state_dict()
|
||||
n = LinearNormalizer()
|
||||
n.load_state_dict(state_dict)
|
||||
datan = n.normalize(data)
|
||||
dataun = n.unnormalize(datan)
|
||||
for key in data:
|
||||
assert torch.allclose(data[key], dataun[key], atol=1e-4)
|
||||
@@ -0,0 +1,19 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
971
lerobot/common/policies/diffusion/model/tensor_utils.py
Normal file
971
lerobot/common/policies/diffusion/model/tensor_utils.py
Normal file
@@ -0,0 +1,971 @@
|
||||
"""
|
||||
A collection of utilities for working with nested tensor structures consisting
|
||||
of numpy arrays and torch tensors.
|
||||
"""
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def recursive_dict_list_tuple_apply(x, type_func_dict):
|
||||
"""
|
||||
Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
|
||||
{data_type: function_to_apply}.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
type_func_dict (dict): a mapping from data types to the functions to be
|
||||
applied for each data type.
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
assert list not in type_func_dict
|
||||
assert tuple not in type_func_dict
|
||||
assert dict not in type_func_dict
|
||||
|
||||
if isinstance(x, (dict, collections.OrderedDict)):
|
||||
new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else {}
|
||||
for k, v in x.items():
|
||||
new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
|
||||
return new_x
|
||||
elif isinstance(x, (list, tuple)):
|
||||
ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
|
||||
if isinstance(x, tuple):
|
||||
ret = tuple(ret)
|
||||
return ret
|
||||
else:
|
||||
for t, f in type_func_dict.items():
|
||||
if isinstance(x, t):
|
||||
return f(x)
|
||||
else:
|
||||
raise NotImplementedError("Cannot handle data type %s" % str(type(x)))
|
||||
|
||||
|
||||
def map_tensor(x, func):
|
||||
"""
|
||||
Apply function @func to torch.Tensor objects in a nested dictionary or
|
||||
list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
func (function): function to apply to each tensor
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: func,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def map_ndarray(x, func):
|
||||
"""
|
||||
Apply function @func to np.ndarray objects in a nested dictionary or
|
||||
list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
func (function): function to apply to each array
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
np.ndarray: func,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def map_tensor_ndarray(x, tensor_func, ndarray_func):
|
||||
"""
|
||||
Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
|
||||
np.ndarray objects in a nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
tensor_func (function): function to apply to each tensor
|
||||
ndarray_Func (function): function to apply to each array
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: tensor_func,
|
||||
np.ndarray: ndarray_func,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def clone(x):
|
||||
"""
|
||||
Clones all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.clone(),
|
||||
np.ndarray: lambda x: x.copy(),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def detach(x):
|
||||
"""
|
||||
Detaches all torch tensors in nested dictionary or list
|
||||
or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.detach(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_batch(x):
|
||||
"""
|
||||
Introduces a leading batch dimension of 1 for all torch tensors and numpy
|
||||
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x[None, ...],
|
||||
np.ndarray: lambda x: x[None, ...],
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_sequence(x):
|
||||
"""
|
||||
Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
|
||||
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x[:, None, ...],
|
||||
np.ndarray: lambda x: x[:, None, ...],
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def index_at_time(x, ind):
|
||||
"""
|
||||
Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
|
||||
nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
ind (int): index
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x[:, ind, ...],
|
||||
np.ndarray: lambda x: x[:, ind, ...],
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def unsqueeze(x, dim):
|
||||
"""
|
||||
Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
|
||||
in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
dim (int): dimension
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.unsqueeze(dim=dim),
|
||||
np.ndarray: lambda x: np.expand_dims(x, axis=dim),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def contiguous(x):
|
||||
"""
|
||||
Makes all torch tensors and numpy arrays contiguous in nested dictionary or
|
||||
list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.contiguous(),
|
||||
np.ndarray: lambda x: np.ascontiguousarray(x),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_device(x, device):
|
||||
"""
|
||||
Sends all torch tensors in nested dictionary or list or tuple to device
|
||||
@device, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
device (torch.Device): device to send tensors to
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x, d=device: x.to(d),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_tensor(x):
|
||||
"""
|
||||
Converts all numpy arrays in nested dictionary or list or tuple to
|
||||
torch tensors (and leaves existing torch Tensors as-is), and returns
|
||||
a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x,
|
||||
np.ndarray: lambda x: torch.from_numpy(x),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_numpy(x):
|
||||
"""
|
||||
Converts all torch tensors in nested dictionary or list or tuple to
|
||||
numpy (and leaves existing numpy arrays as-is), and returns
|
||||
a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
|
||||
def f(tensor):
|
||||
if tensor.is_cuda:
|
||||
return tensor.detach().cpu().numpy()
|
||||
else:
|
||||
return tensor.detach().numpy()
|
||||
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: f,
|
||||
np.ndarray: lambda x: x,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_list(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to a list, and returns a new nested structure. Useful for
|
||||
json encoding.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
|
||||
def f(tensor):
|
||||
if tensor.is_cuda:
|
||||
return tensor.detach().cpu().numpy().tolist()
|
||||
else:
|
||||
return tensor.detach().numpy().tolist()
|
||||
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: f,
|
||||
np.ndarray: lambda x: x.tolist(),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_float(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to float type entries, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.float(),
|
||||
np.ndarray: lambda x: x.astype(np.float32),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_uint8(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to uint8 type entries, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.byte(),
|
||||
np.ndarray: lambda x: x.astype(np.uint8),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def to_torch(x, device):
|
||||
"""
|
||||
Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
|
||||
torch tensors on device @device and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
device (torch.Device): device to send tensors to
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return to_device(to_float(to_tensor(x)), device)
|
||||
|
||||
|
||||
def to_one_hot_single(tensor, num_class):
|
||||
"""
|
||||
Convert tensor to one-hot representation, assuming a certain number of total class labels.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): tensor containing integer labels
|
||||
num_class (int): number of classes
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): tensor containing one-hot representation of labels
|
||||
"""
|
||||
x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device)
|
||||
x.scatter_(-1, tensor.unsqueeze(-1), 1)
|
||||
return x
|
||||
|
||||
|
||||
def to_one_hot(tensor, num_class):
|
||||
"""
|
||||
Convert all tensors in nested dictionary or list or tuple to one-hot representation,
|
||||
assuming a certain number of total class labels.
|
||||
|
||||
Args:
|
||||
tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
num_class (int): number of classes
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc))
|
||||
|
||||
|
||||
def flatten_single(x, begin_axis=1):
|
||||
"""
|
||||
Flatten a tensor in all dimensions from @begin_axis onwards.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to flatten
|
||||
begin_axis (int): which axis to flatten from
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): flattened tensor
|
||||
"""
|
||||
fixed_size = x.size()[:begin_axis]
|
||||
_s = list(fixed_size) + [-1]
|
||||
return x.reshape(*_s)
|
||||
|
||||
|
||||
def flatten(x, begin_axis=1):
|
||||
"""
|
||||
Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): which axis to flatten from
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
|
||||
"""
|
||||
Reshape selected dimensions in a tensor to a target dimension.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to reshape
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
target_dims (tuple or list): target shape for the range of dimensions
|
||||
(@begin_axis, @end_axis)
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): reshaped tensor
|
||||
"""
|
||||
assert begin_axis <= end_axis
|
||||
assert begin_axis >= 0
|
||||
assert end_axis < len(x.shape)
|
||||
assert isinstance(target_dims, (tuple, list))
|
||||
s = x.shape
|
||||
final_s = []
|
||||
for i in range(len(s)):
|
||||
if i == begin_axis:
|
||||
final_s.extend(target_dims)
|
||||
elif i < begin_axis or i > end_axis:
|
||||
final_s.append(s[i])
|
||||
return x.reshape(*final_s)
|
||||
|
||||
|
||||
def reshape_dimensions(x, begin_axis, end_axis, target_dims):
|
||||
"""
|
||||
Reshape selected dimensions for all tensors in nested dictionary or list or tuple
|
||||
to a target dimension.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
target_dims (tuple or list): target shape for the range of dimensions
|
||||
(@begin_axis, @end_axis)
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=t
|
||||
),
|
||||
np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=t
|
||||
),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def join_dimensions(x, begin_axis, end_axis):
|
||||
"""
|
||||
Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
|
||||
all tensors in nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=[-1]
|
||||
),
|
||||
np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=[-1]
|
||||
),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def expand_at_single(x, size, dim):
|
||||
"""
|
||||
Expand a tensor at a single dimension @dim by @size
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): input tensor
|
||||
size (int): size to expand
|
||||
dim (int): dimension to expand
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): expanded tensor
|
||||
"""
|
||||
assert dim < x.ndimension()
|
||||
assert x.shape[dim] == 1
|
||||
expand_dims = [-1] * x.ndimension()
|
||||
expand_dims[dim] = size
|
||||
return x.expand(*expand_dims)
|
||||
|
||||
|
||||
def expand_at(x, size, dim):
|
||||
"""
|
||||
Expand all tensors in nested dictionary or list or tuple at a single
|
||||
dimension @dim by @size.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size to expand
|
||||
dim (int): dimension to expand
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
|
||||
|
||||
|
||||
def unsqueeze_expand_at(x, size, dim):
|
||||
"""
|
||||
Unsqueeze and expand a tensor at a dimension @dim by @size.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size to expand
|
||||
dim (int): dimension to unsqueeze and expand
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
x = unsqueeze(x, dim)
|
||||
return expand_at(x, size, dim)
|
||||
|
||||
|
||||
def repeat_by_expand_at(x, repeats, dim):
|
||||
"""
|
||||
Repeat a dimension by combining expand and reshape operations.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
repeats (int): number of times to repeat the target dimension
|
||||
dim (int): dimension to repeat on
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
x = unsqueeze_expand_at(x, repeats, dim + 1)
|
||||
return join_dimensions(x, dim, dim + 1)
|
||||
|
||||
|
||||
def named_reduce_single(x, reduction, dim):
|
||||
"""
|
||||
Reduce tensor at a dimension by named reduction functions.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to be reduced
|
||||
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): reduced tensor
|
||||
"""
|
||||
assert x.ndimension() > dim
|
||||
assert reduction in ["sum", "max", "mean", "flatten"]
|
||||
if reduction == "flatten":
|
||||
x = flatten(x, begin_axis=dim)
|
||||
elif reduction == "max":
|
||||
x = torch.max(x, dim=dim)[0] # [B, D]
|
||||
elif reduction == "sum":
|
||||
x = torch.sum(x, dim=dim)
|
||||
else:
|
||||
x = torch.mean(x, dim=dim)
|
||||
return x
|
||||
|
||||
|
||||
def named_reduce(x, reduction, dim):
|
||||
"""
|
||||
Reduces all tensors in nested dictionary or list or tuple at a dimension
|
||||
using a named reduction function.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
|
||||
|
||||
|
||||
def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
|
||||
"""
|
||||
This function indexes out a target dimension of a tensor in a structured way,
|
||||
by allowing a different value to be selected for each member of a flat index
|
||||
tensor (@indices) corresponding to a source dimension. This can be interpreted
|
||||
as moving along the source dimension, using the corresponding index value
|
||||
in @indices to select values for all other dimensions outside of the
|
||||
source and target dimensions. A common use case is to gather values
|
||||
in target dimension 1 for each batch member (target dimension 0).
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to gather values for
|
||||
target_dim (int): dimension to gather values along
|
||||
source_dim (int): dimension to hold constant and use for gathering values
|
||||
from the other dimensions
|
||||
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||
@source_dim
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
|
||||
"""
|
||||
assert len(indices.shape) == 1
|
||||
assert x.shape[source_dim] == indices.shape[0]
|
||||
|
||||
# unsqueeze in all dimensions except the source dimension
|
||||
new_shape = [1] * x.ndimension()
|
||||
new_shape[source_dim] = -1
|
||||
indices = indices.reshape(*new_shape)
|
||||
|
||||
# repeat in all dimensions - but preserve shape of source dimension,
|
||||
# and make sure target_dimension has singleton dimension
|
||||
expand_shape = list(x.shape)
|
||||
expand_shape[source_dim] = -1
|
||||
expand_shape[target_dim] = 1
|
||||
indices = indices.expand(*expand_shape)
|
||||
|
||||
out = x.gather(dim=target_dim, index=indices)
|
||||
return out.squeeze(target_dim)
|
||||
|
||||
|
||||
def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
|
||||
"""
|
||||
Apply @gather_along_dim_with_dim_single to all tensors in a nested
|
||||
dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
target_dim (int): dimension to gather values along
|
||||
source_dim (int): dimension to hold constant and use for gathering values
|
||||
from the other dimensions
|
||||
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||
@source_dim
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(
|
||||
x, lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i)
|
||||
)
|
||||
|
||||
|
||||
def gather_sequence_single(seq, indices):
|
||||
"""
|
||||
Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
|
||||
the batch given an index for each sequence.
|
||||
|
||||
Args:
|
||||
seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
|
||||
indices (torch.Tensor): tensor indices of shape [B]
|
||||
|
||||
Return:
|
||||
y (torch.Tensor): indexed tensor of shape [B, ....]
|
||||
"""
|
||||
return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices)
|
||||
|
||||
|
||||
def gather_sequence(seq, indices):
|
||||
"""
|
||||
Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
|
||||
for tensors with leading dimensions [B, T, ...].
|
||||
|
||||
Args:
|
||||
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
indices (torch.Tensor): tensor indices of shape [B]
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
|
||||
"""
|
||||
return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices)
|
||||
|
||||
|
||||
def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None):
|
||||
"""
|
||||
Pad input tensor or array @seq in the time dimension (dimension 1).
|
||||
|
||||
Args:
|
||||
seq (np.ndarray or torch.Tensor): sequence to be padded
|
||||
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||
batched (bool): if sequence has the batch dimension
|
||||
pad_same (bool): if pad by duplicating
|
||||
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||
|
||||
Returns:
|
||||
padded sequence (np.ndarray or torch.Tensor)
|
||||
"""
|
||||
assert isinstance(seq, (np.ndarray, torch.Tensor))
|
||||
assert pad_same or pad_values is not None
|
||||
if pad_values is not None:
|
||||
assert isinstance(pad_values, float)
|
||||
repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave
|
||||
concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
|
||||
ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like
|
||||
seq_dim = 1 if batched else 0
|
||||
|
||||
begin_pad = []
|
||||
end_pad = []
|
||||
|
||||
if padding[0] > 0:
|
||||
pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
|
||||
begin_pad.append(repeat_func(pad, padding[0], seq_dim))
|
||||
if padding[1] > 0:
|
||||
pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
|
||||
end_pad.append(repeat_func(pad, padding[1], seq_dim))
|
||||
|
||||
return concat_func(begin_pad + [seq] + end_pad, seq_dim)
|
||||
|
||||
|
||||
def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
|
||||
"""
|
||||
Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
|
||||
|
||||
Args:
|
||||
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||
batched (bool): if sequence has the batch dimension
|
||||
pad_same (bool): if pad by duplicating
|
||||
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||
|
||||
Returns:
|
||||
padded sequence (dict or list or tuple)
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
seq,
|
||||
{
|
||||
torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
|
||||
x, p, b, ps, pv
|
||||
),
|
||||
np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single(
|
||||
x, p, b, ps, pv
|
||||
),
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def assert_size_at_dim_single(x, size, dim, msg):
|
||||
"""
|
||||
Ensure that array or tensor @x has size @size in dim @dim.
|
||||
|
||||
Args:
|
||||
x (np.ndarray or torch.Tensor): input array or tensor
|
||||
size (int): size that tensors should have at @dim
|
||||
dim (int): dimension to check
|
||||
msg (str): text to display if assertion fails
|
||||
"""
|
||||
assert x.shape[dim] == size, msg
|
||||
|
||||
|
||||
def assert_size_at_dim(x, size, dim, msg):
|
||||
"""
|
||||
Ensure that arrays and tensors in nested dictionary or list or tuple have
|
||||
size @size in dim @dim.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size that tensors should have at @dim
|
||||
dim (int): dimension to check
|
||||
"""
|
||||
map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
|
||||
|
||||
|
||||
def get_shape(x):
|
||||
"""
|
||||
Get all shapes of arrays and tensors in nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple that contains each array or
|
||||
tensor's shape
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x,
|
||||
{
|
||||
torch.Tensor: lambda x: x.shape,
|
||||
np.ndarray: lambda x: x.shape,
|
||||
type(None): lambda x: x,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def list_of_flat_dict_to_dict_of_list(list_of_dict):
|
||||
"""
|
||||
Helper function to go from a list of flat dictionaries to a dictionary of lists.
|
||||
By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
|
||||
floats, etc.
|
||||
|
||||
Args:
|
||||
list_of_dict (list): list of flat dictionaries
|
||||
|
||||
Returns:
|
||||
dict_of_list (dict): dictionary of lists
|
||||
"""
|
||||
assert isinstance(list_of_dict, list)
|
||||
dic = collections.OrderedDict()
|
||||
for i in range(len(list_of_dict)):
|
||||
for k in list_of_dict[i]:
|
||||
if k not in dic:
|
||||
dic[k] = []
|
||||
dic[k].append(list_of_dict[i][k])
|
||||
return dic
|
||||
|
||||
|
||||
def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""):
|
||||
"""
|
||||
Flatten a nested dict or list to a list.
|
||||
|
||||
For example, given a dict
|
||||
{
|
||||
a: 1
|
||||
b: {
|
||||
c: 2
|
||||
}
|
||||
c: 3
|
||||
}
|
||||
|
||||
the function would return [(a, 1), (b_c, 2), (c, 3)]
|
||||
|
||||
Args:
|
||||
d (dict, list): a nested dict or list to be flattened
|
||||
parent_key (str): recursion helper
|
||||
sep (str): separator for nesting keys
|
||||
item_key (str): recursion helper
|
||||
Returns:
|
||||
list: a list of (key, value) tuples
|
||||
"""
|
||||
items = []
|
||||
if isinstance(d, (tuple, list)):
|
||||
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
|
||||
for i, v in enumerate(d):
|
||||
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
|
||||
return items
|
||||
elif isinstance(d, dict):
|
||||
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
|
||||
for k, v in d.items():
|
||||
assert isinstance(k, str)
|
||||
items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
|
||||
return items
|
||||
else:
|
||||
new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
|
||||
return [(new_key, d)]
|
||||
|
||||
|
||||
def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs):
|
||||
"""
|
||||
Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
|
||||
batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
|
||||
Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
|
||||
outputs to [B, T, ...].
|
||||
|
||||
Args:
|
||||
inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
op: a layer op that accepts inputs
|
||||
activation: activation to apply at the output
|
||||
inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
|
||||
inputs_as_args (bool) whether to feed input as a args list to the op
|
||||
kwargs (dict): other kwargs to supply to the op
|
||||
|
||||
Returns:
|
||||
outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
|
||||
"""
|
||||
batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
|
||||
inputs = join_dimensions(inputs, 0, 1)
|
||||
if inputs_as_kwargs:
|
||||
outputs = op(**inputs, **kwargs)
|
||||
elif inputs_as_args:
|
||||
outputs = op(*inputs, **kwargs)
|
||||
else:
|
||||
outputs = op(inputs, **kwargs)
|
||||
|
||||
if activation is not None:
|
||||
outputs = map_tensor(outputs, activation)
|
||||
outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len))
|
||||
return outputs
|
||||
@@ -5,16 +5,16 @@ import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
|
||||
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from .multi_image_obs_encoder import MultiImageObsEncoder
|
||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
|
||||
|
||||
class DiffusionPolicy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
cfg_device,
|
||||
cfg_noise_scheduler,
|
||||
cfg_rgb_model,
|
||||
cfg_obs_encoder,
|
||||
@@ -62,8 +62,9 @@ class DiffusionPolicy(nn.Module):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.device = torch.device("cuda")
|
||||
self.diffusion.cuda()
|
||||
self.device = torch.device(cfg_device)
|
||||
if torch.cuda.is_available() and cfg_device == "cuda":
|
||||
self.diffusion.cuda()
|
||||
|
||||
self.ema = None
|
||||
if self.cfg.use_ema:
|
||||
|
||||
76
lerobot/common/policies/diffusion/pytorch_utils.py
Normal file
76
lerobot/common/policies/diffusion/pytorch_utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
|
||||
def get_resnet(name, weights=None, **kwargs):
|
||||
"""
|
||||
name: resnet18, resnet34, resnet50
|
||||
weights: "IMAGENET1K_V1", "r3m"
|
||||
"""
|
||||
# load r3m weights
|
||||
if (weights == "r3m") or (weights == "R3M"):
|
||||
return get_r3m(name=name, **kwargs)
|
||||
|
||||
func = getattr(torchvision.models, name)
|
||||
resnet = func(weights=weights, **kwargs)
|
||||
resnet.fc = torch.nn.Identity()
|
||||
return resnet
|
||||
|
||||
|
||||
def get_r3m(name, **kwargs):
|
||||
"""
|
||||
name: resnet18, resnet34, resnet50
|
||||
"""
|
||||
import r3m
|
||||
|
||||
r3m.device = "cpu"
|
||||
model = r3m.load_r3m(name)
|
||||
r3m_model = model.module
|
||||
resnet_model = r3m_model.convnet
|
||||
resnet_model = resnet_model.to("cpu")
|
||||
return resnet_model
|
||||
|
||||
|
||||
def dict_apply(
|
||||
x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
result = {}
|
||||
for key, value in x.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = dict_apply(value, func)
|
||||
else:
|
||||
result[key] = func(value)
|
||||
return result
|
||||
|
||||
|
||||
def replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
predicate: Return true if the module is to be replaced.
|
||||
func: Return new module to use.
|
||||
"""
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parent, k in bn_list:
|
||||
parent_module = root_module
|
||||
if len(parent) > 0:
|
||||
parent_module = root_module.get_submodule(".".join(parent))
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
src_module = parent_module[int(k)]
|
||||
else:
|
||||
src_module = getattr(parent_module, k)
|
||||
tgt_module = func(src_module)
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
parent_module[int(k)] = tgt_module
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
assert len(bn_list) == 0
|
||||
return root_module
|
||||
614
lerobot/common/policies/diffusion/replay_buffer.py
Normal file
614
lerobot/common/policies/diffusion/replay_buffer.py
Normal file
@@ -0,0 +1,614 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import numbers
|
||||
import os
|
||||
from functools import cached_property
|
||||
|
||||
import numcodecs
|
||||
import numpy as np
|
||||
import zarr
|
||||
|
||||
|
||||
def check_chunks_compatible(chunks: tuple, shape: tuple):
|
||||
assert len(shape) == len(chunks)
|
||||
for c in chunks:
|
||||
assert isinstance(c, numbers.Integral)
|
||||
assert c > 0
|
||||
|
||||
|
||||
def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
|
||||
old_arr = group[name]
|
||||
if chunks is None:
|
||||
chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
|
||||
check_chunks_compatible(chunks, old_arr.shape)
|
||||
|
||||
if compressor is None:
|
||||
compressor = old_arr.compressor
|
||||
|
||||
if (chunks == old_arr.chunks) and (compressor == old_arr.compressor):
|
||||
# no change
|
||||
return old_arr
|
||||
|
||||
# rechunk recompress
|
||||
group.move(name, tmp_key)
|
||||
old_arr = group[tmp_key]
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
source=old_arr,
|
||||
dest=group,
|
||||
name=name,
|
||||
chunks=chunks,
|
||||
compressor=compressor,
|
||||
)
|
||||
del group[tmp_key]
|
||||
arr = group[name]
|
||||
return arr
|
||||
|
||||
|
||||
def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=None):
|
||||
"""
|
||||
Common shapes
|
||||
T,D
|
||||
T,N,D
|
||||
T,H,W,C
|
||||
T,N,H,W,C
|
||||
"""
|
||||
itemsize = np.dtype(dtype).itemsize
|
||||
# reversed
|
||||
rshape = list(shape[::-1])
|
||||
if max_chunk_length is not None:
|
||||
rshape[-1] = int(max_chunk_length)
|
||||
split_idx = len(shape) - 1
|
||||
for i in range(len(shape) - 1):
|
||||
this_chunk_bytes = itemsize * np.prod(rshape[:i])
|
||||
next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
|
||||
if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
|
||||
split_idx = i
|
||||
|
||||
rchunks = rshape[:split_idx]
|
||||
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
|
||||
this_max_chunk_length = rshape[split_idx]
|
||||
next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
|
||||
rchunks.append(next_chunk_length)
|
||||
len_diff = len(shape) - len(rchunks)
|
||||
rchunks.extend([1] * len_diff)
|
||||
chunks = tuple(rchunks[::-1])
|
||||
# print(np.prod(chunks) * itemsize / target_chunk_bytes)
|
||||
return chunks
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
"""
|
||||
Zarr-based temporal datastructure.
|
||||
Assumes first dimension to be time. Only chunk in time dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, root: zarr.Group | dict[str, dict]):
|
||||
"""
|
||||
Dummy constructor. Use copy_from* and create_from* class methods instead.
|
||||
"""
|
||||
assert "data" in root
|
||||
assert "meta" in root
|
||||
assert "episode_ends" in root["meta"]
|
||||
for value in root["data"].values():
|
||||
assert value.shape[0] == root["meta"]["episode_ends"][-1]
|
||||
self.root = root
|
||||
|
||||
# ============= create constructors ===============
|
||||
@classmethod
|
||||
def create_empty_zarr(cls, storage=None, root=None):
|
||||
if root is None:
|
||||
if storage is None:
|
||||
storage = zarr.MemoryStore()
|
||||
root = zarr.group(store=storage)
|
||||
root.require_group("data", overwrite=False)
|
||||
meta = root.require_group("meta", overwrite=False)
|
||||
if "episode_ends" not in meta:
|
||||
meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
|
||||
return cls(root=root)
|
||||
|
||||
@classmethod
|
||||
def create_empty_numpy(cls):
|
||||
root = {"data": {}, "meta": {"episode_ends": np.zeros((0,), dtype=np.int64)}}
|
||||
return cls(root=root)
|
||||
|
||||
@classmethod
|
||||
def create_from_group(cls, group, **kwargs):
|
||||
if "data" not in group:
|
||||
# create from stratch
|
||||
buffer = cls.create_empty_zarr(root=group, **kwargs)
|
||||
else:
|
||||
# already exist
|
||||
buffer = cls(root=group, **kwargs)
|
||||
return buffer
|
||||
|
||||
@classmethod
|
||||
def create_from_path(cls, zarr_path, mode="r", **kwargs):
|
||||
"""
|
||||
Open a on-disk zarr directly (for dataset larger than memory).
|
||||
Slower.
|
||||
"""
|
||||
group = zarr.open(os.path.expanduser(zarr_path), mode)
|
||||
return cls.create_from_group(group, **kwargs)
|
||||
|
||||
# ============= copy constructors ===============
|
||||
@classmethod
|
||||
def copy_from_store(
|
||||
cls,
|
||||
src_store,
|
||||
store=None,
|
||||
keys=None,
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: dict | str | numcodecs.abc.Codec | None = None,
|
||||
if_exists="replace",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load to memory.
|
||||
"""
|
||||
src_root = zarr.group(src_store)
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
root = None
|
||||
if store is None:
|
||||
# numpy backend
|
||||
meta = {}
|
||||
for key, value in src_root["meta"].items():
|
||||
if len(value.shape) == 0:
|
||||
meta[key] = np.array(value)
|
||||
else:
|
||||
meta[key] = value[:]
|
||||
|
||||
if keys is None:
|
||||
keys = src_root["data"].keys()
|
||||
data = {}
|
||||
for key in keys:
|
||||
arr = src_root["data"][key]
|
||||
data[key] = arr[:]
|
||||
|
||||
root = {"meta": meta, "data": data}
|
||||
else:
|
||||
root = zarr.group(store=store)
|
||||
# copy without recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
||||
)
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
if keys is None:
|
||||
keys = src_root["data"].keys()
|
||||
for key in keys:
|
||||
value = src_root["data"][key]
|
||||
cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=src_store,
|
||||
dest=store,
|
||||
source_path=this_path,
|
||||
dest_path=this_path,
|
||||
if_exists=if_exists,
|
||||
)
|
||||
else:
|
||||
# copy with recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
source=value,
|
||||
dest=data_group,
|
||||
name=key,
|
||||
chunks=cks,
|
||||
compressor=cpr,
|
||||
if_exists=if_exists,
|
||||
)
|
||||
buffer = cls(root=root)
|
||||
return buffer
|
||||
|
||||
@classmethod
|
||||
def copy_from_path(
|
||||
cls,
|
||||
zarr_path,
|
||||
backend=None,
|
||||
store=None,
|
||||
keys=None,
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: dict | str | numcodecs.abc.Codec | None = None,
|
||||
if_exists="replace",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Copy a on-disk zarr to in-memory compressed.
|
||||
Recommended
|
||||
"""
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
if backend == "numpy":
|
||||
print("backend argument is deprecated!")
|
||||
store = None
|
||||
group = zarr.open(os.path.expanduser(zarr_path), "r")
|
||||
return cls.copy_from_store(
|
||||
src_store=group.store,
|
||||
store=store,
|
||||
keys=keys,
|
||||
chunks=chunks,
|
||||
compressors=compressors,
|
||||
if_exists=if_exists,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ============= save methods ===============
|
||||
def save_to_store(
|
||||
self,
|
||||
store,
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
||||
if_exists="replace",
|
||||
**kwargs,
|
||||
):
|
||||
root = zarr.group(store)
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
if self.backend == "zarr":
|
||||
# recompression free copy
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=self.root.store,
|
||||
dest=store,
|
||||
source_path="/meta",
|
||||
dest_path="/meta",
|
||||
if_exists=if_exists,
|
||||
)
|
||||
else:
|
||||
meta_group = root.create_group("meta", overwrite=True)
|
||||
# save meta, no chunking
|
||||
for key, value in self.root["meta"].items():
|
||||
_ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
|
||||
|
||||
# save data, chunk
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
for key, value in self.root["data"].items():
|
||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
if isinstance(value, zarr.Array):
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=self.root.store,
|
||||
dest=store,
|
||||
source_path=this_path,
|
||||
dest_path=this_path,
|
||||
if_exists=if_exists,
|
||||
)
|
||||
else:
|
||||
# copy with recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
source=value,
|
||||
dest=data_group,
|
||||
name=key,
|
||||
chunks=cks,
|
||||
compressor=cpr,
|
||||
if_exists=if_exists,
|
||||
)
|
||||
else:
|
||||
# numpy
|
||||
_ = data_group.array(name=key, data=value, chunks=cks, compressor=cpr)
|
||||
return store
|
||||
|
||||
def save_to_path(
|
||||
self,
|
||||
zarr_path,
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
||||
if_exists="replace",
|
||||
**kwargs,
|
||||
):
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
store = zarr.DirectoryStore(os.path.expanduser(zarr_path))
|
||||
return self.save_to_store(
|
||||
store, chunks=chunks, compressors=compressors, if_exists=if_exists, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def resolve_compressor(compressor="default"):
|
||||
if compressor == "default":
|
||||
compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
|
||||
elif compressor == "disk":
|
||||
compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
|
||||
return compressor
|
||||
|
||||
@classmethod
|
||||
def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
|
||||
# allows compressor to be explicitly set to None
|
||||
cpr = "nil"
|
||||
if isinstance(compressors, dict):
|
||||
if key in compressors:
|
||||
cpr = cls.resolve_compressor(compressors[key])
|
||||
elif isinstance(array, zarr.Array):
|
||||
cpr = array.compressor
|
||||
else:
|
||||
cpr = cls.resolve_compressor(compressors)
|
||||
# backup default
|
||||
if cpr == "nil":
|
||||
cpr = cls.resolve_compressor("default")
|
||||
return cpr
|
||||
|
||||
@classmethod
|
||||
def _resolve_array_chunks(cls, chunks: dict | tuple, key, array):
|
||||
cks = None
|
||||
if isinstance(chunks, dict):
|
||||
if key in chunks:
|
||||
cks = chunks[key]
|
||||
elif isinstance(array, zarr.Array):
|
||||
cks = array.chunks
|
||||
elif isinstance(chunks, tuple):
|
||||
cks = chunks
|
||||
else:
|
||||
raise TypeError(f"Unsupported chunks type {type(chunks)}")
|
||||
# backup default
|
||||
if cks is None:
|
||||
cks = get_optimal_chunks(shape=array.shape, dtype=array.dtype)
|
||||
# check
|
||||
check_chunks_compatible(chunks=cks, shape=array.shape)
|
||||
return cks
|
||||
|
||||
# ============= properties =================
|
||||
@cached_property
|
||||
def data(self):
|
||||
return self.root["data"]
|
||||
|
||||
@cached_property
|
||||
def meta(self):
|
||||
return self.root["meta"]
|
||||
|
||||
def update_meta(self, data):
|
||||
# sanitize data
|
||||
np_data = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
np_data[key] = value
|
||||
else:
|
||||
arr = np.array(value)
|
||||
if arr.dtype == object:
|
||||
raise TypeError(f"Invalid value type {type(value)}")
|
||||
np_data[key] = arr
|
||||
|
||||
meta_group = self.meta
|
||||
if self.backend == "zarr":
|
||||
for key, value in np_data.items():
|
||||
_ = meta_group.array(
|
||||
name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
|
||||
)
|
||||
else:
|
||||
meta_group.update(np_data)
|
||||
|
||||
return meta_group
|
||||
|
||||
@property
|
||||
def episode_ends(self):
|
||||
return self.meta["episode_ends"]
|
||||
|
||||
def get_episode_idxs(self):
|
||||
import numba
|
||||
|
||||
numba.jit(nopython=True)
|
||||
|
||||
def _get_episode_idxs(episode_ends):
|
||||
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
||||
for i in range(len(episode_ends)):
|
||||
start = 0
|
||||
if i > 0:
|
||||
start = episode_ends[i - 1]
|
||||
end = episode_ends[i]
|
||||
for idx in range(start, end):
|
||||
result[idx] = i
|
||||
return result
|
||||
|
||||
return _get_episode_idxs(self.episode_ends)
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
backend = "numpy"
|
||||
if isinstance(self.root, zarr.Group):
|
||||
backend = "zarr"
|
||||
return backend
|
||||
|
||||
# =========== dict-like API ==============
|
||||
def __repr__(self) -> str:
|
||||
if self.backend == "zarr":
|
||||
return str(self.root.tree())
|
||||
else:
|
||||
return super().__repr__()
|
||||
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
def values(self):
|
||||
return self.data.values()
|
||||
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.data
|
||||
|
||||
# =========== our API ==============
|
||||
@property
|
||||
def n_steps(self):
|
||||
if len(self.episode_ends) == 0:
|
||||
return 0
|
||||
return self.episode_ends[-1]
|
||||
|
||||
@property
|
||||
def n_episodes(self):
|
||||
return len(self.episode_ends)
|
||||
|
||||
@property
|
||||
def chunk_size(self):
|
||||
if self.backend == "zarr":
|
||||
return next(iter(self.data.arrays()))[-1].chunks[0]
|
||||
return None
|
||||
|
||||
@property
|
||||
def episode_lengths(self):
|
||||
ends = self.episode_ends[:]
|
||||
ends = np.insert(ends, 0, 0)
|
||||
lengths = np.diff(ends)
|
||||
return lengths
|
||||
|
||||
def add_episode(
|
||||
self,
|
||||
data: dict[str, np.ndarray],
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
||||
):
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
assert len(data) > 0
|
||||
is_zarr = self.backend == "zarr"
|
||||
|
||||
curr_len = self.n_steps
|
||||
episode_length = None
|
||||
for value in data.values():
|
||||
assert len(value.shape) >= 1
|
||||
if episode_length is None:
|
||||
episode_length = len(value)
|
||||
else:
|
||||
assert episode_length == len(value)
|
||||
new_len = curr_len + episode_length
|
||||
|
||||
for key, value in data.items():
|
||||
new_shape = (new_len,) + value.shape[1:]
|
||||
# create array
|
||||
if key not in self.data:
|
||||
if is_zarr:
|
||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
arr = self.data.zeros(
|
||||
name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
|
||||
)
|
||||
else:
|
||||
# copy data to prevent modify
|
||||
arr = np.zeros(shape=new_shape, dtype=value.dtype)
|
||||
self.data[key] = arr
|
||||
else:
|
||||
arr = self.data[key]
|
||||
assert value.shape[1:] == arr.shape[1:]
|
||||
# same method for both zarr and numpy
|
||||
if is_zarr:
|
||||
arr.resize(new_shape)
|
||||
else:
|
||||
arr.resize(new_shape, refcheck=False)
|
||||
# copy data
|
||||
arr[-value.shape[0] :] = value
|
||||
|
||||
# append to episode ends
|
||||
episode_ends = self.episode_ends
|
||||
if is_zarr:
|
||||
episode_ends.resize(episode_ends.shape[0] + 1)
|
||||
else:
|
||||
episode_ends.resize(episode_ends.shape[0] + 1, refcheck=False)
|
||||
episode_ends[-1] = new_len
|
||||
|
||||
# rechunk
|
||||
if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
|
||||
rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
|
||||
|
||||
def drop_episode(self):
|
||||
is_zarr = self.backend == "zarr"
|
||||
episode_ends = self.episode_ends[:].copy()
|
||||
assert len(episode_ends) > 0
|
||||
start_idx = 0
|
||||
if len(episode_ends) > 1:
|
||||
start_idx = episode_ends[-2]
|
||||
for value in self.data.values():
|
||||
new_shape = (start_idx,) + value.shape[1:]
|
||||
if is_zarr:
|
||||
value.resize(new_shape)
|
||||
else:
|
||||
value.resize(new_shape, refcheck=False)
|
||||
if is_zarr:
|
||||
self.episode_ends.resize(len(episode_ends) - 1)
|
||||
else:
|
||||
self.episode_ends.resize(len(episode_ends) - 1, refcheck=False)
|
||||
|
||||
def pop_episode(self):
|
||||
assert self.n_episodes > 0
|
||||
episode = self.get_episode(self.n_episodes - 1, copy=True)
|
||||
self.drop_episode()
|
||||
return episode
|
||||
|
||||
def extend(self, data):
|
||||
self.add_episode(data)
|
||||
|
||||
def get_episode(self, idx, copy=False):
|
||||
idx = list(range(len(self.episode_ends)))[idx]
|
||||
start_idx = 0
|
||||
if idx > 0:
|
||||
start_idx = self.episode_ends[idx - 1]
|
||||
end_idx = self.episode_ends[idx]
|
||||
result = self.get_steps_slice(start_idx, end_idx, copy=copy)
|
||||
return result
|
||||
|
||||
def get_episode_slice(self, idx):
|
||||
start_idx = 0
|
||||
if idx > 0:
|
||||
start_idx = self.episode_ends[idx - 1]
|
||||
end_idx = self.episode_ends[idx]
|
||||
return slice(start_idx, end_idx)
|
||||
|
||||
def get_steps_slice(self, start, stop, step=None, copy=False):
|
||||
_slice = slice(start, stop, step)
|
||||
|
||||
result = {}
|
||||
for key, value in self.data.items():
|
||||
x = value[_slice]
|
||||
if copy and isinstance(value, np.ndarray):
|
||||
x = x.copy()
|
||||
result[key] = x
|
||||
return result
|
||||
|
||||
# =========== chunking =============
|
||||
def get_chunks(self) -> dict:
|
||||
assert self.backend == "zarr"
|
||||
chunks = {}
|
||||
for key, value in self.data.items():
|
||||
chunks[key] = value.chunks
|
||||
return chunks
|
||||
|
||||
def set_chunks(self, chunks: dict):
|
||||
assert self.backend == "zarr"
|
||||
for key, value in chunks.items():
|
||||
if key in self.data:
|
||||
arr = self.data[key]
|
||||
if value != arr.chunks:
|
||||
check_chunks_compatible(chunks=value, shape=arr.shape)
|
||||
rechunk_recompress_array(self.data, key, chunks=value)
|
||||
|
||||
def get_compressors(self) -> dict:
|
||||
assert self.backend == "zarr"
|
||||
compressors = {}
|
||||
for key, value in self.data.items():
|
||||
compressors[key] = value.compressor
|
||||
return compressors
|
||||
|
||||
def set_compressors(self, compressors: dict):
|
||||
assert self.backend == "zarr"
|
||||
for key, value in compressors.items():
|
||||
if key in self.data:
|
||||
arr = self.data[key]
|
||||
compressor = self.resolve_compressor(value)
|
||||
if compressor != arr.compressor:
|
||||
rechunk_recompress_array(self.data, key, compressor=compressor)
|
||||
@@ -1,6 +1,6 @@
|
||||
def make_policy(cfg):
|
||||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc import TDMPC
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPC
|
||||
|
||||
policy = TDMPC(cfg.policy, cfg.device)
|
||||
elif cfg.policy.name == "diffusion":
|
||||
@@ -8,6 +8,7 @@ def make_policy(cfg):
|
||||
|
||||
policy = DiffusionPolicy(
|
||||
cfg=cfg.policy,
|
||||
cfg_device=cfg.device,
|
||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
@@ -16,6 +17,12 @@ def make_policy(cfg):
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
elif cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||
|
||||
policy = ActionChunkingTransformerPolicy(
|
||||
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
|
||||
)
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import lerobot.common.policies.tdmpc_helper as h
|
||||
import lerobot.common.policies.tdmpc.helper as h
|
||||
|
||||
FIRST_FRAME = 0
|
||||
|
||||
@@ -21,6 +21,8 @@ save_buffer: false
|
||||
train_steps: ???
|
||||
fps: ???
|
||||
|
||||
offline_prioritized_sampler: true
|
||||
|
||||
n_action_steps: ???
|
||||
env: ???
|
||||
|
||||
@@ -29,5 +31,4 @@ policy: ???
|
||||
wandb:
|
||||
enable: true
|
||||
project: lerobot
|
||||
entity: rcadene # insert your own
|
||||
notes: ""
|
||||
|
||||
25
lerobot/configs/env/aloha.yaml
vendored
Normal file
25
lerobot/configs/env/aloha.yaml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
# @package _global_
|
||||
|
||||
eval_episodes: 50
|
||||
eval_freq: 7500
|
||||
save_freq: 75000
|
||||
log_freq: 250
|
||||
# TODO: same as simxarm, need to adjust
|
||||
offline_steps: 25000
|
||||
online_steps: 25000
|
||||
|
||||
fps: 50
|
||||
|
||||
env:
|
||||
name: aloha
|
||||
task: sim_insertion_human
|
||||
from_pixels: True
|
||||
pixels_only: False
|
||||
image_size: [3, 480, 640]
|
||||
action_repeat: 1
|
||||
episode_length: 400
|
||||
fps: ${fps}
|
||||
|
||||
policy:
|
||||
state_dim: 14
|
||||
action_dim: 14
|
||||
58
lerobot/configs/policy/act.yaml
Normal file
58
lerobot/configs/policy/act.yaml
Normal file
@@ -0,0 +1,58 @@
|
||||
# @package _global_
|
||||
|
||||
offline_steps: 1344000
|
||||
online_steps: 0
|
||||
|
||||
eval_episodes: 1
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
|
||||
horizon: 100
|
||||
n_obs_steps: 1
|
||||
n_latency_steps: 0
|
||||
# when temporal_agg=False, n_action_steps=horizon
|
||||
n_action_steps: ${horizon}
|
||||
|
||||
policy:
|
||||
name: act
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
backbone: resnet18
|
||||
num_queries: ${horizon} # chunk_size
|
||||
horizon: ${horizon} # chunk_size
|
||||
kl_weight: 10
|
||||
hidden_dim: 512
|
||||
dim_feedforward: 3200
|
||||
enc_layers: 4
|
||||
dec_layers: 7
|
||||
nheads: 8
|
||||
#camera_names: [top, front_close, left_pillar, right_pillar]
|
||||
camera_names: [top]
|
||||
position_embedding: sine
|
||||
masks: false
|
||||
dilation: false
|
||||
dropout: 0.1
|
||||
pre_norm: false
|
||||
|
||||
vae: true
|
||||
|
||||
batch_size: 8
|
||||
|
||||
per_alpha: 0.6
|
||||
per_beta: 0.4
|
||||
|
||||
balanced_sampling: false
|
||||
utd: 1
|
||||
|
||||
n_obs_steps: ${n_obs_steps}
|
||||
|
||||
temporal_agg: false
|
||||
|
||||
state_dim: ???
|
||||
action_dim: ???
|
||||
@@ -29,6 +29,8 @@ log_freq: 250
|
||||
offline_steps: 1344000
|
||||
online_steps: 0
|
||||
|
||||
offline_prioritized_sampler: true
|
||||
|
||||
policy:
|
||||
name: diffusion
|
||||
|
||||
@@ -72,7 +74,6 @@ noise_scheduler:
|
||||
prediction_type: epsilon # or sample
|
||||
|
||||
obs_encoder:
|
||||
# _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder
|
||||
shape_meta: ${shape_meta}
|
||||
# resize_shape: null
|
||||
# crop_shape: [76, 76]
|
||||
@@ -83,12 +84,12 @@ obs_encoder:
|
||||
imagenet_norm: True
|
||||
|
||||
rgb_model:
|
||||
_target_: diffusion_policy.model.vision.model_getter.get_resnet
|
||||
_target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet
|
||||
name: resnet18
|
||||
weights: null
|
||||
|
||||
ema:
|
||||
_target_: diffusion_policy.model.diffusion.ema_model.EMAModel
|
||||
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
|
||||
update_after_step: 0
|
||||
inv_gamma: 1.0
|
||||
power: 0.75
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
# TODO(rcadene): obsolete remove
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
|
||||
|
||||
def download():
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
download_path = "data.zip"
|
||||
gdown.download(url, download_path, quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(download_path, "r") as zip_f:
|
||||
for member in zip_f.namelist():
|
||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||
print(member)
|
||||
zip_f.extract(member=member)
|
||||
os.remove(download_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download()
|
||||
@@ -9,13 +9,13 @@ import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from tensordict.nn import TensorDictModule
|
||||
from termcolor import colored
|
||||
from torchrl.envs import EnvBase
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import set_seed
|
||||
from lerobot.common.utils import init_logging, set_seed
|
||||
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
@@ -38,27 +38,18 @@ def eval_policy(
|
||||
successes = []
|
||||
threads = []
|
||||
for i in tqdm.tqdm(range(num_episodes)):
|
||||
tensordict = env.reset()
|
||||
|
||||
ep_frames = []
|
||||
|
||||
if save_video or (return_first_video and i == 0):
|
||||
|
||||
def rendering_callback(env, td=None):
|
||||
def render_frame(env):
|
||||
ep_frames.append(env.render()) # noqa: B023
|
||||
|
||||
# render first frame before rollout
|
||||
rendering_callback(env)
|
||||
else:
|
||||
rendering_callback = None
|
||||
env.register_rendering_hook(render_frame)
|
||||
|
||||
with torch.inference_mode():
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
callback=rendering_callback,
|
||||
auto_reset=False,
|
||||
tensordict=tensordict,
|
||||
auto_cast_to_device=True,
|
||||
)
|
||||
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
||||
@@ -85,6 +76,8 @@ def eval_policy(
|
||||
if return_first_video and i == 0:
|
||||
first_video = stacked_frames.transpose(0, 3, 1, 2)
|
||||
|
||||
env.reset_rendering_hooks()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
@@ -109,16 +102,24 @@ def eval(cfg: dict, out_dir=None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
init_logging()
|
||||
|
||||
if cfg.device == "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
else:
|
||||
logging.warning("Using CPU, this will be slow.")
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_seed(cfg.seed)
|
||||
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
|
||||
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("make_offline_buffer")
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer._transform)
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
policy = make_policy(cfg)
|
||||
@@ -142,6 +143,8 @@ def eval(cfg: dict, out_dir=None):
|
||||
)
|
||||
print(metrics)
|
||||
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
eval_cli()
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict.nn import TensorDictModule
|
||||
from termcolor import colored
|
||||
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import format_big_number, init_logging, set_seed
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
@@ -143,11 +143,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
online_buffer = TensorDictReplayBuffer(
|
||||
storage=LazyMemmapStorage(100_000),
|
||||
sampler=online_sampler,
|
||||
transform=offline_buffer._transform,
|
||||
transform=offline_buffer.transform,
|
||||
)
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer._transform)
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg)
|
||||
@@ -164,7 +164,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(out_dir, job_name, cfg)
|
||||
|
||||
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
||||
logging.info(f"{cfg.online_steps=}")
|
||||
@@ -174,7 +174,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
step = 0 # number of policy update
|
||||
step = 0 # number of policy update (forward + backward + optim)
|
||||
|
||||
is_offline = True
|
||||
for offline_step in range(cfg.offline_steps):
|
||||
@@ -193,6 +193,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
num_episodes=cfg.eval_episodes,
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
return_first_video=True,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
save_video=True,
|
||||
)
|
||||
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
|
||||
if cfg.wandb.enable:
|
||||
@@ -212,7 +214,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
for env_step in range(cfg.online_steps):
|
||||
if env_step == 0:
|
||||
logging.info("Start online training by interacting with environment")
|
||||
# TODO: use SyncDataCollector for that?
|
||||
# TODO: add configurable number of rollout? (default=1)
|
||||
with torch.no_grad():
|
||||
rollout = env.rollout(
|
||||
@@ -268,6 +269,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
step += 1
|
||||
online_step += 1
|
||||
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import hydra
|
||||
import imageio
|
||||
import torch
|
||||
from torchrl.data.replay_buffers import SliceSamplerWithoutReplacement
|
||||
from torchrl.data.replay_buffers import (
|
||||
SamplerWithoutReplacement,
|
||||
)
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.utils import init_logging
|
||||
|
||||
NUM_EPISODES_TO_RENDER = 10
|
||||
NUM_EPISODES_TO_RENDER = 50
|
||||
MAX_NUM_STEPS = 1000
|
||||
FIRST_FRAME = 0
|
||||
|
||||
@@ -17,45 +24,88 @@ def visualize_dataset_cli(cfg: dict):
|
||||
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
|
||||
|
||||
|
||||
def cat_and_write_video(video_path, frames, fps):
|
||||
frames = torch.cat(frames)
|
||||
assert frames.dtype == torch.uint8
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
||||
imageio.mimsave(video_path, frames, fps=fps)
|
||||
|
||||
|
||||
def visualize_dataset(cfg: dict, out_dir=None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
sampler = SliceSamplerWithoutReplacement(
|
||||
num_slices=1,
|
||||
strict_length=False,
|
||||
init_logging()
|
||||
log_output_dir(out_dir)
|
||||
|
||||
# we expect frames of each episode to be stored next to each others sequentially
|
||||
sampler = SamplerWithoutReplacement(
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
offline_buffer = make_offline_buffer(cfg, sampler)
|
||||
logging.info("make_offline_buffer")
|
||||
offline_buffer = make_offline_buffer(
|
||||
cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12
|
||||
)
|
||||
|
||||
for _ in range(NUM_EPISODES_TO_RENDER):
|
||||
episode = offline_buffer.sample(MAX_NUM_STEPS)
|
||||
logging.info("Start rendering episodes from offline buffer")
|
||||
|
||||
ep_idx = episode["episode"][FIRST_FRAME].item()
|
||||
ep_frames = torch.cat(
|
||||
[
|
||||
episode["observation"]["image"][FIRST_FRAME][None, ...],
|
||||
episode["next", "observation"]["image"],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
threads = []
|
||||
frames = {}
|
||||
current_ep_idx = 0
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER):
|
||||
# TODO(rcadene): make it work with bsize > 1
|
||||
ep_td = offline_buffer.sample(1)
|
||||
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
||||
|
||||
video_dir = Path(out_dir) / "visualize_dataset"
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
# TODO(rcadene): make fps configurable
|
||||
video_path = video_dir / f"episode_{ep_idx}.mp4"
|
||||
# TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
||||
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
|
||||
new_episode = ep_idx != current_ep_idx
|
||||
|
||||
assert ep_frames.min().item() >= 0
|
||||
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
|
||||
assert ep_frames.max().item() <= 255
|
||||
ep_frames = ep_frames.type(torch.uint8)
|
||||
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps)
|
||||
if new_episode:
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
|
||||
# ran out of episodes
|
||||
if offline_buffer._sampler._sample_list.numel() == 0:
|
||||
for im_key in offline_buffer.image_keys:
|
||||
if new_episode or no_more_frames:
|
||||
# append last observed frames (the ones after last action taken)
|
||||
frames[im_key].append(ep_td[("next", *im_key)])
|
||||
|
||||
video_dir = Path(out_dir) / "visualize_dataset"
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(offline_buffer.image_keys) > 1:
|
||||
camera = im_key[-1]
|
||||
video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
||||
else:
|
||||
video_path = video_dir / f"episode_{current_ep_idx}.mp4"
|
||||
|
||||
thread = threading.Thread(
|
||||
target=cat_and_write_video,
|
||||
args=(str(video_path), frames[im_key], cfg.fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
current_ep_idx = ep_idx
|
||||
|
||||
# reset list of frames
|
||||
del frames[im_key]
|
||||
|
||||
# append current cameras images to list of frames
|
||||
if im_key not in frames:
|
||||
frames[im_key] = []
|
||||
frames[im_key].append(ep_td[im_key])
|
||||
|
||||
if no_more_frames:
|
||||
logging.info("Ran out of frames")
|
||||
break
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
logging.info("End of visualize_dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
visualize_dataset_cli()
|
||||
|
||||
1006
poetry.lock
generated
1006
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -14,10 +14,11 @@ classifiers=[
|
||||
"Intended Audience :: Developers",
|
||||
"Topic :: Software Development :: Build Tools",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
packages = [{include = "lerobot"}]
|
||||
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
cython = "^3.0.8"
|
||||
@@ -41,19 +42,22 @@ mpmath = "^1.3.0"
|
||||
torch = "^2.2.1"
|
||||
tensordict = {git = "https://github.com/pytorch/tensordict"}
|
||||
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
|
||||
mujoco = "^3.1.2"
|
||||
mujoco = "2.3.7"
|
||||
mujoco-py = "^2.1.2.14"
|
||||
gym = "^0.26.2"
|
||||
opencv-python = "^4.9.0.80"
|
||||
diffusion-policy = {git = "https://github.com/real-stanford/diffusion_policy"}
|
||||
diffusers = "^0.26.3"
|
||||
torchvision = "^0.17.1"
|
||||
h5py = "^3.10.0"
|
||||
dm-control = "1.0.14"
|
||||
huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"}
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pre-commit = "^3.6.2"
|
||||
debugpy = "^1.8.1"
|
||||
pytest = "^8.1.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
@@ -82,5 +86,15 @@ exclude = [
|
||||
"venv",
|
||||
]
|
||||
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||
|
||||
|
||||
[tool.poetry-dynamic-versioning]
|
||||
enable = true
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
|
||||
build-backend = "poetry_dynamic_versioning.backend"
|
||||
|
||||
@@ -17,6 +17,7 @@ apptainer exec --nv \
|
||||
~/apptainer/nvidia_cuda:12.2.2-devel-ubuntu22.04.sif $SHELL
|
||||
|
||||
source ~/.bashrc
|
||||
conda activate fowm
|
||||
#conda activate fowm
|
||||
conda activate lerobot
|
||||
|
||||
srun $CMD
|
||||
|
||||
17
sbatch_hopper.sh
Normal file
17
sbatch_hopper.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=1 # total number of nodes (N to be defined)
|
||||
#SBATCH --ntasks-per-node=1 # number of tasks per node (here 8 tasks, or 1 task per GPU)
|
||||
#SBATCH --qos=normal # number of GPUs reserved per node (here 8, or all the GPUs)
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --gres=gpu:1 # number of GPUs reserved per node (here 8, or all the GPUs)
|
||||
#SBATCH --cpus-per-task=12 # number of cores per task
|
||||
#SBATCH --mem-per-cpu=11G
|
||||
#SBATCH --time=12:00:00
|
||||
#SBATCH --output=/admin/home/remi_cadene/slurm/%j.out
|
||||
#SBATCH --error=/admin/home/remi_cadene/slurm/%j.err
|
||||
#SBATCH --mail-user=remi_cadene@huggingface.co
|
||||
#SBATCH --mail-type=ALL
|
||||
|
||||
CMD=$@
|
||||
echo "command: $CMD"
|
||||
srun $CMD
|
||||
159
setup.py
159
setup.py
@@ -1,159 +0,0 @@
|
||||
"""A setuptools based setup module.
|
||||
|
||||
See:
|
||||
https://packaging.python.org/en/latest/distributing.html
|
||||
https://github.com/pypa/sampleproject
|
||||
"""
|
||||
|
||||
# To use a consistent encoding
|
||||
from codecs import open
|
||||
from os import path
|
||||
|
||||
# Always prefer setuptools over distutils
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
here = path.abspath(path.dirname(__file__))
|
||||
|
||||
# Get the long description from the README file
|
||||
with open(path.join(here, "README.md"), encoding="utf-8") as f:
|
||||
long_description = f.read()
|
||||
|
||||
# Arguments marked as "Required" below must be included for upload to PyPI.
|
||||
# Fields marked as "Optional" may be commented out.
|
||||
|
||||
# https://stackoverflow.com/questions/458550/standard-way-to-embed-version-into-python-package/16084844#16084844
|
||||
exec(open(path.join(here, "lerobot", "__version__.py")).read())
|
||||
setup(
|
||||
# This is the name of your project. The first time you publish this
|
||||
# package, this name will be registered for you. It will determine how
|
||||
# users can install this project, e.g.:
|
||||
#
|
||||
# $ pip install sampleproject
|
||||
#
|
||||
# And where it will live on PyPI: https://pypi.org/project/sampleproject/
|
||||
#
|
||||
# There are some restrictions on what makes a valid project name
|
||||
# specification here:
|
||||
# https://packaging.python.org/specifications/core-metadata/#name
|
||||
name="lerobot", # Required
|
||||
# Versions should comply with PEP 440:
|
||||
# https://www.python.org/dev/peps/pep-0440/
|
||||
#
|
||||
# For a discussion on single-sourcing the version across setup.py and the
|
||||
# project code, see
|
||||
# https://packaging.python.org/en/latest/single_source_version.html
|
||||
version=__version__, # noqa: F821 # Required
|
||||
# This is a one-line description or tagline of what your project does. This
|
||||
# corresponds to the "Summary" metadata field:
|
||||
# https://packaging.python.org/specifications/core-metadata/#summary
|
||||
description="Le robot is learning", # Required
|
||||
# This is an optional longer description of your project that represents
|
||||
# the body of text which users will see when they visit PyPI.
|
||||
#
|
||||
# Often, this is the same as your README, so you can just read it in from
|
||||
# that file directly (as we have already done above)
|
||||
#
|
||||
# This field corresponds to the "Description" metadata field:
|
||||
# https://packaging.python.org/specifications/core-metadata/#description-optional
|
||||
long_description=long_description, # Optional
|
||||
# This should be a valid link to your project's main homepage.
|
||||
#
|
||||
# This field corresponds to the "Home-Page" metadata field:
|
||||
# https://packaging.python.org/specifications/core-metadata/#home-page-optional
|
||||
url="https://github.com/cadene/lerobot", # Optional
|
||||
# This should be your name or the name of the organization which owns the
|
||||
# project.
|
||||
author="Remi Cadene", # Optional
|
||||
# This should be a valid email address corresponding to the author listed
|
||||
# above.
|
||||
author_email="re.cadene@gmail.com", # Optional
|
||||
# Classifiers help users find your project by categorizing it.
|
||||
#
|
||||
# For a list of valid classifiers, see
|
||||
# https://pypi.python.org/pypi?%3Aaction=list_classifiers
|
||||
classifiers=[ # Optional
|
||||
# How mature is this project? Common values are
|
||||
# 3 - Alpha
|
||||
# 4 - Beta
|
||||
# 5 - Production/Stable
|
||||
"Development Status :: 3 - Alpha",
|
||||
# Indicate who your project is intended for
|
||||
"Intended Audience :: Developers",
|
||||
"Topic :: Software Development :: Build Tools",
|
||||
# Pick your license as you wish
|
||||
"License :: OSI Approved :: MIT License",
|
||||
# Specify the Python versions you support here. In particular, ensure
|
||||
# that you indicate whether you support Python 2, Python 3 or both.
|
||||
"Programming Language :: Python :: 3.7",
|
||||
],
|
||||
# This field adds keywords for your project which will appear on the
|
||||
# project page. What does your project relate to?
|
||||
#
|
||||
# Note that this is a string of words separated by whitespace, not a list.
|
||||
keywords="pytorch framework bootstrap deep learning scaffolding", # Optional
|
||||
# You can just specify package directories manually here if your project is
|
||||
# simple. Or you can use find_packages().
|
||||
#
|
||||
# Alternatively, if you just want to distribute a single Python file, use
|
||||
# the `py_modules` argument instead as follows, which will expect a file
|
||||
# called `my_module.py` to exist:
|
||||
#
|
||||
# py_modules=["my_module"],
|
||||
#
|
||||
packages=find_packages(
|
||||
exclude=[
|
||||
"data",
|
||||
"logs",
|
||||
]
|
||||
),
|
||||
# This field lists other packages that your project depends on to run.
|
||||
# Any package you put here will be installed by pip when your project is
|
||||
# installed, so they must be valid existing projects.
|
||||
#
|
||||
# For an analysis of "install_requires" vs pip's requirements files see:
|
||||
# https://packaging.python.org/en/latest/requirements.html
|
||||
install_requires=[
|
||||
"torch",
|
||||
"numpy",
|
||||
"argparse",
|
||||
],
|
||||
# List additional groups of dependencies here (e.g. development
|
||||
# dependencies). Users will be able to install these using the "extras"
|
||||
# syntax, for example:
|
||||
#
|
||||
# $ pip install sampleproject[dev]
|
||||
#
|
||||
# Similar to `install_requires` above, these must be valid existing
|
||||
# projects.
|
||||
# extras_require={ # Optional
|
||||
# 'dev': ['check-manifest'],
|
||||
# 'test': ['coverage'],
|
||||
# },
|
||||
# If there are data files included in your packages that need to be
|
||||
# installed, specify them here.
|
||||
#
|
||||
# If using Python 2.6 or earlier, then these have to be included in
|
||||
# MANIFEST.in as well.
|
||||
# package_data={ # Optional
|
||||
# 'sample': ['package_data.dat'],
|
||||
# },
|
||||
include_package_data=True,
|
||||
# Although 'package_data' is the preferred approach, in some case you may
|
||||
# need to place data files outside of your packages. See:
|
||||
# http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files
|
||||
#
|
||||
# In this case, 'data_file' will be installed into '<sys.prefix>/my_data'
|
||||
# data_files=[('my_data', ['data/data_file'])], # Optional
|
||||
# To provide executable scripts, use entry points in preference to the
|
||||
# "scripts" keyword. Entry points provide cross-platform support and allow
|
||||
# `pip` to create the appropriate form of executable for the target
|
||||
# platform.
|
||||
#
|
||||
# For example, the following would provide a command called `sample` which
|
||||
# executes the function `main` from this package when invoked:
|
||||
# entry_points={ # Optional
|
||||
# 'console_scripts': [
|
||||
# 'sample=sample:main',
|
||||
# ],
|
||||
# },
|
||||
)
|
||||
3
tests/data/aloha_sim_insertion_human/action.memmap
Normal file
3
tests/data/aloha_sim_insertion_human/action.memmap
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d789deddb081a9f4b626342391de8f48949d38fb5fdead87b5c0737b46c0877a
|
||||
size 2800
|
||||
3
tests/data/aloha_sim_insertion_human/episode.memmap
Normal file
3
tests/data/aloha_sim_insertion_human/episode.memmap
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7a12e561363385e9dfeeab326368731c030ed4b374e7f5897ac819159d2884c5
|
||||
size 400
|
||||
3
tests/data/aloha_sim_insertion_human/frame_id.memmap
Normal file
3
tests/data/aloha_sim_insertion_human/frame_id.memmap
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c202d9cfc7858fd49d522047e16948359bbbb2eda2d3825d552e45a78d5f8585
|
||||
size 400
|
||||
1
tests/data/aloha_sim_insertion_human/meta.json
Normal file
1
tests/data/aloha_sim_insertion_human/meta.json
Normal file
@@ -0,0 +1 @@
|
||||
{"action": {"device": "cpu", "shape": [50, 14], "dtype": "torch.float32"}, "episode": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "frame_id": {"device": "cpu", "shape": [50], "dtype": "torch.int64"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
||||
3
tests/data/aloha_sim_insertion_human/next/done.memmap
Normal file
3
tests/data/aloha_sim_insertion_human/next/done.memmap
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cc2786e1f9910a9d811400edcddaf7075195f7a16b216dcbefba3bc7c4f2ae51
|
||||
size 50
|
||||
1
tests/data/aloha_sim_insertion_human/next/meta.json
Normal file
1
tests/data/aloha_sim_insertion_human/next/meta.json
Normal file
@@ -0,0 +1 @@
|
||||
{"done": {"device": "cpu", "shape": [50, 1], "dtype": "torch.bool"}, "shape": [50], "device": "cpu", "_type": "<class 'tensordict._td.TensorDict'>"}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user