Compare commits
230 Commits
user/rcade
...
thom-propo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1e47202c0 | ||
|
|
24821fee24 | ||
|
|
4751642ace | ||
|
|
11cbf1bea1 | ||
|
|
f1148b8c2d | ||
|
|
2a98cc71ed | ||
|
|
a7c9b78e56 | ||
|
|
404b8f8a75 | ||
|
|
17c2bbbeb8 | ||
|
|
006e5feabf | ||
|
|
b99ee8180a | ||
|
|
6bddcb647e | ||
|
|
58df2066a9 | ||
|
|
c89aa4f8ed | ||
|
|
62aad7104b | ||
|
|
9d9148dad8 | ||
|
|
1b6cb2b1be | ||
|
|
6f1a0aefab | ||
|
|
b7c9c33072 | ||
|
|
120f0aef5c | ||
|
|
032200e32c | ||
|
|
de1e9187c8 | ||
|
|
4f8f1926f9 | ||
|
|
6710121a29 | ||
|
|
5f4b8ab899 | ||
|
|
18e7f4c3e6 | ||
|
|
643d64e2a8 | ||
|
|
c037722e23 | ||
|
|
6cd671040f | ||
|
|
b6353964ba | ||
|
|
64c8851c40 | ||
|
|
dc745e3037 | ||
|
|
6f0c2445ca | ||
|
|
d1d2229407 | ||
|
|
68d02c80cf | ||
|
|
011f2d27fe | ||
|
|
be4441c7ff | ||
|
|
1ed0110900 | ||
|
|
cb6d1e0871 | ||
|
|
9ced0cf1fb | ||
|
|
98534d1a63 | ||
|
|
edacc1d2a0 | ||
|
|
5a46b8a2a9 | ||
|
|
4a8c5e238e | ||
|
|
1a1308d62f | ||
|
|
203bcd7ca5 | ||
|
|
98b9631aa6 | ||
|
|
c5635b7d94 | ||
|
|
f00252552a | ||
|
|
90f6af9736 | ||
|
|
bcfdba109f | ||
|
|
0fae5b206b | ||
|
|
7cdd6d2450 | ||
|
|
058ac991eb | ||
|
|
d3adaf1379 | ||
|
|
dc89166bee | ||
|
|
5ef813ff1e | ||
|
|
a2ac83276b | ||
|
|
c0833f1c2d | ||
|
|
de5c30405e | ||
|
|
462e7469e8 | ||
|
|
298d391b26 | ||
|
|
be6364f109 | ||
|
|
127de1258d | ||
|
|
b905111895 | ||
|
|
0c41675986 | ||
|
|
1c24bbda3f | ||
|
|
e41c420a96 | ||
|
|
4a48b77540 | ||
|
|
f3cfc8b3b4 | ||
|
|
d2ef43436c | ||
|
|
40f3783fca | ||
|
|
e21ed6f510 | ||
|
|
bd40ffc53c | ||
|
|
d43fa600a0 | ||
|
|
e698d38a35 | ||
|
|
15ff3b3af8 | ||
|
|
a80d9c0257 | ||
|
|
b9047fbdd2 | ||
|
|
115927d0f6 | ||
|
|
529f42643d | ||
|
|
1b279a1fc0 | ||
|
|
3f0f95f4c0 | ||
|
|
8720c568d0 | ||
|
|
b633748987 | ||
|
|
41912b962b | ||
|
|
98361073ef | ||
|
|
48df15ed26 | ||
|
|
b562f89c3b | ||
|
|
4e10cd306b | ||
|
|
72d3c3120b | ||
|
|
acf1174447 | ||
|
|
1bd50122be | ||
|
|
2b0221052a | ||
|
|
4631d36c05 | ||
|
|
f23a53c3e4 | ||
|
|
82e6e01651 | ||
|
|
d323993569 | ||
|
|
ec536ef0fa | ||
|
|
3910c48e43 | ||
|
|
4b7ec81dde | ||
|
|
98a816f0f8 | ||
|
|
45a4a02b7e | ||
|
|
8bed0fc465 | ||
|
|
32e3f71dd1 | ||
|
|
5332766a82 | ||
|
|
b1ec3da035 | ||
|
|
d16f6a93b3 | ||
|
|
52e149fbfd | ||
|
|
4f1955edfd | ||
|
|
c5010fee9a | ||
|
|
18fa88475b | ||
|
|
b54cdc9a0f | ||
|
|
46ac87d2a6 | ||
|
|
896a11f60e | ||
|
|
2d5abbbd6f | ||
|
|
7d5d99e036 | ||
|
|
b420ab88f4 | ||
|
|
e799dc5e3f | ||
|
|
10034e85c4 | ||
|
|
ea17f4ce50 | ||
|
|
6a1a29386a | ||
|
|
88347965c2 | ||
|
|
09ddd9bf92 | ||
|
|
099a465367 | ||
|
|
8e346b379d | ||
|
|
bae7e7b41c | ||
|
|
75cc10198f | ||
|
|
3124f71ebd | ||
|
|
4ecfd17f9e | ||
|
|
58d1787ee3 | ||
|
|
b752833f3f | ||
|
|
a45896dc8d | ||
|
|
9c88071bc7 | ||
|
|
5805a7ffb1 | ||
|
|
41521f7e96 | ||
|
|
b10c9507d4 | ||
|
|
a311d38796 | ||
|
|
19730b3412 | ||
|
|
a222c88c99 | ||
|
|
736bc969ca | ||
|
|
4822d63dbe | ||
|
|
ba91976944 | ||
|
|
95e84079ef | ||
|
|
8e856f1bf7 | ||
|
|
8c2b47752a | ||
|
|
f515cb6efd | ||
|
|
c3f8d14fd8 | ||
|
|
98484ac68e | ||
|
|
9512d1d2f3 | ||
|
|
8c56770318 | ||
|
|
998dd2b874 | ||
|
|
7331df81d2 | ||
|
|
2c5d49cad5 | ||
|
|
5881eec376 | ||
|
|
29c73844b1 | ||
|
|
f9258898ff | ||
|
|
9d002032d1 | ||
|
|
060bac7672 | ||
|
|
337208f28d | ||
|
|
87fcc536f9 | ||
|
|
48e70e044e | ||
|
|
4449c06823 | ||
|
|
304355c917 | ||
|
|
2a01487494 | ||
|
|
a94800fc8a | ||
|
|
a207b416b7 | ||
|
|
78690d197f | ||
|
|
6d6c84b4a3 | ||
|
|
772a826bf2 | ||
|
|
2cb8ae5037 | ||
|
|
fab2b3240b | ||
|
|
84a1647c01 | ||
|
|
ccd5dc5a42 | ||
|
|
c1e9c13ade | ||
|
|
00fe4f4f18 | ||
|
|
225eebde40 | ||
|
|
816b2e9d63 | ||
|
|
a7ef4a6a33 | ||
|
|
d4ea4f0ad1 | ||
|
|
f54ee7cda0 | ||
|
|
134009f337 | ||
|
|
7982425670 | ||
|
|
6c867d78ef | ||
|
|
302b78962c | ||
|
|
59397fb44a | ||
|
|
1cc621ec36 | ||
|
|
471ebfef62 | ||
|
|
30753d879c | ||
|
|
c6fb40fb29 | ||
|
|
fa7a947acc | ||
|
|
450e32e4b5 | ||
|
|
0da85b2cef | ||
|
|
f2c7ab5b3b | ||
|
|
cde866dac0 | ||
|
|
a54a0feb63 | ||
|
|
f440a681ad | ||
|
|
35bd577deb | ||
|
|
327f60e4be | ||
|
|
74ad9d5154 | ||
|
|
89eaab140b | ||
|
|
7dbdbb051c | ||
|
|
4cc7e1539e | ||
|
|
f1e2837d63 | ||
|
|
54b05bfb77 | ||
|
|
524d29aa80 | ||
|
|
c2c0ef9927 | ||
|
|
66373e9b13 | ||
|
|
7d33b437fa | ||
|
|
b9dc3be463 | ||
|
|
86ec62f98a | ||
|
|
52bdfc659e | ||
|
|
d782b029e1 | ||
|
|
49c0955f97 | ||
|
|
eed24b083a | ||
|
|
f95ecd66fc | ||
|
|
d34c0a3c49 | ||
|
|
11a5a7ca45 | ||
|
|
a6d353c419 | ||
|
|
d6556e6519 | ||
|
|
12af67066d | ||
|
|
7a20ef65f6 | ||
|
|
2f80d71c3e | ||
|
|
d4e0849970 | ||
|
|
e132a267aa | ||
|
|
a027f4edfb | ||
|
|
570f8d01df | ||
|
|
7938adcdfc | ||
|
|
20c08bb740 | ||
|
|
2bcf2631b9 |
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*.memmap filter=lfs diff=lfs merge=lfs -text
|
||||
*.stl filter=lfs diff=lfs merge=lfs -text
|
||||
3333
.github/poetry/cpu/poetry.lock
generated
vendored
Normal file
3333
.github/poetry/cpu/poetry.lock
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
109
.github/poetry/cpu/pyproject.toml
vendored
Normal file
109
.github/poetry/cpu/pyproject.toml
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
[tool.poetry]
|
||||
name = "lerobot"
|
||||
version = "0.1.0"
|
||||
description = "Le robot is learning"
|
||||
authors = [
|
||||
"Rémi Cadène <re.cadene@gmail.com>",
|
||||
"Simon Alibert <alibert.sim@gmail.com>",
|
||||
]
|
||||
repository = "https://github.com/Cadene/lerobot"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Topic :: Software Development :: Build Tools",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
packages = [{include = "lerobot"}]
|
||||
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
termcolor = "^2.4.0"
|
||||
omegaconf = "^2.3.0"
|
||||
dm-env = "^1.6"
|
||||
pandas = "^2.2.1"
|
||||
wandb = "^0.16.3"
|
||||
moviepy = "^1.0.3"
|
||||
imageio = {extras = ["pyav"], version = "^2.34.0"}
|
||||
gdown = "^5.1.0"
|
||||
hydra-core = "^1.3.2"
|
||||
einops = "^0.7.0"
|
||||
pygame = "^2.5.2"
|
||||
pymunk = "^6.6.0"
|
||||
zarr = "^2.17.0"
|
||||
shapely = "^2.0.3"
|
||||
scikit-image = "^0.22.0"
|
||||
numba = "^0.59.0"
|
||||
mpmath = "^1.3.0"
|
||||
torch = {version = "^2.2.1", source = "torch-cpu"}
|
||||
tensordict = {git = "https://github.com/pytorch/tensordict"}
|
||||
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
|
||||
mujoco = "^2.3.7"
|
||||
opencv-python = "^4.9.0.80"
|
||||
diffusers = "^0.26.3"
|
||||
torchvision = {version = "^0.17.1", source = "torch-cpu"}
|
||||
h5py = "^3.10.0"
|
||||
dm = "^1.3"
|
||||
dm-control = "1.0.14"
|
||||
robomimic = "0.2.0"
|
||||
huggingface-hub = "^0.21.4"
|
||||
gymnasium-robotics = "^1.2.4"
|
||||
gymnasium = "^0.29.1"
|
||||
cmake = "^3.29.0.1"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pre-commit = "^3.6.2"
|
||||
debugpy = "^1.8.1"
|
||||
pytest = "^8.1.0"
|
||||
pytest-cov = "^5.0.0"
|
||||
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = "torch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
priority = "supplemental"
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
target-version = "py310"
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"venv",
|
||||
]
|
||||
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||
|
||||
|
||||
[tool.poetry-dynamic-versioning]
|
||||
enable = true
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
|
||||
build-backend = "poetry_dynamic_versioning.backend"
|
||||
234
.github/workflows/test.yml
vendored
Normal file
234
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,234 @@
|
||||
name: Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
types: [opened, synchronize, reopened, labeled]
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
if: |
|
||||
${{ github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'CI') }} ||
|
||||
${{ github.event_name == 'push' }}
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
POETRY_VERSION: 1.8.2
|
||||
DATA_DIR: tests/data
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
#----------------------------------------------
|
||||
# check-out repo and set-up python
|
||||
#----------------------------------------------
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up python
|
||||
id: setup-python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
#----------------------------------------------
|
||||
# install & configure poetry
|
||||
#----------------------------------------------
|
||||
- name: Load cached Poetry installation
|
||||
id: restore-poetry-cache
|
||||
uses: actions/cache/restore@v3
|
||||
with:
|
||||
path: ~/.local
|
||||
key: poetry-${{ env.POETRY_VERSION }}
|
||||
|
||||
- name: Install Poetry
|
||||
if: steps.restore-poetry-cache.outputs.cache-hit != 'true'
|
||||
uses: snok/install-poetry@v1
|
||||
with:
|
||||
version: ${{ env.POETRY_VERSION }}
|
||||
virtualenvs-create: true
|
||||
installer-parallel: true
|
||||
|
||||
- name: Save cached Poetry installation
|
||||
if: |
|
||||
steps.restore-poetry-cache.outputs.cache-hit != 'true' &&
|
||||
github.ref_name == 'main'
|
||||
id: save-poetry-cache
|
||||
uses: actions/cache/save@v3
|
||||
with:
|
||||
path: ~/.local
|
||||
key: poetry-${{ env.POETRY_VERSION }}
|
||||
|
||||
- name: Configure Poetry
|
||||
run: poetry config virtualenvs.in-project true
|
||||
|
||||
#----------------------------------------------
|
||||
# install dependencies
|
||||
#----------------------------------------------
|
||||
# TODO(aliberts): move to gpu runners
|
||||
- name: Select cpu dependencies # HACK
|
||||
run: cp -t . .github/poetry/cpu/pyproject.toml .github/poetry/cpu/poetry.lock
|
||||
|
||||
- name: Load cached venv
|
||||
id: restore-dependencies-cache
|
||||
uses: actions/cache/restore@v3
|
||||
with:
|
||||
path: .venv
|
||||
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
|
||||
env:
|
||||
TMPDIR: ~/tmp
|
||||
TEMP: ~/tmp
|
||||
TMP: ~/tmp
|
||||
run: |
|
||||
mkdir ~/tmp
|
||||
poetry install --no-interaction --no-root
|
||||
|
||||
- name: Save cached venv
|
||||
if: |
|
||||
steps.restore-dependencies-cache.outputs.cache-hit != 'true' &&
|
||||
github.ref_name == 'main'
|
||||
id: save-dependencies-cache
|
||||
uses: actions/cache/save@v3
|
||||
with:
|
||||
path: .venv
|
||||
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
|
||||
|
||||
- name: Install libegl1-mesa-dev (to use MUJOCO_GL=egl)
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||
|
||||
#----------------------------------------------
|
||||
# install project
|
||||
#----------------------------------------------
|
||||
- name: Install project
|
||||
run: poetry install --no-interaction
|
||||
|
||||
#----------------------------------------------
|
||||
# run tests & coverage
|
||||
#----------------------------------------------
|
||||
- name: Run tests
|
||||
env:
|
||||
LEROBOT_TESTS_DEVICE: cpu
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
pytest --cov=./lerobot --cov-report=xml tests
|
||||
|
||||
# TODO(aliberts): Link with HF Codecov account
|
||||
# - name: Upload coverage reports to Codecov with GitHub Action
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# files: ./coverage.xml
|
||||
# verbose: true
|
||||
|
||||
#----------------------------------------------
|
||||
# run end-to-end tests
|
||||
#----------------------------------------------
|
||||
- name: Test train ACT on ALOHA end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
env=aloha \
|
||||
wandb.enable=False \
|
||||
offline_steps=2 \
|
||||
online_steps=0 \
|
||||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=2 \
|
||||
horizon=20 \
|
||||
policy.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/act/
|
||||
|
||||
- name: Test eval ACT on ALOHA end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/eval.py \
|
||||
--config tests/outputs/act/.hydra/config.yaml \
|
||||
eval_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/act/models/2.pt
|
||||
|
||||
# TODO(aliberts): This takes ~2mn to run, needs to be improved
|
||||
# - name: Test eval ACT on ALOHA end-to-end (policy is None)
|
||||
# run: |
|
||||
# source .venv/bin/activate
|
||||
# python lerobot/scripts/eval.py \
|
||||
# --config lerobot/configs/default.yaml \
|
||||
# policy=act \
|
||||
# env=aloha \
|
||||
# eval_episodes=1 \
|
||||
# device=cpu
|
||||
|
||||
- name: Test train Diffusion on PushT end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/train.py \
|
||||
policy=diffusion \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
offline_steps=2 \
|
||||
online_steps=0 \
|
||||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=2 \
|
||||
hydra.run.dir=tests/outputs/diffusion/
|
||||
|
||||
- name: Test eval Diffusion on PushT end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/eval.py \
|
||||
--config tests/outputs/diffusion/.hydra/config.yaml \
|
||||
eval_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
|
||||
|
||||
- name: Test eval Diffusion on PushT end-to-end (policy is None)
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/eval.py \
|
||||
--config lerobot/configs/default.yaml \
|
||||
policy=diffusion \
|
||||
env=pusht \
|
||||
eval_episodes=1 \
|
||||
device=cpu
|
||||
|
||||
- name: Test train TDMPC on Simxarm end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/train.py \
|
||||
policy=tdmpc \
|
||||
env=simxarm \
|
||||
wandb.enable=False \
|
||||
offline_steps=1 \
|
||||
online_steps=1 \
|
||||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=2 \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
||||
- name: Test eval TDMPC on Simxarm end-to-end
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/eval.py \
|
||||
--config tests/outputs/tdmpc/.hydra/config.yaml \
|
||||
eval_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
|
||||
|
||||
- name: Test eval TDPMC on Simxarm end-to-end (policy is None)
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/eval.py \
|
||||
--config lerobot/configs/default.yaml \
|
||||
policy=tdmpc \
|
||||
env=simxarm \
|
||||
eval_episodes=1 \
|
||||
device=cpu
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,6 +1,3 @@
|
||||
# Custom
|
||||
diffusion_policy
|
||||
|
||||
# Logging
|
||||
logs
|
||||
tmp
|
||||
@@ -54,6 +51,7 @@ pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
!tests/data
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
exclude: ^(data/|tests/|diffusion_policy/)
|
||||
exclude: ^(data/|tests/)
|
||||
default_language_version:
|
||||
python: python3.10
|
||||
repos:
|
||||
@@ -14,11 +14,11 @@ repos:
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.15.1
|
||||
rev: v3.15.2
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.2.2
|
||||
rev: v0.3.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
|
||||
507
LICENSE
Normal file
507
LICENSE
Normal file
@@ -0,0 +1,507 @@
|
||||
Copyright 2024 The Hugging Face team. All rights reserved.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
||||
## Some of lerobot's code is derived from Diffusion Policy, which is subject to the following copyright notice:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Columbia Artificial Intelligence and Robotics Lab
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
## Some of lerobot's code is derived from FOWM, which is subject to the following copyright notice:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Yunhai Feng
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
## Some of lerobot's code is derived from simxarm, which is subject to the following copyright notice:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Nicklas Hansen & Yanjie Ze
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
## Some of lerobot's code is derived from ALOHA, which is subject to the following copyright notice:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Tony Z. Zhao
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
## Some of lerobot's code is derived from DETR, which is subject to the following copyright notice:
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2020 - present, Facebook, Inc
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
396
README.md
396
README.md
@@ -1,83 +1,360 @@
|
||||
# LeRobot
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="media/lerobot-logo-thumbnail.png">
|
||||
<source media="(prefers-color-scheme: light)" srcset="media/lerobot-logo-thumbnail.png">
|
||||
<img alt="LeRobot, Hugging Face Robotics Library" src="media/lerobot-logo-thumbnail.png" style="max-width: 100%;">
|
||||
</picture>
|
||||
<br/>
|
||||
<br/>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/test.yml?query=branch%3Amain)
|
||||
[](https://codecov.io/gh/huggingface/lerobot)
|
||||
[](https://www.python.org/downloads/)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/LICENSE)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://github.com/huggingface/lerobot/tree/main/examples)
|
||||
[](https://discord.gg/s3KuuzsPFb)
|
||||
|
||||
</div>
|
||||
|
||||
<h3 align="center">
|
||||
<p>State-of-the-art Machine Learning for real-world robotics</p>
|
||||
</h3>
|
||||
|
||||
---
|
||||
|
||||
|
||||
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
|
||||
|
||||
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
|
||||
|
||||
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulated environments so that everyone can get started. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
|
||||
|
||||
🤗 LeRobot hosts pretrained models and datasets on this HuggingFace community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
|
||||
|
||||
#### Examples of pretrained models and environments
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><img src="http://remicadene.com/assets/gif/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="http://remicadene.com/assets/gif/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="http://remicadene.com/assets/gif/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">ACT policy on ALOHA env</td>
|
||||
<td align="center">TDMPC policy on SimXArm env</td>
|
||||
<td align="center">Diffusion policy on PushT env</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
- ACT policy and ALOHA environment are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha/)
|
||||
- Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
|
||||
- TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
|
||||
- Abstractions and utilities for Reinforcement Learning come from [TorchRL](https://github.com/pytorch/rl)
|
||||
|
||||
## Installation
|
||||
|
||||
Create a virtual environment with python 3.10, e.g. using `conda`:
|
||||
Download our source code:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
```
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
[Install `poetry`](https://python-poetry.org/docs/#installation) (if you don't have it already)
|
||||
```
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
Then, install 🤗 LeRobot:
|
||||
```bash
|
||||
python -m pip install .
|
||||
```
|
||||
|
||||
Install dependencies
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
|
||||
```bash
|
||||
wandb login
|
||||
```
|
||||
|
||||
## Walkthrough
|
||||
|
||||
```
|
||||
.
|
||||
├── lerobot
|
||||
| ├── configs # contains hydra yaml files with all options that you can override in the command line
|
||||
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
|
||||
| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, simxarm.yaml
|
||||
| | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml
|
||||
| ├── common # contains classes and utilities
|
||||
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, simxarm
|
||||
| | ├── envs # various sim environments: aloha, pusht, simxarm
|
||||
| | └── policies # various policies: act, diffusion, tdmpc
|
||||
| └── scripts # contains functions to execute via command line
|
||||
| ├── visualize_dataset.py # load a dataset and render its demonstrations
|
||||
| ├── eval.py # load policy and evaluate it on an environment
|
||||
| └── train.py # train a policy via imitation learning and/or reinforcement learning
|
||||
├── outputs # contains results of scripts execution: logs, videos, model checkpoints
|
||||
├── .github
|
||||
| └── workflows
|
||||
| └── test.yml # defines install settings for continuous integration and specifies end-to-end tests
|
||||
└── tests # contains pytest utilities for continuous integration
|
||||
|
||||
```
|
||||
|
||||
### Visualize datasets
|
||||
|
||||
You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities:
|
||||
```python
|
||||
""" Copy pasted from `examples/1_visualize_dataset.py` """
|
||||
import lerobot
|
||||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
from torchrl.data.replay_buffers import SamplerWithoutReplacement
|
||||
from lerobot.scripts.visualize_dataset import render_dataset
|
||||
|
||||
print(lerobot.available_datasets)
|
||||
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
||||
|
||||
# we use this sampler to sample 1 frame after the other
|
||||
sampler = SamplerWithoutReplacement(shuffle=False)
|
||||
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler)
|
||||
|
||||
video_paths = render_dataset(
|
||||
dataset,
|
||||
out_dir="outputs/visualize_dataset/example",
|
||||
max_num_samples=300,
|
||||
fps=50,
|
||||
)
|
||||
print(video_paths)
|
||||
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||
```
|
||||
|
||||
Or you can achieve the same result by executing our script from the command line:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
env=aloha \
|
||||
task=sim_sim_transfer_cube_human \
|
||||
hydra.run.dir=outputs/visualize_dataset/example
|
||||
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||
```
|
||||
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
|
||||
|
||||
Or you can achieve the same result by executing our script from the command line:
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--hub-id lerobot/diffusion_policy_pusht_image \
|
||||
eval_episodes=10 \
|
||||
hydra.run.dir=outputs/eval/example_hub
|
||||
```
|
||||
|
||||
After training your own policy, you can also re-evaluate the checkpoints with:
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--config PATH/TO/FOLDER/config.yaml \
|
||||
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
|
||||
eval_episodes=10 \
|
||||
hydra.run.dir=outputs/eval/example_dir
|
||||
```
|
||||
|
||||
See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): check out [example 3](./examples/3_train_policy.py). After you run this, you may want to revisit [example 2](./examples/2_evaluate_pretrained_policy.py) to evaluate your training output!
|
||||
|
||||
In general, you can use our training script to easily train any policy on any environment:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
env=aloha \
|
||||
task=sim_insertion \
|
||||
dataset_id=aloha_sim_insertion_scripted \
|
||||
policy=act \
|
||||
hydra.run.dir=outputs/train/aloha_act
|
||||
```
|
||||
|
||||
## Contribute
|
||||
|
||||
Feel free to open issues and PRs, and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co).
|
||||
|
||||
### TODO
|
||||
|
||||
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
|
||||
|
||||
### Follow our style
|
||||
|
||||
```bash
|
||||
# install if needed
|
||||
pre-commit install
|
||||
# apply style and linter checks before git commit
|
||||
pre-commit
|
||||
```
|
||||
|
||||
### Add dependencies
|
||||
|
||||
Instead of using `pip` directly, we use `poetry` for development purposes to easily track our dependencies.
|
||||
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
|
||||
|
||||
Install the project with:
|
||||
```bash
|
||||
poetry install
|
||||
```
|
||||
|
||||
If you encounter a disk space error, try to change your tmp dir to a location where you have enough disk space, e.g.
|
||||
```
|
||||
mkdir ~/tmp
|
||||
export TMPDIR='~/tmp'
|
||||
Then, the equivalent of `pip install some-package`, would just be:
|
||||
```bash
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
Install `diffusion_policy` #HACK
|
||||
```
|
||||
# from this directory
|
||||
git clone https://github.com/real-stanford/diffusion_policy
|
||||
cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/
|
||||
**NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated to the CI. To do this, you should create a separate environment and add the new package there as well. For example:
|
||||
```bash
|
||||
# Add the new package to your main poetry env
|
||||
poetry add some-package
|
||||
# Add the same package to the CPU-only env dedicated to CI
|
||||
conda create -y -n lerobot-ci python=3.10
|
||||
conda activate lerobot-ci
|
||||
cd .github/poetry/cpu
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
## Usage
|
||||
### Run tests locally
|
||||
|
||||
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
|
||||
|
||||
### Train
|
||||
|
||||
```
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.job.name=pusht \
|
||||
env=pusht
|
||||
On Mac:
|
||||
```bash
|
||||
brew install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
### Visualize offline buffer
|
||||
|
||||
```
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
hydra.run.dir=tmp/$(date +"%Y_%m_%d") \
|
||||
env=pusht
|
||||
On Ubuntu:
|
||||
```bash
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
### Visualize online buffer / Eval
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval.py \
|
||||
hydra.run.dir=tmp/$(date +"%Y_%m_%d") \
|
||||
env=pusht
|
||||
Pull artifacts if they're not in [tests/data](tests/data)
|
||||
```bash
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
When adding a new dataset, mock it with
|
||||
```bash
|
||||
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
|
||||
```
|
||||
|
||||
## TODO
|
||||
Run tests
|
||||
```bash
|
||||
DATA_DIR="tests/data" pytest -sx tests
|
||||
```
|
||||
|
||||
- [x] priority update doesnt match FOWM or original paper
|
||||
- [x] self.step=100000 should be updated at every step to adjust to horizon of planner
|
||||
- [ ] prefetch replay buffer to speedup training
|
||||
- [ ] parallelize env to speedup eval
|
||||
- [ ] clean checkpointing / loading
|
||||
- [ ] clean logging
|
||||
- [ ] clean config
|
||||
- [ ] clean hyperparameter tuning
|
||||
- [ ] add pusht
|
||||
- [ ] add aloha
|
||||
- [ ] add act
|
||||
- [ ] add diffusion
|
||||
- [ ] add aloha 2
|
||||
### Add a new dataset
|
||||
|
||||
## Profile
|
||||
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
**Example**
|
||||
Then you can upload it to the hub with:
|
||||
```bash
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
|
||||
--repo-type dataset \
|
||||
--revision v1.0
|
||||
```
|
||||
|
||||
You will need to set the corresponding version as a default argument in your dataset class:
|
||||
```python
|
||||
version: str | None = "v1.0",
|
||||
```
|
||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
||||
|
||||
For instance, for [lerobot/pusht](https://huggingface.co/datasets/lerobot/pusht), we used:
|
||||
```bash
|
||||
HF_USER=lerobot
|
||||
DATASET=pusht
|
||||
```
|
||||
|
||||
If you want to improve an existing dataset, you can download it locally with:
|
||||
```bash
|
||||
mkdir -p data/$DATASET
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download ${HF_USER}/$DATASET \
|
||||
--repo-type dataset \
|
||||
--local-dir data/$DATASET \
|
||||
--local-dir-use-symlinks=False \
|
||||
--revision v1.0
|
||||
```
|
||||
|
||||
Iterate on your code and dataset with:
|
||||
```bash
|
||||
DATA_DIR=data python train.py
|
||||
```
|
||||
|
||||
Upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant):
|
||||
```bash
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
|
||||
--repo-type dataset \
|
||||
--revision v1.1 \
|
||||
--delete "*"
|
||||
```
|
||||
|
||||
Then you will need to set the corresponding version as a default argument in your dataset class:
|
||||
```python
|
||||
version: str | None = "v1.1",
|
||||
```
|
||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
||||
|
||||
|
||||
Finally, you might want to mock the dataset if you need to update the unit tests as well:
|
||||
```bash
|
||||
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
|
||||
```
|
||||
|
||||
### Add a pretrained policy
|
||||
|
||||
Once you have trained a policy you may upload it to the HuggingFace hub.
|
||||
|
||||
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
|
||||
|
||||
Secondly, assuming you have trained a policy, you need:
|
||||
|
||||
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
|
||||
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
|
||||
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
|
||||
|
||||
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
|
||||
|
||||
```
|
||||
to_upload
|
||||
├── config.yaml
|
||||
├── model.pt
|
||||
└── stats.pth
|
||||
```
|
||||
|
||||
With the folder prepared, run the following with a desired revision ID.
|
||||
|
||||
```bash
|
||||
huggingface-cli upload $HUB_ID to_upload --revision $REVISION_ID
|
||||
```
|
||||
|
||||
If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers):
|
||||
|
||||
```bash
|
||||
huggingface-cli upload $HUB_ID to_upload
|
||||
```
|
||||
|
||||
See `eval.py` for an example of how a user may use your policy.
|
||||
|
||||
|
||||
### Improve your code with profiling
|
||||
|
||||
An example of a code snippet to profile the evaluation of a policy:
|
||||
```python
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
@@ -96,25 +373,12 @@ with profile(
|
||||
with record_function("eval_policy"):
|
||||
for i in range(num_episodes):
|
||||
prof.step()
|
||||
# insert code to profile, potentially whole body of eval_policy function
|
||||
```
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
pretrained_model_path=/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt \
|
||||
--config outputs/pusht/.hydra/config.yaml \
|
||||
pretrained_model_path=outputs/pusht/model.pt \
|
||||
eval_episodes=7
|
||||
```
|
||||
|
||||
## Contribute
|
||||
|
||||
**Style**
|
||||
```
|
||||
# install if needed
|
||||
pre-commit install
|
||||
# apply style and linter checks before git commit
|
||||
pre-commit run -a
|
||||
```
|
||||
|
||||
**Tests**
|
||||
```
|
||||
pytest -sx tests
|
||||
```
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
name: lerobot
|
||||
dependencies:
|
||||
- python=3.8.16
|
||||
- pytorch::pytorch=1.13.1
|
||||
- pytorch::torchvision=0.14.1
|
||||
- nvidia::cudatoolkit=11.7
|
||||
- anaconda::pip
|
||||
- pip:
|
||||
- cython==0.29.33
|
||||
- mujoco==2.3.2
|
||||
- mujoco-py==2.1.2.14
|
||||
- termcolor
|
||||
- omegaconf
|
||||
- gym==0.21.0
|
||||
- dm-env==1.6
|
||||
- pandas
|
||||
- wandb
|
||||
- moviepy
|
||||
- imageio
|
||||
- gdown
|
||||
# - -e benchmarks/d4rl
|
||||
# TODO: verify this works
|
||||
- git+https://github.com/nicklashansen/simxarm.git@main#egg=simxarm
|
||||
24
examples/1_visualize_dataset.py
Normal file
24
examples/1_visualize_dataset.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import os
|
||||
|
||||
from torchrl.data.replay_buffers import SamplerWithoutReplacement
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
from lerobot.scripts.visualize_dataset import render_dataset
|
||||
|
||||
print(lerobot.available_datasets)
|
||||
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
||||
|
||||
# we use this sampler to sample 1 frame after the other
|
||||
sampler = SamplerWithoutReplacement(shuffle=False)
|
||||
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR"))
|
||||
|
||||
video_paths = render_dataset(
|
||||
dataset,
|
||||
out_dir="outputs/visualize_dataset/example",
|
||||
max_num_samples=300,
|
||||
fps=50,
|
||||
)
|
||||
print(video_paths)
|
||||
# ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||
39
examples/2_evaluate_pretrained_policy.py
Normal file
39
examples/2_evaluate_pretrained_policy.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
||||
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
from lerobot.scripts.eval import eval
|
||||
|
||||
# Get a pretrained policy from the hub.
|
||||
hub_id = "lerobot/diffusion_policy_pusht_image"
|
||||
folder = Path(snapshot_download(hub_id))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# folder = Path("outputs/train/example_pusht_diffusion")
|
||||
|
||||
config_path = folder / "config.yaml"
|
||||
weights_path = folder / "model.pt"
|
||||
stats_path = folder / "stats.pth" # normalization stats
|
||||
|
||||
# Override some config parameters to do with evaluation.
|
||||
overrides = [
|
||||
f"policy.pretrained_model_path={weights_path}",
|
||||
"eval_episodes=10",
|
||||
"rollout_batch_size=10",
|
||||
"device=cuda",
|
||||
]
|
||||
|
||||
# Create a Hydra config.
|
||||
cfg = init_hydra_config(config_path, overrides)
|
||||
|
||||
# Evaluate the policy and save the outputs including metrics and videos.
|
||||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
|
||||
stats_path=stats_path,
|
||||
)
|
||||
55
examples/3_train_policy.py
Normal file
55
examples/3_train_policy.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
||||
|
||||
Once you have trained a model with this script, you can try to evaluate it on
|
||||
examples/2_evaluate_pretrained_policy.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
|
||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||
os.makedirs(output_directory, exist_ok=True)
|
||||
|
||||
overrides = [
|
||||
"env=pusht",
|
||||
"policy=diffusion",
|
||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||
"offline_steps=5000",
|
||||
"log_freq=250",
|
||||
"device=cuda",
|
||||
]
|
||||
|
||||
cfg = init_hydra_config("lerobot/configs/default.yaml", overrides)
|
||||
|
||||
policy = DiffusionPolicy(
|
||||
cfg=cfg.policy,
|
||||
cfg_device=cfg.device,
|
||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
policy.train()
|
||||
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
for offline_step in trange(cfg.offline_steps):
|
||||
train_info = policy.update(offline_buffer, offline_step)
|
||||
if offline_step % cfg.log_freq == 0:
|
||||
print(train_info)
|
||||
|
||||
# Save the policy, configuration, and normalization stats for later use.
|
||||
policy.save_pretrained(output_directory / "model.pt")
|
||||
OmegaConf.save(cfg, output_directory / "config.yaml")
|
||||
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
|
||||
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
|
||||
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import lerobot
|
||||
print(lerobot.available_envs)
|
||||
print(lerobot.available_tasks_per_env)
|
||||
print(lerobot.available_datasets_per_env)
|
||||
print(lerobot.available_datasets)
|
||||
print(lerobot.available_policies)
|
||||
```
|
||||
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
from lerobot.__version__ import __version__ # noqa: F401
|
||||
|
||||
available_envs = [
|
||||
"aloha",
|
||||
"pusht",
|
||||
"simxarm",
|
||||
]
|
||||
|
||||
available_tasks_per_env = {
|
||||
"aloha": [
|
||||
"sim_insertion",
|
||||
"sim_transfer_cube",
|
||||
],
|
||||
"pusht": ["pusht"],
|
||||
"simxarm": ["lift"],
|
||||
}
|
||||
|
||||
available_datasets_per_env = {
|
||||
"aloha": [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
],
|
||||
"pusht": ["pusht"],
|
||||
"simxarm": ["xarm_lift_medium"],
|
||||
}
|
||||
|
||||
available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]]
|
||||
|
||||
available_policies = [
|
||||
"act",
|
||||
"diffusion",
|
||||
"tdmpc",
|
||||
]
|
||||
|
||||
@@ -1 +1,8 @@
|
||||
__version__ = "0.0.0"
|
||||
"""To enable `lerobot.__version__`"""
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
try:
|
||||
__version__ = version("lerobot")
|
||||
except PackageNotFoundError:
|
||||
__version__ = "unknown"
|
||||
|
||||
207
lerobot/common/datasets/abstract.py
Normal file
207
lerobot/common/datasets/abstract.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.envs.transforms.transforms import Compose
|
||||
|
||||
HF_USER = "lerobot"
|
||||
|
||||
|
||||
class AbstractDataset(TensorDictReplayBuffer):
|
||||
"""
|
||||
AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning.
|
||||
This class is designed to be subclassed by concrete implementations that specify particular types of datasets.
|
||||
These implementations can vary based on the source of the data, the environment the data pertains to,
|
||||
or the specific kind of data manipulation applied.
|
||||
|
||||
Note:
|
||||
- `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational
|
||||
functionality for storing and retrieving `TensorDict`-like data.
|
||||
- `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported.
|
||||
It is expected that these variants correspond to a HuggingFace dataset on the hub.
|
||||
For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants:
|
||||
- [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
||||
- [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
||||
- [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
||||
- [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
||||
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
available_datasets: list[str] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = None,
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
assert (
|
||||
self.available_datasets is not None
|
||||
), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute."
|
||||
assert (
|
||||
dataset_id in self.available_datasets
|
||||
), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}."
|
||||
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.shuffle = shuffle
|
||||
self.root = root if root is None else Path(root)
|
||||
|
||||
if self.root is not None and self.version is not None:
|
||||
logging.warning(
|
||||
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
|
||||
)
|
||||
|
||||
storage = self._download_or_load_dataset()
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=ImmutableDatasetWriter() if writer is None else writer,
|
||||
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def stats_patterns(self) -> dict:
|
||||
return {
|
||||
("observation", "state"): "b c -> c",
|
||||
("observation", "image"): "b c h w -> c 1 1",
|
||||
("action",): "b c -> c",
|
||||
}
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list:
|
||||
return [("observation", "image")]
|
||||
|
||||
@property
|
||||
def num_cameras(self) -> int:
|
||||
return len(self.image_keys)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self._storage._storage["episode"].unique())
|
||||
|
||||
@property
|
||||
def transform(self):
|
||||
return self._transform
|
||||
|
||||
def set_transform(self, transform):
|
||||
if not isinstance(transform, Compose):
|
||||
# required since torchrl calls `len(self._transform)` downstream
|
||||
if isinstance(transform, list):
|
||||
self._transform = Compose(*transform)
|
||||
else:
|
||||
self._transform = Compose(transform)
|
||||
else:
|
||||
self._transform = transform
|
||||
|
||||
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
||||
stats_path = self.data_dir / "stats.pth"
|
||||
if stats_path.exists():
|
||||
stats = torch.load(stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
stats = self._compute_stats(num_batch, batch_size)
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
def _download_or_load_dataset(self) -> torch.StorageBase:
|
||||
if self.root is None:
|
||||
self.data_dir = Path(
|
||||
snapshot_download(
|
||||
repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.data_dir = self.root / self.dataset_id
|
||||
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
|
||||
|
||||
def _compute_stats(self, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
storage=self._storage,
|
||||
batch_size=batch_size,
|
||||
prefetch=True,
|
||||
)
|
||||
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
|
||||
# compute mean, min, max
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
batch = rb.sample()
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
if key not in mean:
|
||||
# first batch initialize mean, min, max
|
||||
mean[key] = einops.reduce(batch[key], pattern, "mean")
|
||||
max[key] = einops.reduce(batch[key], pattern, "max")
|
||||
min[key] = einops.reduce(batch[key], pattern, "min")
|
||||
else:
|
||||
mean[key] += einops.reduce(batch[key], pattern, "mean")
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
batch = rb.sample()
|
||||
|
||||
for key in self.stats_patterns:
|
||||
mean[key] /= num_batch
|
||||
|
||||
# compute std, min, max
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
batch = rb.sample()
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
if key not in std:
|
||||
# first batch initialize std
|
||||
std[key] = (batch_mean - mean[key]) ** 2
|
||||
else:
|
||||
std[key] += (batch_mean - mean[key]) ** 2
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
for key in self.stats_patterns:
|
||||
std[key] = torch.sqrt(std[key] / num_batch)
|
||||
|
||||
stats = TensorDict({}, batch_size=[])
|
||||
for key in self.stats_patterns:
|
||||
stats[(*key, "mean")] = mean[key]
|
||||
stats[(*key, "std")] = std[key]
|
||||
stats[(*key, "max")] = max[key]
|
||||
stats[(*key, "min")] = min[key]
|
||||
|
||||
if key[0] == "observation":
|
||||
# use same stats for the next observations
|
||||
stats[("next", *key)] = stats[key]
|
||||
return stats
|
||||
185
lerobot/common/datasets/aloha.py
Normal file
185
lerobot/common/datasets/aloha.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import gdown
|
||||
import h5py
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
|
||||
DATASET_IDS = [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
|
||||
FOLDER_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
|
||||
}
|
||||
|
||||
EP48_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
|
||||
}
|
||||
|
||||
EP49_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
|
||||
}
|
||||
|
||||
NUM_EPISODES = {
|
||||
"aloha_sim_insertion_human": 50,
|
||||
"aloha_sim_insertion_scripted": 50,
|
||||
"aloha_sim_transfer_cube_human": 50,
|
||||
"aloha_sim_transfer_cube_scripted": 50,
|
||||
}
|
||||
|
||||
EPISODE_LEN = {
|
||||
"aloha_sim_insertion_human": 500,
|
||||
"aloha_sim_insertion_scripted": 400,
|
||||
"aloha_sim_transfer_cube_human": 400,
|
||||
"aloha_sim_transfer_cube_scripted": 400,
|
||||
}
|
||||
|
||||
CAMERAS = {
|
||||
"aloha_sim_insertion_human": ["top"],
|
||||
"aloha_sim_insertion_scripted": ["top"],
|
||||
"aloha_sim_transfer_cube_human": ["top"],
|
||||
"aloha_sim_transfer_cube_scripted": ["top"],
|
||||
}
|
||||
|
||||
|
||||
def download(data_dir, dataset_id):
|
||||
assert dataset_id in DATASET_IDS
|
||||
assert dataset_id in FOLDER_URLS
|
||||
assert dataset_id in EP48_URLS
|
||||
assert dataset_id in EP49_URLS
|
||||
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
|
||||
|
||||
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
||||
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
|
||||
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
|
||||
|
||||
class AlohaDataset(AbstractDataset):
|
||||
available_datasets = DATASET_IDS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
version,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def stats_patterns(self) -> dict:
|
||||
d = {
|
||||
("observation", "state"): "b c -> c",
|
||||
("action",): "b c -> c",
|
||||
}
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
d[("observation", "image", cam)] = "b c h w -> c 1 1"
|
||||
return d
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list:
|
||||
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
if not raw_dir.is_dir():
|
||||
download(raw_dir, self.dataset_id)
|
||||
|
||||
total_num_frames = 0
|
||||
logging.info("Compute total number of frames to initialize offline buffer")
|
||||
for ep_id in range(NUM_EPISODES[self.dataset_id]):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
total_num_frames += ep["/action"].shape[0] - 1
|
||||
logging.info(f"{total_num_frames=}")
|
||||
|
||||
logging.info("Initialize and feed offline buffer")
|
||||
idxtd = 0
|
||||
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
ep_num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(ep_num_frames, 1, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "state"): state[:-1],
|
||||
"action": action[:-1],
|
||||
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
|
||||
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
|
||||
("next", "observation", "state"): state[1:],
|
||||
# TODO: compute reward and success
|
||||
# ("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
# ("next", "success"): success[1:],
|
||||
},
|
||||
batch_size=ep_num_frames - 1,
|
||||
)
|
||||
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||
ep_td["observation", "image", cam] = image[:-1]
|
||||
ep_td["next", "observation", "image", cam] = image[1:]
|
||||
|
||||
if ep_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
idxtd = idxtd + len(ep_td)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
@@ -1,36 +1,27 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||
|
||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
||||
|
||||
# TODO(rcadene): implement
|
||||
|
||||
# dataset_d4rl = D4RLExperienceReplay(
|
||||
# dataset_id="maze2d-umaze-v1",
|
||||
# split_trajs=False,
|
||||
# batch_size=1,
|
||||
# sampler=SamplerWithoutReplacement(drop_last=False),
|
||||
# prefetch=4,
|
||||
# direct_download=True,
|
||||
# )
|
||||
|
||||
# dataset_openx = OpenXExperienceReplay(
|
||||
# "cmu_stretch",
|
||||
# batch_size=1,
|
||||
# num_slices=1,
|
||||
# #download="force",
|
||||
# streaming=False,
|
||||
# root="data",
|
||||
# )
|
||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
||||
# to load a subset of our datasets for faster continuous integration.
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_offline_buffer(cfg, sampler=None):
|
||||
def make_offline_buffer(
|
||||
cfg,
|
||||
overwrite_sampler=None,
|
||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||
normalize=True,
|
||||
overwrite_batch_size=None,
|
||||
overwrite_prefetch=None,
|
||||
stats_path=None,
|
||||
):
|
||||
if cfg.policy.balanced_sampling:
|
||||
assert cfg.online_steps > 0
|
||||
batch_size = None
|
||||
@@ -43,50 +34,103 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
pin_memory = cfg.device == "cuda"
|
||||
prefetch = cfg.prefetch
|
||||
|
||||
overwrite_sampler = sampler is not None
|
||||
if overwrite_batch_size is not None:
|
||||
batch_size = overwrite_batch_size
|
||||
|
||||
if not overwrite_sampler:
|
||||
if overwrite_prefetch is not None:
|
||||
prefetch = overwrite_prefetch
|
||||
|
||||
if overwrite_sampler is None:
|
||||
# TODO(rcadene): move batch_size outside
|
||||
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.policy.per_alpha,
|
||||
beta=cfg.policy.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
|
||||
if cfg.offline_prioritized_sampler:
|
||||
logging.info("use prioritized sampler for offline dataset")
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.policy.per_alpha,
|
||||
beta=cfg.policy.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
else:
|
||||
logging.info("use simple sampler for offline dataset")
|
||||
sampler = SliceSampler(
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
else:
|
||||
sampler = overwrite_sampler
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
||||
offline_buffer = SimxarmExperienceReplay(
|
||||
f"xarm_{cfg.env.task}_medium",
|
||||
# download="force",
|
||||
download=True,
|
||||
streaming=False,
|
||||
root=str(DATA_DIR),
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
from lerobot.common.datasets.simxarm import SimxarmDataset
|
||||
|
||||
clsfunc = SimxarmDataset
|
||||
|
||||
elif cfg.env.name == "pusht":
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
"pusht",
|
||||
streaming=False,
|
||||
root=DATA_DIR,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
|
||||
clsfunc = PushtDataset
|
||||
|
||||
elif cfg.env.name == "aloha":
|
||||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
|
||||
clsfunc = AlohaDataset
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
offline_buffer = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
root=DATA_DIR,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
img_keys = []
|
||||
for key in offline_buffer.image_keys:
|
||||
img_keys.append(("next", *key))
|
||||
img_keys += offline_buffer.image_keys
|
||||
else:
|
||||
img_keys = offline_buffer.image_keys
|
||||
|
||||
if normalize:
|
||||
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
||||
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||
|
||||
# we only normalize the state and action, since the images are usually normalized inside the model for
|
||||
# now (except for tdmpc: see the following)
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
|
||||
in_keys += img_keys
|
||||
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
|
||||
in_keys += [("next", *key) for key in img_keys]
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
||||
|
||||
offline_buffer.set_transform(transforms)
|
||||
|
||||
if not overwrite_sampler:
|
||||
num_steps = len(offline_buffer)
|
||||
index = torch.arange(0, num_steps, 1)
|
||||
index = torch.arange(0, offline_buffer.num_samples, 1)
|
||||
sampler.extend(index)
|
||||
|
||||
return offline_buffer
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
@@ -12,21 +9,18 @@ import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||
from lerobot.common.envs.transforms import NormalizeTransform
|
||||
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
||||
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
|
||||
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||
|
||||
@@ -54,8 +48,10 @@ def add_tee(
|
||||
angle,
|
||||
scale=30,
|
||||
color="LightSlateGray",
|
||||
mask=DEFAULT_TEE_MASK,
|
||||
mask=None,
|
||||
):
|
||||
if mask is None:
|
||||
mask = pymunk.ShapeFilter.ALL_MASKS()
|
||||
mass = 1
|
||||
length = 4
|
||||
vertices1 = [
|
||||
@@ -87,114 +83,41 @@ def add_tee(
|
||||
return body
|
||||
|
||||
|
||||
class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
class PushtDataset(AbstractDataset):
|
||||
available_datasets = ["pusht"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
batch_size: int = None,
|
||||
version: str | None = "v1.2",
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
num_slices: int = None,
|
||||
slice_len: int = None,
|
||||
pad: float = None,
|
||||
replacement: bool = None,
|
||||
streaming: bool = False,
|
||||
root: Path = None,
|
||||
sampler: Sampler = None,
|
||||
writer: Writer = None,
|
||||
collate_fn: Callable = None,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
transform: "torchrl.envs.Transform" = None, # noqa: F821
|
||||
split_trajs: bool = False,
|
||||
strict_length: bool = True,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
if streaming:
|
||||
raise NotImplementedError
|
||||
self.streaming = streaming
|
||||
self.dataset_id = dataset_id
|
||||
self.split_trajs = split_trajs
|
||||
self.shuffle = shuffle
|
||||
self.num_slices = num_slices
|
||||
self.slice_len = slice_len
|
||||
self.pad = pad
|
||||
|
||||
self.strict_length = strict_length
|
||||
if (self.num_slices is not None) and (self.slice_len is not None):
|
||||
raise ValueError("num_slices or slice_len can be not None, but not both.")
|
||||
if split_trajs:
|
||||
raise NotImplementedError
|
||||
|
||||
if root is None:
|
||||
root = _get_root_dir("pusht")
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
self.root = root
|
||||
if not self._is_downloaded():
|
||||
storage = self._download_and_preproc()
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||
|
||||
stats = self._compute_or_load_stats(storage)
|
||||
transform = NormalizeTransform(
|
||||
stats,
|
||||
in_keys=[
|
||||
# TODO(rcadene): imagenet normalization is applied inside diffusion policy
|
||||
# We need to automate this for tdmpc and others
|
||||
# ("observation", "image"),
|
||||
("observation", "state"),
|
||||
# TODO(rcadene): for tdmpc, we might want next image and state
|
||||
# ("next", "observation", "image"),
|
||||
# ("next", "observation", "state"),
|
||||
("action"),
|
||||
],
|
||||
mode="min_max",
|
||||
)
|
||||
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
transform.stats["observation", "state", "min"] = torch.tensor(
|
||||
[13.456424, 32.938293], dtype=torch.float32
|
||||
)
|
||||
transform.stats["observation", "state", "max"] = torch.tensor(
|
||||
[496.14618, 510.9579], dtype=torch.float32
|
||||
)
|
||||
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
if collate_fn is None:
|
||||
collate_fn = _collate_id
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=writer,
|
||||
collate_fn=collate_fn,
|
||||
dataset_id,
|
||||
version,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self._storage._storage["episode"].unique())
|
||||
|
||||
@property
|
||||
def data_path_root(self) -> Path:
|
||||
return None if self.streaming else self.root / self.dataset_id
|
||||
|
||||
def _is_downloaded(self) -> bool:
|
||||
return self.data_path_root.is_dir()
|
||||
|
||||
def _download_and_preproc(self):
|
||||
# download
|
||||
raw_dir = self.root / "raw"
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -208,6 +131,9 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
total_frames = dataset_dict["action"].shape[0]
|
||||
# to create test artifact
|
||||
# num_episodes = 1
|
||||
# total_frames = 50
|
||||
assert len(
|
||||
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
@@ -225,6 +151,8 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
idxtd = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
# to create test artifact
|
||||
# idx1 = 51
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
@@ -266,8 +194,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
print("before " + """episode = TensorDict(""")
|
||||
episode = TensorDict(
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "image"): image[:-1],
|
||||
("observation", "state"): agent_pos[:-1],
|
||||
@@ -286,120 +213,11 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idxtd : idxtd + len(episode)] = episode
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
|
||||
idx0 = idx1
|
||||
idxtd = idxtd + len(episode)
|
||||
idxtd = idxtd + len(ep_td)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
|
||||
def _compute_stats(self, storage, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
storage=storage,
|
||||
batch_size=batch_size,
|
||||
prefetch=True,
|
||||
)
|
||||
batch = rb.sample()
|
||||
|
||||
image_channels = batch["observation", "image"].shape[1]
|
||||
image_mean = torch.zeros(image_channels)
|
||||
image_std = torch.zeros(image_channels)
|
||||
image_max = torch.tensor([-math.inf] * image_channels)
|
||||
image_min = torch.tensor([math.inf] * image_channels)
|
||||
|
||||
state_channels = batch["observation", "state"].shape[1]
|
||||
state_mean = torch.zeros(state_channels)
|
||||
state_std = torch.zeros(state_channels)
|
||||
state_max = torch.tensor([-math.inf] * state_channels)
|
||||
state_min = torch.tensor([math.inf] * state_channels)
|
||||
|
||||
action_channels = batch["action"].shape[1]
|
||||
action_mean = torch.zeros(action_channels)
|
||||
action_std = torch.zeros(action_channels)
|
||||
action_max = torch.tensor([-math.inf] * action_channels)
|
||||
action_min = torch.tensor([math.inf] * action_channels)
|
||||
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||
state_mean += einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||
action_mean += einops.reduce(batch["action"], "b c -> c", "mean")
|
||||
|
||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||
image_max = torch.maximum(image_max, b_image_max)
|
||||
image_min = torch.maximum(image_min, b_image_min)
|
||||
state_max = torch.maximum(state_max, b_state_max)
|
||||
state_min = torch.maximum(state_min, b_state_min)
|
||||
action_max = torch.maximum(action_max, b_action_max)
|
||||
action_min = torch.maximum(action_min, b_action_min)
|
||||
|
||||
batch = rb.sample()
|
||||
|
||||
image_mean /= num_batch
|
||||
state_mean /= num_batch
|
||||
action_mean /= num_batch
|
||||
|
||||
for i in tqdm.tqdm(range(num_batch)):
|
||||
b_image_mean = einops.reduce(batch["observation", "image"], "b c h w -> c", "mean")
|
||||
b_state_mean = einops.reduce(batch["observation", "state"], "b c -> c", "mean")
|
||||
b_action_mean = einops.reduce(batch["action"], "b c -> c", "mean")
|
||||
image_std += (b_image_mean - image_mean) ** 2
|
||||
state_std += (b_state_mean - state_mean) ** 2
|
||||
action_std += (b_action_mean - action_mean) ** 2
|
||||
|
||||
b_image_max = einops.reduce(batch["observation", "image"], "b c h w -> c", "max")
|
||||
b_image_min = einops.reduce(batch["observation", "image"], "b c h w -> c", "min")
|
||||
b_state_max = einops.reduce(batch["observation", "state"], "b c -> c", "max")
|
||||
b_state_min = einops.reduce(batch["observation", "state"], "b c -> c", "min")
|
||||
b_action_max = einops.reduce(batch["action"], "b c -> c", "max")
|
||||
b_action_min = einops.reduce(batch["action"], "b c -> c", "min")
|
||||
image_max = torch.maximum(image_max, b_image_max)
|
||||
image_min = torch.maximum(image_min, b_image_min)
|
||||
state_max = torch.maximum(state_max, b_state_max)
|
||||
state_min = torch.maximum(state_min, b_state_min)
|
||||
action_max = torch.maximum(action_max, b_action_max)
|
||||
action_min = torch.maximum(action_min, b_action_min)
|
||||
|
||||
if i < num_batch - 1:
|
||||
batch = rb.sample()
|
||||
|
||||
image_std = torch.sqrt(image_std / num_batch)
|
||||
state_std = torch.sqrt(state_std / num_batch)
|
||||
action_std = torch.sqrt(action_std / num_batch)
|
||||
|
||||
stats = TensorDict(
|
||||
{
|
||||
("observation", "image", "mean"): image_mean[None, :, None, None],
|
||||
("observation", "image", "std"): image_std[None, :, None, None],
|
||||
("observation", "image", "max"): image_max[None, :, None, None],
|
||||
("observation", "image", "min"): image_min[None, :, None, None],
|
||||
("observation", "state", "mean"): state_mean[None, :],
|
||||
("observation", "state", "std"): state_std[None, :],
|
||||
("observation", "state", "max"): state_max[None, :],
|
||||
("observation", "state", "min"): state_min[None, :],
|
||||
("action", "mean"): action_mean[None, :],
|
||||
("action", "std"): action_std[None, :],
|
||||
("action", "max"): action_max[None, :],
|
||||
("action", "min"): action_min[None, :],
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
stats["next", "observation", "image"] = stats["observation", "image"]
|
||||
stats["next", "observation", "state"] = stats["observation", "state"]
|
||||
return stats
|
||||
|
||||
def _compute_or_load_stats(self, storage) -> TensorDict:
|
||||
stats_path = self.root / self.dataset_id / "stats.pth"
|
||||
if stats_path.exists():
|
||||
stats = torch.load(stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
stats = self._compute_stats(storage)
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
import pickle
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
@@ -7,130 +7,71 @@ import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
Sampler,
|
||||
SliceSampler,
|
||||
SliceSamplerWithoutReplacement,
|
||||
)
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
|
||||
|
||||
class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
def download():
|
||||
raise NotImplementedError()
|
||||
import gdown
|
||||
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
download_path = "data.zip"
|
||||
gdown.download(url, download_path, quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(download_path, "r") as zip_f:
|
||||
for member in zip_f.namelist():
|
||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||
print(member)
|
||||
zip_f.extract(member=member)
|
||||
Path(download_path).unlink()
|
||||
|
||||
|
||||
class SimxarmDataset(AbstractDataset):
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id,
|
||||
batch_size: int = None,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.1",
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
num_slices: int = None,
|
||||
slice_len: int = None,
|
||||
pad: float = None,
|
||||
replacement: bool = None,
|
||||
streaming: bool = False,
|
||||
root: Path = None,
|
||||
download: bool = False,
|
||||
sampler: Sampler = None,
|
||||
writer: Writer = None,
|
||||
collate_fn: Callable = None,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
transform: "torchrl.envs.Transform" = None, # noqa-F821
|
||||
split_trajs: bool = False,
|
||||
strict_length: bool = True,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
self.download = download
|
||||
if streaming:
|
||||
raise NotImplementedError
|
||||
self.streaming = streaming
|
||||
self.dataset_id = dataset_id
|
||||
self.split_trajs = split_trajs
|
||||
self.shuffle = shuffle
|
||||
self.num_slices = num_slices
|
||||
self.slice_len = slice_len
|
||||
self.pad = pad
|
||||
|
||||
self.strict_length = strict_length
|
||||
if (self.num_slices is not None) and (self.slice_len is not None):
|
||||
raise ValueError("num_slices or slice_len can be not None, but not both.")
|
||||
if split_trajs:
|
||||
raise NotImplementedError
|
||||
|
||||
if root is None:
|
||||
root = _get_root_dir("simxarm")
|
||||
os.makedirs(root, exist_ok=True)
|
||||
self.root = Path(root)
|
||||
if self.download == "force" or (self.download and not self._is_downloaded()):
|
||||
storage = self._download_and_preproc()
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||
|
||||
if num_slices is not None or slice_len is not None:
|
||||
if sampler is not None:
|
||||
raise ValueError("`num_slices` and `slice_len` are exclusive with the `sampler` argument.")
|
||||
|
||||
if replacement:
|
||||
if not self.shuffle:
|
||||
raise RuntimeError("shuffle=False can only be used when replacement=False.")
|
||||
sampler = SliceSampler(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
strict_length=strict_length,
|
||||
)
|
||||
else:
|
||||
sampler = SliceSamplerWithoutReplacement(
|
||||
num_slices=num_slices,
|
||||
slice_len=slice_len,
|
||||
strict_length=strict_length,
|
||||
shuffle=self.shuffle,
|
||||
)
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
if collate_fn is None:
|
||||
collate_fn = _collate_id
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=writer,
|
||||
collate_fn=collate_fn,
|
||||
dataset_id,
|
||||
version,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
return len(self)
|
||||
def _download_and_preproc_obsolete(self):
|
||||
# assert self.root is not None
|
||||
# TODO(rcadene): finish download
|
||||
# download()
|
||||
|
||||
@property
|
||||
def num_episodes(self):
|
||||
return len(self._storage._storage["episode"].unique())
|
||||
|
||||
@property
|
||||
def data_path_root(self):
|
||||
if self.streaming:
|
||||
return None
|
||||
return self.root / self.dataset_id
|
||||
|
||||
def _is_downloaded(self):
|
||||
return os.path.exists(self.data_path_root)
|
||||
|
||||
def _download_and_preproc(self):
|
||||
# download
|
||||
# TODO(rcadene)
|
||||
|
||||
# load
|
||||
dataset_dir = Path("data") / self.dataset_id
|
||||
dataset_path = dataset_dir / "buffer.pkl"
|
||||
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
@@ -164,15 +105,19 @@ class SimxarmExperienceReplay(TensorDictReplayBuffer):
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
("next", "observation", "image"): next_image,
|
||||
("next", "observation", "state"): next_state,
|
||||
("next", "observation", "reward"): next_reward,
|
||||
("next", "observation", "done"): next_done,
|
||||
("next", "reward"): next_reward,
|
||||
("next", "done"): next_done,
|
||||
},
|
||||
batch_size=num_frames,
|
||||
)
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.root / self.dataset_id)
|
||||
td_data = (
|
||||
episode[0]
|
||||
.expand(total_frames)
|
||||
.memmap_like(self.root / f"{self.dataset_id}" / "replay_buffer")
|
||||
)
|
||||
|
||||
td_data[idx0:idx1] = episode
|
||||
|
||||
|
||||
92
lerobot/common/envs/abstract.py
Normal file
92
lerobot/common/envs/abstract.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
from tensordict import TensorDict
|
||||
from torchrl.envs import EnvBase
|
||||
|
||||
from lerobot.common.utils import set_global_seed
|
||||
|
||||
|
||||
class AbstractEnv(EnvBase):
|
||||
"""
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
name: str | None = None # same name should be used to instantiate the environment in factory.py
|
||||
available_tasks: list[str] | None = None # for instance: sim_insertion, sim_transfer_cube, pusht, lift
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
assert self.name is not None, "Subclasses of `AbstractEnv` should set the `name` class attribute."
|
||||
assert (
|
||||
self.available_tasks is not None
|
||||
), "Subclasses of `AbstractEnv` should set the `available_tasks` class attribute."
|
||||
assert (
|
||||
task in self.available_tasks
|
||||
), f"The provided task ({task}) is not on the list of available tasks {self.available_tasks}."
|
||||
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
self.num_prev_obs = num_prev_obs
|
||||
self.num_prev_action = num_prev_action
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
|
||||
self._make_env()
|
||||
self._make_spec()
|
||||
|
||||
# self._next_seed will be used for the next reset. It is recommended that when self.set_seed is called
|
||||
# you store the return value in self._next_seed (it will be a new randomly generated seed).
|
||||
self._next_seed = seed
|
||||
# Don't store the result of this in self._next_seed, as we want to make sure that the first time
|
||||
# self._reset is called, we use seed.
|
||||
self.set_seed(seed)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||
if self.num_prev_action > 0:
|
||||
raise NotImplementedError()
|
||||
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||
|
||||
def render(self, mode="rgb_array", width=640, height=480):
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def _make_env(self):
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def _make_spec(self):
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
set_global_seed(seed)
|
||||
@@ -0,0 +1,59 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
|
||||
<equality>
|
||||
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
</equality>
|
||||
|
||||
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
|
||||
<body name="peg" pos="0.2 0.5 0.05">
|
||||
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
<body name="socket" pos="-0.2 0.5 0.05">
|
||||
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
@@ -0,0 +1,48 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
|
||||
<equality>
|
||||
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
</equality>
|
||||
|
||||
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
|
||||
<body name="box" pos="0.2 0.5 0.05">
|
||||
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
@@ -0,0 +1,53 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body name="peg" pos="0.2 0.5 0.05">
|
||||
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
<body name="socket" pos="-0.2 0.5 0.05">
|
||||
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
@@ -0,0 +1,42 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body name="box" pos="0.2 0.5 0.05">
|
||||
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
38
lerobot/common/envs/aloha/assets/scene.xml
Normal file
38
lerobot/common/envs/aloha/assets/scene.xml
Normal file
@@ -0,0 +1,38 @@
|
||||
<mujocoinclude>
|
||||
<!-- <option timestep='0.0025' iterations="50" tolerance="1e-10" solver="Newton" jacobian="dense" cone="elliptic"/>-->
|
||||
|
||||
<asset>
|
||||
<mesh file="tabletop.stl" name="tabletop" scale="0.001 0.001 0.001"/>
|
||||
</asset>
|
||||
|
||||
<visual>
|
||||
<map fogstart="1.5" fogend="5" force="0.1" znear="0.1"/>
|
||||
<quality shadowsize="4096" offsamples="4"/>
|
||||
<headlight ambient="0.4 0.4 0.4"/>
|
||||
</visual>
|
||||
|
||||
<worldbody>
|
||||
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='-1 -1 1'
|
||||
dir='1 1 -1'/>
|
||||
<light directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='1 -1 1' dir='-1 1 -1'/>
|
||||
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='0 1 1'
|
||||
dir='0 -1 -1'/>
|
||||
|
||||
<body name="table" pos="0 .6 0">
|
||||
<geom group="1" mesh="tabletop" pos="0 0 0" type="mesh" conaffinity="1" contype="1" name="table" rgba="0.2 0.2 0.2 1" />
|
||||
</body>
|
||||
<body name="midair" pos="0 .6 0.2">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="midair" rgba="1 0 0 0"/>
|
||||
</body>
|
||||
|
||||
<camera name="left_pillar" pos="-0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="right_pillar" pos="0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="top" pos="0 0.6 0.8" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="angle" pos="0 0 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="front_close" pos="0 0.2 0.4" fovy="78" mode="targetbody" target="vx300s_left/camera_focus"/>
|
||||
|
||||
</worldbody>
|
||||
|
||||
|
||||
|
||||
</mujocoinclude>
|
||||
3
lerobot/common/envs/aloha/assets/tabletop.stl
Normal file
3
lerobot/common/envs/aloha/assets/tabletop.stl
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:76a1571d1aa36520f2bd81c268991b99816c2a7819464d718e0fd9976fe30dce
|
||||
size 684
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:df73ae5b9058e5d50a6409ac2ab687dade75053a86591bb5e23ab051dbf2d659
|
||||
size 83384
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:56fb3cc1236d4193106038adf8e457c7252ae9e86c7cee6dabf0578c53666358
|
||||
size 83384
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a4baacd9a64df1be60ea5e98f50f3c660e1b7a1fe9684aace6004c5058c09483
|
||||
size 42884
|
||||
3
lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl
Normal file
3
lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a18a1601074d29ed1d546ead70cd18fbb063f1db7b5b96b9f0365be714f3136a
|
||||
size 3884
|
||||
3
lerobot/common/envs/aloha/assets/vx300s_1_base.stl
Normal file
3
lerobot/common/envs/aloha/assets/vx300s_1_base.stl
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d100cafe656671ca8fde98fb6a4cf2d1b746995c51c61c25ad9ea2715635d146
|
||||
size 99984
|
||||
3
lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl
Normal file
3
lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:139745a74055cb0b23430bb5bc032bf68cf7bea5e4975c8f4c04107ae005f7f0
|
||||
size 63884
|
||||
3
lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl
Normal file
3
lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:900f236320dd3d500870c5fde763b2d47502d51e043a5c377875e70237108729
|
||||
size 102984
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4104fc54bbfb8a9b533029f1e7e3ade3d54d638372b3195daa0c98f57e0295b5
|
||||
size 49584
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:66814e27fa728056416e25e02e89eb7d34c51d51c51e7c3df873829037ddc6b8
|
||||
size 99884
|
||||
3
lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl
Normal file
3
lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:90eb145c85627968c3776ae6de23ccff7e112c9dd713c46bc9acdfdaa859a048
|
||||
size 70784
|
||||
3
lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl
Normal file
3
lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:786c1077bfd226f14219581b11d5f19464ca95b17132e0bb7532503568f5af90
|
||||
size 450084
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d1275a93fe2157c83dbc095617fb7e672888bdd48ec070a35ef4ab9ebd9755b0
|
||||
size 31684
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a4de62c9a2ed2c78433010e4c05530a1254b1774a7651967f406120c9bf8973e
|
||||
size 379484
|
||||
17
lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
Normal file
17
lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
Normal file
@@ -0,0 +1,17 @@
|
||||
<mujocoinclude>
|
||||
<compiler angle="radian" inertiafromgeom="auto" inertiagrouprange="4 5"/>
|
||||
<asset>
|
||||
<mesh name="vx300s_1_base" file="vx300s_1_base.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_2_shoulder" file="vx300s_2_shoulder.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_3_upper_arm" file="vx300s_3_upper_arm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_4_upper_forearm" file="vx300s_4_upper_forearm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_5_lower_forearm" file="vx300s_5_lower_forearm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_6_wrist" file="vx300s_6_wrist.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_7_gripper" file="vx300s_7_gripper.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_8_gripper_prop" file="vx300s_8_gripper_prop.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_9_gripper_bar" file="vx300s_9_gripper_bar.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_10_gripper_finger_left" file="vx300s_10_custom_finger_left.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_10_gripper_finger_right" file="vx300s_10_custom_finger_right.stl" scale="0.001 0.001 0.001" />
|
||||
</asset>
|
||||
|
||||
</mujocoinclude>
|
||||
59
lerobot/common/envs/aloha/assets/vx300s_left.xml
Normal file
59
lerobot/common/envs/aloha/assets/vx300s_left.xml
Normal file
@@ -0,0 +1,59 @@
|
||||
|
||||
<mujocoinclude>
|
||||
<body name="vx300s_left" pos="-0.469 0.5 0">
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_left/1_base" contype="0" conaffinity="0"/>
|
||||
<body name="vx300s_left/shoulder_link" pos="0 0 0.079">
|
||||
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
|
||||
<joint name="vx300s_left/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
|
||||
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_left/2_shoulder" />
|
||||
<body name="vx300s_left/upper_arm_link" pos="0 0 0.04805">
|
||||
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
|
||||
<joint name="vx300s_left/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_left/3_upper_arm"/>
|
||||
<body name="vx300s_left/upper_forearm_link" pos="0.05955 0 0.3">
|
||||
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
|
||||
<joint name="vx300s_left/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
|
||||
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_left/4_upper_forearm" />
|
||||
<body name="vx300s_left/lower_forearm_link" pos="0.2 0 0">
|
||||
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
|
||||
<joint name="vx300s_left/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_left/5_lower_forearm"/>
|
||||
<body name="vx300s_left/wrist_link" pos="0.1 0 0">
|
||||
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
|
||||
<joint name="vx300s_left/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_left/6_wrist" />
|
||||
<body name="vx300s_left/gripper_link" pos="0.069744 0 0">
|
||||
<body name="vx300s_left/camera_focus" pos="0.15 0 0.01">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="left_cam_focus" rgba="0 0 1 0"/>
|
||||
</body>
|
||||
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_left_site1" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_left_site2" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_left_site3" rgba="0 0 1 0"/>
|
||||
<camera name="left_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_left/camera_focus"/>
|
||||
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
|
||||
<joint name="vx300s_left/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_left/7_gripper" />
|
||||
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_left/9_gripper_bar" />
|
||||
<body name="vx300s_left/gripper_prop_link" pos="0.0485 0 0">
|
||||
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
|
||||
<!-- <joint name="vx300s_left/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
|
||||
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_left/8_gripper_prop" />
|
||||
</body>
|
||||
<body name="vx300s_left/left_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
|
||||
<joint name="vx300s_left/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_left/10_left_gripper_finger"/>
|
||||
</body>
|
||||
<body name="vx300s_left/right_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
|
||||
<joint name="vx300s_left/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_left/10_right_gripper_finger"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</mujocoinclude>
|
||||
59
lerobot/common/envs/aloha/assets/vx300s_right.xml
Normal file
59
lerobot/common/envs/aloha/assets/vx300s_right.xml
Normal file
@@ -0,0 +1,59 @@
|
||||
|
||||
<mujocoinclude>
|
||||
<body name="vx300s_right" pos="0.469 0.5 0" euler="0 0 3.1416">
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_right/1_base" contype="0" conaffinity="0"/>
|
||||
<body name="vx300s_right/shoulder_link" pos="0 0 0.079">
|
||||
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
|
||||
<joint name="vx300s_right/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
|
||||
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_right/2_shoulder" />
|
||||
<body name="vx300s_right/upper_arm_link" pos="0 0 0.04805">
|
||||
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
|
||||
<joint name="vx300s_right/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_right/3_upper_arm"/>
|
||||
<body name="vx300s_right/upper_forearm_link" pos="0.05955 0 0.3">
|
||||
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
|
||||
<joint name="vx300s_right/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
|
||||
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_right/4_upper_forearm" />
|
||||
<body name="vx300s_right/lower_forearm_link" pos="0.2 0 0">
|
||||
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
|
||||
<joint name="vx300s_right/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_right/5_lower_forearm"/>
|
||||
<body name="vx300s_right/wrist_link" pos="0.1 0 0">
|
||||
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
|
||||
<joint name="vx300s_right/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_right/6_wrist" />
|
||||
<body name="vx300s_right/gripper_link" pos="0.069744 0 0">
|
||||
<body name="vx300s_right/camera_focus" pos="0.15 0 0.01">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="right_cam_focus" rgba="0 0 1 0"/>
|
||||
</body>
|
||||
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_right_site1" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_right_site2" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_right_site3" rgba="0 0 1 0"/>
|
||||
<camera name="right_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_right/camera_focus"/>
|
||||
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
|
||||
<joint name="vx300s_right/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_right/7_gripper" />
|
||||
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_right/9_gripper_bar" />
|
||||
<body name="vx300s_right/gripper_prop_link" pos="0.0485 0 0">
|
||||
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
|
||||
<!-- <joint name="vx300s_right/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
|
||||
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_right/8_gripper_prop" />
|
||||
</body>
|
||||
<body name="vx300s_right/left_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
|
||||
<joint name="vx300s_right/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_right/10_left_gripper_finger"/>
|
||||
</body>
|
||||
<body name="vx300s_right/right_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
|
||||
<joint name="vx300s_right/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_right/10_right_gripper_finger"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</mujocoinclude>
|
||||
163
lerobot/common/envs/aloha/constants.py
Normal file
163
lerobot/common/envs/aloha/constants.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from pathlib import Path
|
||||
|
||||
### Simulation envs fixed constants
|
||||
DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz
|
||||
FPS = 50
|
||||
|
||||
|
||||
JOINTS = [
|
||||
# absolute joint position
|
||||
"left_arm_waist",
|
||||
"left_arm_shoulder",
|
||||
"left_arm_elbow",
|
||||
"left_arm_forearm_roll",
|
||||
"left_arm_wrist_angle",
|
||||
"left_arm_wrist_rotate",
|
||||
# normalized gripper position 0: close, 1: open
|
||||
"left_arm_gripper",
|
||||
# absolute joint position
|
||||
"right_arm_waist",
|
||||
"right_arm_shoulder",
|
||||
"right_arm_elbow",
|
||||
"right_arm_forearm_roll",
|
||||
"right_arm_wrist_angle",
|
||||
"right_arm_wrist_rotate",
|
||||
# normalized gripper position 0: close, 1: open
|
||||
"right_arm_gripper",
|
||||
]
|
||||
|
||||
ACTIONS = [
|
||||
# position and quaternion for end effector
|
||||
"left_arm_waist",
|
||||
"left_arm_shoulder",
|
||||
"left_arm_elbow",
|
||||
"left_arm_forearm_roll",
|
||||
"left_arm_wrist_angle",
|
||||
"left_arm_wrist_rotate",
|
||||
# normalized gripper position (0: close, 1: open)
|
||||
"left_arm_gripper",
|
||||
"right_arm_waist",
|
||||
"right_arm_shoulder",
|
||||
"right_arm_elbow",
|
||||
"right_arm_forearm_roll",
|
||||
"right_arm_wrist_angle",
|
||||
"right_arm_wrist_rotate",
|
||||
# normalized gripper position (0: close, 1: open)
|
||||
"right_arm_gripper",
|
||||
]
|
||||
|
||||
|
||||
START_ARM_POSE = [
|
||||
0,
|
||||
-0.96,
|
||||
1.16,
|
||||
0,
|
||||
-0.3,
|
||||
0,
|
||||
0.02239,
|
||||
-0.02239,
|
||||
0,
|
||||
-0.96,
|
||||
1.16,
|
||||
0,
|
||||
-0.3,
|
||||
0,
|
||||
0.02239,
|
||||
-0.02239,
|
||||
]
|
||||
|
||||
ASSETS_DIR = Path(__file__).parent.resolve() / "assets" # note: absolute path
|
||||
|
||||
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||
|
||||
# Gripper joint limits (qpos[6])
|
||||
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
||||
|
||||
############################ Helper functions ############################
|
||||
|
||||
|
||||
def normalize_master_gripper_position(x):
|
||||
return (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
||||
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def normalize_puppet_gripper_position(x):
|
||||
return (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
||||
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def unnormalize_master_gripper_position(x):
|
||||
return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||
|
||||
|
||||
def unnormalize_puppet_gripper_position(x):
|
||||
return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||
|
||||
|
||||
def convert_position_from_master_to_puppet(x):
|
||||
return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x))
|
||||
|
||||
|
||||
def normalizer_master_gripper_joint(x):
|
||||
return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
|
||||
|
||||
def normalize_puppet_gripper_joint(x):
|
||||
return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
|
||||
|
||||
def unnormalize_master_gripper_joint(x):
|
||||
return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
|
||||
|
||||
def unnormalize_puppet_gripper_joint(x):
|
||||
return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
|
||||
|
||||
def convert_join_from_master_to_puppet(x):
|
||||
return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x))
|
||||
|
||||
|
||||
def normalize_master_gripper_velocity(x):
|
||||
return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
|
||||
def normalize_puppet_gripper_velocity(x):
|
||||
return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
|
||||
def convert_master_from_position_to_joint(x):
|
||||
return (
|
||||
normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
+ MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def convert_master_from_joint_to_position(x):
|
||||
return unnormalize_master_gripper_position(
|
||||
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
|
||||
|
||||
def convert_puppet_from_position_to_join(x):
|
||||
return (
|
||||
normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
+ PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def convert_puppet_from_joint_to_position(x):
|
||||
return unnormalize_puppet_gripper_position(
|
||||
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
298
lerobot/common/envs/aloha/env.py
Normal file
298
lerobot/common/envs/aloha/env.py
Normal file
@@ -0,0 +1,298 @@
|
||||
import importlib
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
BoundedTensorSpec,
|
||||
CompositeSpec,
|
||||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
|
||||
from lerobot.common.envs.abstract import AbstractEnv
|
||||
from lerobot.common.envs.aloha.constants import (
|
||||
ACTIONS,
|
||||
ASSETS_DIR,
|
||||
DT,
|
||||
JOINTS,
|
||||
)
|
||||
from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask
|
||||
from lerobot.common.envs.aloha.tasks.sim_end_effector import (
|
||||
InsertionEndEffectorTask,
|
||||
TransferCubeEndEffectorTask,
|
||||
)
|
||||
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||
from lerobot.common.utils import set_global_seed
|
||||
|
||||
_has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||
|
||||
|
||||
class AlohaEnv(AbstractEnv):
|
||||
name = "aloha"
|
||||
available_tasks = ["sim_insertion", "sim_transfer_cube"]
|
||||
_reset_warning_issued = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(
|
||||
task=task,
|
||||
frame_skip=frame_skip,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=image_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gymnasium.")
|
||||
|
||||
if not self.from_pixels:
|
||||
raise NotImplementedError()
|
||||
|
||||
self._env = self._make_env_task(self.task)
|
||||
|
||||
def render(self, mode="rgb_array", width=640, height=480):
|
||||
# TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
|
||||
image = self._env.physics.render(height=height, width=width, camera_id="top")
|
||||
return image
|
||||
|
||||
def _make_env_task(self, task_name):
|
||||
# time limit is controlled by StepCounter in env factory
|
||||
time_limit = float("inf")
|
||||
|
||||
if "sim_transfer_cube" in task_name:
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = TransferCubeTask(random=False)
|
||||
elif "sim_insertion" in task_name:
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = InsertionTask(random=False)
|
||||
elif "sim_end_effector_transfer_cube" in task_name:
|
||||
raise NotImplementedError()
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = TransferCubeEndEffectorTask(random=False)
|
||||
elif "sim_end_effector_insertion" in task_name:
|
||||
raise NotImplementedError()
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = InsertionEndEffectorTask(random=False)
|
||||
else:
|
||||
raise NotImplementedError(task_name)
|
||||
|
||||
env = control.Environment(
|
||||
physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False
|
||||
)
|
||||
return env
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
image = torch.from_numpy(raw_obs["images"]["top"].copy())
|
||||
image = einops.rearrange(image, "h w c -> c h w")
|
||||
assert image.dtype == torch.uint8
|
||||
obs = {"image": {"top": image}}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32)
|
||||
else:
|
||||
# TODO(rcadene):
|
||||
raise NotImplementedError()
|
||||
# obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||
|
||||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
if tensordict is not None and not AlohaEnv._reset_warning_issued:
|
||||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||
AlohaEnv._reset_warning_issued = True
|
||||
|
||||
# Seed the environment and update the seed to be used for the next reset.
|
||||
self._next_seed = self.set_seed(self._next_seed)
|
||||
|
||||
# TODO(rcadene): do not use global variable for this
|
||||
if "sim_transfer_cube" in self.task:
|
||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||
elif "sim_insertion" in self.task:
|
||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||
|
||||
raw_obs = self._env.reset()
|
||||
|
||||
obs = self._format_raw_obs(raw_obs.observation)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
assert action.ndim == 1
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
|
||||
_, reward, _, raw_obs = self._env.step(action)
|
||||
|
||||
# TODO(rcadene): add an enum
|
||||
success = done = reward == 4
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"]["top"])
|
||||
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"reward": torch.tensor([reward], dtype=torch.float32),
|
||||
# success and done are true when coverage > self.success_threshold in env
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([success], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
return td
|
||||
|
||||
def _make_spec(self):
|
||||
obs = {}
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if self.from_pixels:
|
||||
if isinstance(self.image_size, int):
|
||||
image_shape = (3, self.image_size, self.image_size)
|
||||
elif OmegaConf.is_list(self.image_size) or isinstance(self.image_size, list):
|
||||
assert len(self.image_size) == 3 # c h w
|
||||
assert self.image_size[0] == 3 # c is RGB
|
||||
image_shape = tuple(self.image_size)
|
||||
else:
|
||||
raise ValueError(self.image_size)
|
||||
if self.num_prev_obs > 0:
|
||||
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||
|
||||
obs["image"] = {
|
||||
"top": BoundedTensorSpec(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=image_shape,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
}
|
||||
if not self.pixels_only:
|
||||
state_shape = (len(JOINTS),)
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
# TODO: add low and high bounds
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
state_shape = (len(JOINTS),)
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
# TODO: add low and high bounds
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.observation_spec = CompositeSpec({"observation": obs})
|
||||
|
||||
# TODO(rcadene): valid when controling end effector?
|
||||
# action_space = self._env.action_spec()
|
||||
# self.action_spec = BoundedTensorSpec(
|
||||
# low=action_space.minimum,
|
||||
# high=action_space.maximum,
|
||||
# shape=action_space.shape,
|
||||
# dtype=torch.float32,
|
||||
# device=self.device,
|
||||
# )
|
||||
|
||||
# TODO(rcaene): add bounds (where are they????)
|
||||
self.action_spec = BoundedTensorSpec(
|
||||
shape=(len(ACTIONS)),
|
||||
low=-1,
|
||||
high=1,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||
shape=(1,),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.done_spec = CompositeSpec(
|
||||
{
|
||||
"done": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
"success": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
set_global_seed(seed)
|
||||
# TODO(rcadene): seed the env
|
||||
# self._env.seed(seed)
|
||||
logging.warning("Aloha env is not seeded")
|
||||
219
lerobot/common/envs/aloha/tasks/sim.py
Normal file
219
lerobot/common/envs/aloha/tasks/sim.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
from dm_control.suite import base
|
||||
|
||||
from lerobot.common.envs.aloha.constants import (
|
||||
START_ARM_POSE,
|
||||
normalize_puppet_gripper_position,
|
||||
normalize_puppet_gripper_velocity,
|
||||
unnormalize_puppet_gripper_position,
|
||||
)
|
||||
|
||||
BOX_POSE = [None] # to be changed from outside
|
||||
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with joint position control
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
|
||||
class BimanualViperXTask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
left_arm_action = action[:6]
|
||||
right_arm_action = action[7 : 7 + 6]
|
||||
normalized_left_gripper_action = action[6]
|
||||
normalized_right_gripper_action = action[7 + 6]
|
||||
|
||||
left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action)
|
||||
right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action)
|
||||
|
||||
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
|
||||
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
|
||||
|
||||
env_action = np.concatenate(
|
||||
[left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action]
|
||||
)
|
||||
super().before_step(env_action, physics)
|
||||
return
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos(physics)
|
||||
obs["qvel"] = self.get_qvel(physics)
|
||||
obs["env_state"] = self.get_env_state(physics)
|
||||
obs["images"] = {}
|
||||
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
|
||||
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
|
||||
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
|
||||
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7:] = BOX_POSE[0]
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = (
|
||||
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
)
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = (
|
||||
("socket-1", "table") in all_contact_pairs
|
||||
or ("socket-2", "table") in all_contact_pairs
|
||||
or ("socket-3", "table") in all_contact_pairs
|
||||
or ("socket-4", "table") in all_contact_pairs
|
||||
)
|
||||
peg_touch_socket = (
|
||||
("red_peg", "socket-1") in all_contact_pairs
|
||||
or ("red_peg", "socket-2") in all_contact_pairs
|
||||
or ("red_peg", "socket-3") in all_contact_pairs
|
||||
or ("red_peg", "socket-4") in all_contact_pairs
|
||||
)
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if (
|
||||
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
|
||||
): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
263
lerobot/common/envs/aloha/tasks/sim_end_effector.py
Normal file
263
lerobot/common/envs/aloha/tasks/sim_end_effector.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
from dm_control.suite import base
|
||||
|
||||
from lerobot.common.envs.aloha.constants import (
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
START_ARM_POSE,
|
||||
normalize_puppet_gripper_position,
|
||||
normalize_puppet_gripper_velocity,
|
||||
unnormalize_puppet_gripper_position,
|
||||
)
|
||||
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with end-effector control.
|
||||
Action space: [left_arm_pose (7), # position and quaternion for end effector
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_pose (7), # position and quaternion for end effector
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
|
||||
class BimanualViperXEndEffectorTask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
a_len = len(action) // 2
|
||||
action_left = action[:a_len]
|
||||
action_right = action[a_len:]
|
||||
|
||||
# set mocap position and quat
|
||||
# left
|
||||
np.copyto(physics.data.mocap_pos[0], action_left[:3])
|
||||
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], action_right[:3])
|
||||
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
|
||||
|
||||
# set gripper
|
||||
g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7])
|
||||
g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7])
|
||||
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
|
||||
|
||||
def initialize_robots(self, physics):
|
||||
# reset joint position
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
|
||||
# reset mocap to align with end effector
|
||||
# to obtain these numbers:
|
||||
# (1) make an ee_sim env and reset to the same start_pose
|
||||
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
|
||||
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
|
||||
# repeat the same for right side
|
||||
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
|
||||
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
|
||||
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
|
||||
|
||||
# reset gripper control
|
||||
close_gripper_control = np.array(
|
||||
[
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
]
|
||||
)
|
||||
np.copyto(physics.data.ctrl, close_gripper_control)
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
# note: it is important to do .copy()
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos(physics)
|
||||
obs["qvel"] = self.get_qvel(physics)
|
||||
obs["env_state"] = self.get_env_state(physics)
|
||||
obs["images"] = {}
|
||||
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
|
||||
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
|
||||
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
|
||||
# used in scripted policy to obtain starting pose
|
||||
obs["mocap_pose_left"] = np.concatenate(
|
||||
[physics.data.mocap_pos[0], physics.data.mocap_quat[0]]
|
||||
).copy()
|
||||
obs["mocap_pose_right"] = np.concatenate(
|
||||
[physics.data.mocap_pos[1], physics.data.mocap_quat[1]]
|
||||
).copy()
|
||||
|
||||
# used when replaying joint trajectory
|
||||
obs["gripper_ctrl"] = physics.data.ctrl.copy()
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize box position
|
||||
cube_pose = sample_box_pose()
|
||||
box_start_idx = physics.model.name2id("red_box_joint", "joint")
|
||||
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionEndEffectorTask(BimanualViperXEndEffectorTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize peg and socket position
|
||||
peg_pose, socket_pose = sample_insertion_pose()
|
||||
|
||||
def id2index(j_id):
|
||||
return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
|
||||
|
||||
peg_start_id = physics.model.name2id("red_peg_joint", "joint")
|
||||
peg_start_idx = id2index(peg_start_id)
|
||||
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
socket_start_id = physics.model.name2id("blue_socket_joint", "joint")
|
||||
socket_start_idx = id2index(socket_start_id)
|
||||
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = (
|
||||
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
)
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = (
|
||||
("socket-1", "table") in all_contact_pairs
|
||||
or ("socket-2", "table") in all_contact_pairs
|
||||
or ("socket-3", "table") in all_contact_pairs
|
||||
or ("socket-4", "table") in all_contact_pairs
|
||||
)
|
||||
peg_touch_socket = (
|
||||
("red_peg", "socket-1") in all_contact_pairs
|
||||
or ("red_peg", "socket-2") in all_contact_pairs
|
||||
or ("red_peg", "socket-3") in all_contact_pairs
|
||||
or ("red_peg", "socket-4") in all_contact_pairs
|
||||
)
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if (
|
||||
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
|
||||
): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
39
lerobot/common/envs/aloha/utils.py
Normal file
39
lerobot/common/envs/aloha/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def sample_box_pose():
|
||||
x_range = [0.0, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
cube_quat = np.array([1, 0, 0, 0])
|
||||
return np.concatenate([cube_position, cube_quat])
|
||||
|
||||
|
||||
def sample_insertion_pose():
|
||||
# Peg
|
||||
x_range = [0.1, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
peg_quat = np.array([1, 0, 0, 0])
|
||||
peg_pose = np.concatenate([peg_position, peg_quat])
|
||||
|
||||
# Socket
|
||||
x_range = [-0.2, -0.1]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
socket_quat = np.array([1, 0, 0, 0])
|
||||
socket_pose = np.concatenate([socket_position, socket_quat])
|
||||
|
||||
return peg_pose, socket_pose
|
||||
@@ -1,61 +1,64 @@
|
||||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||
from torchrl.envs import SerialEnv
|
||||
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
||||
|
||||
|
||||
def make_env(cfg, transform=None):
|
||||
"""
|
||||
Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
|
||||
environments. The env therefore returns batches.`
|
||||
"""
|
||||
|
||||
kwargs = {
|
||||
"frame_skip": cfg.env.action_repeat,
|
||||
"from_pixels": cfg.env.from_pixels,
|
||||
"pixels_only": cfg.env.pixels_only,
|
||||
"image_size": cfg.env.image_size,
|
||||
# TODO(rcadene): do we want a specific eval_env_seed?
|
||||
"seed": cfg.seed,
|
||||
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||
}
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
||||
|
||||
kwargs["task"] = cfg.env.task
|
||||
clsfunc = SimxarmEnv
|
||||
elif cfg.env.name == "pusht":
|
||||
from lerobot.common.envs.pusht import PushtEnv
|
||||
from lerobot.common.envs.pusht.env import PushtEnv
|
||||
|
||||
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||
|
||||
clsfunc = PushtEnv
|
||||
elif cfg.env.name == "aloha":
|
||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||
|
||||
kwargs["task"] = cfg.env.task
|
||||
clsfunc = AlohaEnv
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
env = clsfunc(**kwargs)
|
||||
def _make_env(seed):
|
||||
nonlocal kwargs
|
||||
kwargs["seed"] = seed
|
||||
env = clsfunc(**kwargs)
|
||||
|
||||
# limit rollout to max_steps
|
||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
|
||||
# limit rollout to max_steps
|
||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
|
||||
|
||||
if transform is not None:
|
||||
# useful to add normalization
|
||||
env.append_transform(transform)
|
||||
if transform is not None:
|
||||
# useful to add normalization
|
||||
if isinstance(transform, Compose):
|
||||
for tf in transform:
|
||||
env.append_transform(tf.clone())
|
||||
elif isinstance(transform, Transform):
|
||||
env.append_transform(transform.clone())
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return env
|
||||
return env
|
||||
|
||||
|
||||
# def make_env(env_name, frame_skip, device, is_test=False):
|
||||
# env = GymEnv(
|
||||
# env_name,
|
||||
# frame_skip=frame_skip,
|
||||
# from_pixels=True,
|
||||
# pixels_only=False,
|
||||
# device=device,
|
||||
# )
|
||||
# env = TransformedEnv(env)
|
||||
# env.append_transform(NoopResetEnv(noops=30, random=True))
|
||||
# if not is_test:
|
||||
# env.append_transform(EndOfLifeTransform())
|
||||
# env.append_transform(RewardClipping(-1, 1))
|
||||
# env.append_transform(ToTensorImage())
|
||||
# env.append_transform(GrayScale())
|
||||
# env.append_transform(Resize(84, 84))
|
||||
# env.append_transform(CatFrames(N=4, dim=-3))
|
||||
# env.append_transform(RewardSum())
|
||||
# env.append_transform(StepCounter(max_steps=4500))
|
||||
# env.append_transform(DoubleToFloat())
|
||||
# env.append_transform(VecNorm(in_keys=["pixels"]))
|
||||
# return env
|
||||
return SerialEnv(
|
||||
cfg.rollout_batch_size,
|
||||
create_env_fn=_make_env,
|
||||
create_env_kwargs=[
|
||||
{"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||
],
|
||||
)
|
||||
|
||||
245
lerobot/common/envs/pusht/env.py
Normal file
245
lerobot/common/envs/pusht/env.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import importlib
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
BoundedTensorSpec,
|
||||
CompositeSpec,
|
||||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||
|
||||
from lerobot.common.envs.abstract import AbstractEnv
|
||||
from lerobot.common.utils import set_global_seed
|
||||
|
||||
_has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||
|
||||
|
||||
class PushtEnv(AbstractEnv):
|
||||
name = "pusht"
|
||||
available_tasks = ["pusht"]
|
||||
_reset_warning_issued = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task="pusht",
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(
|
||||
task=task,
|
||||
frame_skip=frame_skip,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=image_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gymnasium.")
|
||||
|
||||
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
||||
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||
|
||||
if not self.from_pixels:
|
||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||
from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv
|
||||
|
||||
self._env = PushTImageEnv(render_size=self.image_size)
|
||||
|
||||
def render(self, mode="rgb_array", width=96, height=96, with_marker=True):
|
||||
"""
|
||||
with_marker adds a cursor showing the targeted action for the controller.
|
||||
"""
|
||||
if width != height:
|
||||
raise NotImplementedError()
|
||||
tmp = self._env.render_size
|
||||
if width != self._env.render_size:
|
||||
self._env.render_cache = None
|
||||
self._env.render_size = width
|
||||
out = self._env.render(mode).copy()
|
||||
if with_marker and self._env.latest_action is not None:
|
||||
action = np.array(self._env.latest_action)
|
||||
coord = (action / 512 * self._env.render_size).astype(np.int32)
|
||||
marker_size = int(8 / 96 * self._env.render_size)
|
||||
thickness = int(1 / 96 * self._env.render_size)
|
||||
cv2.drawMarker(
|
||||
out,
|
||||
coord,
|
||||
color=(255, 0, 0),
|
||||
markerType=cv2.MARKER_CROSS,
|
||||
markerSize=marker_size,
|
||||
thickness=thickness,
|
||||
)
|
||||
self._env.render_size = tmp
|
||||
return out
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
image = torch.from_numpy(raw_obs["image"])
|
||||
obs = {"image": image}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
|
||||
else:
|
||||
# TODO:
|
||||
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||
|
||||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
if tensordict is not None and not PushtEnv._reset_warning_issued:
|
||||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||
PushtEnv._reset_warning_issued = True
|
||||
|
||||
# Seed the environment and update the seed to be used for the next reset.
|
||||
self._next_seed = self.set_seed(self._next_seed)
|
||||
raw_obs = self._env.reset()
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
assert action.ndim == 1
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"])
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"reward": torch.tensor([reward], dtype=torch.float32),
|
||||
# success and done are true when coverage > self.success_threshold in env
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([done], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
return td
|
||||
|
||||
def _make_spec(self):
|
||||
obs = {}
|
||||
if self.from_pixels:
|
||||
image_shape = (3, self.image_size, self.image_size)
|
||||
if self.num_prev_obs > 0:
|
||||
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||
|
||||
obs["image"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=image_shape,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
state_shape = self._env.observation_space["agent_pos"].shape
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=512,
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
state_shape = self._env.observation_space["observation"].shape
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
# TODO:
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.observation_spec = CompositeSpec({"observation": obs})
|
||||
|
||||
self.action_spec = _gym_to_torchrl_spec_transform(
|
||||
self._env.action_space,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||
shape=(1,),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.done_spec = CompositeSpec(
|
||||
{
|
||||
"done": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
"success": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
# Set global seed.
|
||||
set_global_seed(seed)
|
||||
# Set PushTImageEnv seed as it relies on it's own internal _seed attribute.
|
||||
self._env.seed(seed)
|
||||
378
lerobot/common/envs/pusht/pusht_env.py
Normal file
378
lerobot/common/envs/pusht/pusht_env.py
Normal file
@@ -0,0 +1,378 @@
|
||||
import collections
|
||||
|
||||
import cv2
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pygame
|
||||
import pymunk
|
||||
import pymunk.pygame_util
|
||||
import shapely.geometry as sg
|
||||
import skimage.transform as st
|
||||
from gymnasium import spaces
|
||||
from pymunk.vec2d import Vec2d
|
||||
|
||||
from lerobot.common.envs.pusht.pymunk_override import DrawOptions
|
||||
|
||||
|
||||
def pymunk_to_shapely(body, shapes):
|
||||
geoms = []
|
||||
for shape in shapes:
|
||||
if isinstance(shape, pymunk.shapes.Poly):
|
||||
verts = [body.local_to_world(v) for v in shape.get_vertices()]
|
||||
verts += [verts[0]]
|
||||
geoms.append(sg.Polygon(verts))
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported shape type {type(shape)}")
|
||||
geom = sg.MultiPolygon(geoms)
|
||||
return geom
|
||||
|
||||
|
||||
class PushTEnv(gym.Env):
|
||||
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
|
||||
reward_range = (0.0, 1.0)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
legacy=True, # compatibility with original
|
||||
block_cog=None,
|
||||
damping=None,
|
||||
render_action=True,
|
||||
render_size=96,
|
||||
reset_to_state=None,
|
||||
):
|
||||
self._seed = None
|
||||
self.seed()
|
||||
self.window_size = ws = 512 # The size of the PyGame window
|
||||
self.render_size = render_size
|
||||
self.sim_hz = 100
|
||||
# Local controller params.
|
||||
self.k_p, self.k_v = 100, 20 # PD control.z
|
||||
self.control_hz = self.metadata["video.frames_per_second"]
|
||||
# legcay set_state for data compatibility
|
||||
self.legacy = legacy
|
||||
|
||||
# agent_pos, block_pos, block_angle
|
||||
self.observation_space = spaces.Box(
|
||||
low=np.array([0, 0, 0, 0, 0], dtype=np.float64),
|
||||
high=np.array([ws, ws, ws, ws, np.pi * 2], dtype=np.float64),
|
||||
shape=(5,),
|
||||
dtype=np.float64,
|
||||
)
|
||||
|
||||
# positional goal for agent
|
||||
self.action_space = spaces.Box(
|
||||
low=np.array([0, 0], dtype=np.float64),
|
||||
high=np.array([ws, ws], dtype=np.float64),
|
||||
shape=(2,),
|
||||
dtype=np.float64,
|
||||
)
|
||||
|
||||
self.block_cog = block_cog
|
||||
self.damping = damping
|
||||
self.render_action = render_action
|
||||
|
||||
"""
|
||||
If human-rendering is used, `self.window` will be a reference
|
||||
to the window that we draw to. `self.clock` will be a clock that is used
|
||||
to ensure that the environment is rendered at the correct framerate in
|
||||
human-mode. They will remain `None` until human-mode is used for the
|
||||
first time.
|
||||
"""
|
||||
self.window = None
|
||||
self.clock = None
|
||||
self.screen = None
|
||||
|
||||
self.space = None
|
||||
self.teleop = None
|
||||
self.render_buffer = None
|
||||
self.latest_action = None
|
||||
self.reset_to_state = reset_to_state
|
||||
|
||||
def reset(self):
|
||||
seed = self._seed
|
||||
self._setup()
|
||||
if self.block_cog is not None:
|
||||
self.block.center_of_gravity = self.block_cog
|
||||
if self.damping is not None:
|
||||
self.space.damping = self.damping
|
||||
|
||||
# use legacy RandomState for compatibility
|
||||
state = self.reset_to_state
|
||||
if state is None:
|
||||
rs = np.random.RandomState(seed=seed)
|
||||
state = np.array(
|
||||
[
|
||||
rs.randint(50, 450),
|
||||
rs.randint(50, 450),
|
||||
rs.randint(100, 400),
|
||||
rs.randint(100, 400),
|
||||
rs.randn() * 2 * np.pi - np.pi,
|
||||
]
|
||||
)
|
||||
self._set_state(state)
|
||||
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
dt = 1.0 / self.sim_hz
|
||||
self.n_contact_points = 0
|
||||
n_steps = self.sim_hz // self.control_hz
|
||||
if action is not None:
|
||||
self.latest_action = action
|
||||
for _ in range(n_steps):
|
||||
# Step PD control.
|
||||
# self.agent.velocity = self.k_p * (act - self.agent.position) # P control works too.
|
||||
acceleration = self.k_p * (action - self.agent.position) + self.k_v * (
|
||||
Vec2d(0, 0) - self.agent.velocity
|
||||
)
|
||||
self.agent.velocity += acceleration * dt
|
||||
|
||||
# Step physics.
|
||||
self.space.step(dt)
|
||||
|
||||
# compute reward
|
||||
goal_body = self._get_goal_pose_body(self.goal_pose)
|
||||
goal_geom = pymunk_to_shapely(goal_body, self.block.shapes)
|
||||
block_geom = pymunk_to_shapely(self.block, self.block.shapes)
|
||||
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage = intersection_area / goal_area
|
||||
reward = np.clip(coverage / self.success_threshold, 0, 1)
|
||||
done = coverage > self.success_threshold
|
||||
|
||||
observation = self._get_obs()
|
||||
info = self._get_info()
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def render(self, mode):
|
||||
return self._render_frame(mode)
|
||||
|
||||
def teleop_agent(self):
|
||||
TeleopAgent = collections.namedtuple("TeleopAgent", ["act"])
|
||||
|
||||
def act(obs):
|
||||
act = None
|
||||
mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen)
|
||||
if self.teleop or (mouse_position - self.agent.position).length < 30:
|
||||
self.teleop = True
|
||||
act = mouse_position
|
||||
return act
|
||||
|
||||
return TeleopAgent(act)
|
||||
|
||||
def _get_obs(self):
|
||||
obs = np.array(
|
||||
tuple(self.agent.position) + tuple(self.block.position) + (self.block.angle % (2 * np.pi),)
|
||||
)
|
||||
return obs
|
||||
|
||||
def _get_goal_pose_body(self, pose):
|
||||
mass = 1
|
||||
inertia = pymunk.moment_for_box(mass, (50, 100))
|
||||
body = pymunk.Body(mass, inertia)
|
||||
# preserving the legacy assignment order for compatibility
|
||||
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
|
||||
body.position = pose[:2].tolist()
|
||||
body.angle = pose[2]
|
||||
return body
|
||||
|
||||
def _get_info(self):
|
||||
n_steps = self.sim_hz // self.control_hz
|
||||
n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps))
|
||||
info = {
|
||||
"pos_agent": np.array(self.agent.position),
|
||||
"vel_agent": np.array(self.agent.velocity),
|
||||
"block_pose": np.array(list(self.block.position) + [self.block.angle]),
|
||||
"goal_pose": self.goal_pose,
|
||||
"n_contacts": n_contact_points_per_step,
|
||||
}
|
||||
return info
|
||||
|
||||
def _render_frame(self, mode):
|
||||
if self.window is None and mode == "human":
|
||||
pygame.init()
|
||||
pygame.display.init()
|
||||
self.window = pygame.display.set_mode((self.window_size, self.window_size))
|
||||
if self.clock is None and mode == "human":
|
||||
self.clock = pygame.time.Clock()
|
||||
|
||||
canvas = pygame.Surface((self.window_size, self.window_size))
|
||||
canvas.fill((255, 255, 255))
|
||||
self.screen = canvas
|
||||
|
||||
draw_options = DrawOptions(canvas)
|
||||
|
||||
# Draw goal pose.
|
||||
goal_body = self._get_goal_pose_body(self.goal_pose)
|
||||
for shape in self.block.shapes:
|
||||
goal_points = [
|
||||
pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface)
|
||||
for v in shape.get_vertices()
|
||||
]
|
||||
goal_points += [goal_points[0]]
|
||||
pygame.draw.polygon(canvas, self.goal_color, goal_points)
|
||||
|
||||
# Draw agent and block.
|
||||
self.space.debug_draw(draw_options)
|
||||
|
||||
if mode == "human":
|
||||
# The following line copies our drawings from `canvas` to the visible window
|
||||
self.window.blit(canvas, canvas.get_rect())
|
||||
pygame.event.pump()
|
||||
pygame.display.update()
|
||||
|
||||
# the clock is already ticked during in step for "human"
|
||||
|
||||
img = np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))
|
||||
img = cv2.resize(img, (self.render_size, self.render_size))
|
||||
if self.render_action and self.latest_action is not None:
|
||||
action = np.array(self.latest_action)
|
||||
coord = (action / 512 * 96).astype(np.int32)
|
||||
marker_size = int(8 / 96 * self.render_size)
|
||||
thickness = int(1 / 96 * self.render_size)
|
||||
cv2.drawMarker(
|
||||
img,
|
||||
coord,
|
||||
color=(255, 0, 0),
|
||||
markerType=cv2.MARKER_CROSS,
|
||||
markerSize=marker_size,
|
||||
thickness=thickness,
|
||||
)
|
||||
return img
|
||||
|
||||
def close(self):
|
||||
if self.window is not None:
|
||||
pygame.display.quit()
|
||||
pygame.quit()
|
||||
|
||||
def seed(self, seed=None):
|
||||
if seed is None:
|
||||
seed = np.random.randint(0, 25536)
|
||||
self._seed = seed
|
||||
self.np_random = np.random.default_rng(seed)
|
||||
|
||||
def _handle_collision(self, arbiter, space, data):
|
||||
self.n_contact_points += len(arbiter.contact_point_set.points)
|
||||
|
||||
def _set_state(self, state):
|
||||
if isinstance(state, np.ndarray):
|
||||
state = state.tolist()
|
||||
pos_agent = state[:2]
|
||||
pos_block = state[2:4]
|
||||
rot_block = state[4]
|
||||
self.agent.position = pos_agent
|
||||
# setting angle rotates with respect to center of mass
|
||||
# therefore will modify the geometric position
|
||||
# if not the same as CoM
|
||||
# therefore should be modified first.
|
||||
if self.legacy:
|
||||
# for compatibility with legacy data
|
||||
self.block.position = pos_block
|
||||
self.block.angle = rot_block
|
||||
else:
|
||||
self.block.angle = rot_block
|
||||
self.block.position = pos_block
|
||||
|
||||
# Run physics to take effect
|
||||
self.space.step(1.0 / self.sim_hz)
|
||||
|
||||
def _set_state_local(self, state_local):
|
||||
agent_pos_local = state_local[:2]
|
||||
block_pose_local = state_local[2:]
|
||||
tf_img_obj = st.AffineTransform(translation=self.goal_pose[:2], rotation=self.goal_pose[2])
|
||||
tf_obj_new = st.AffineTransform(translation=block_pose_local[:2], rotation=block_pose_local[2])
|
||||
tf_img_new = st.AffineTransform(matrix=tf_img_obj.params @ tf_obj_new.params)
|
||||
agent_pos_new = tf_img_new(agent_pos_local)
|
||||
new_state = np.array(list(agent_pos_new[0]) + list(tf_img_new.translation) + [tf_img_new.rotation])
|
||||
self._set_state(new_state)
|
||||
return new_state
|
||||
|
||||
def _setup(self):
|
||||
self.space = pymunk.Space()
|
||||
self.space.gravity = 0, 0
|
||||
self.space.damping = 0
|
||||
self.teleop = False
|
||||
self.render_buffer = []
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
self._add_segment((5, 506), (5, 5), 2),
|
||||
self._add_segment((5, 5), (506, 5), 2),
|
||||
self._add_segment((506, 5), (506, 506), 2),
|
||||
self._add_segment((5, 506), (506, 506), 2),
|
||||
]
|
||||
self.space.add(*walls)
|
||||
|
||||
# Add agent, block, and goal zone.
|
||||
self.agent = self.add_circle((256, 400), 15)
|
||||
self.block = self.add_tee((256, 300), 0)
|
||||
self.goal_color = pygame.Color("LightGreen")
|
||||
self.goal_pose = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
|
||||
# Add collision handling
|
||||
self.collision_handeler = self.space.add_collision_handler(0, 0)
|
||||
self.collision_handeler.post_solve = self._handle_collision
|
||||
self.n_contact_points = 0
|
||||
|
||||
self.max_score = 50 * 100
|
||||
self.success_threshold = 0.95 # 95% coverage.
|
||||
|
||||
def _add_segment(self, a, b, radius):
|
||||
shape = pymunk.Segment(self.space.static_body, a, b, radius)
|
||||
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
|
||||
return shape
|
||||
|
||||
def add_circle(self, position, radius):
|
||||
body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)
|
||||
body.position = position
|
||||
body.friction = 1
|
||||
shape = pymunk.Circle(body, radius)
|
||||
shape.color = pygame.Color("RoyalBlue")
|
||||
self.space.add(body, shape)
|
||||
return body
|
||||
|
||||
def add_box(self, position, height, width):
|
||||
mass = 1
|
||||
inertia = pymunk.moment_for_box(mass, (height, width))
|
||||
body = pymunk.Body(mass, inertia)
|
||||
body.position = position
|
||||
shape = pymunk.Poly.create_box(body, (height, width))
|
||||
shape.color = pygame.Color("LightSlateGray")
|
||||
self.space.add(body, shape)
|
||||
return body
|
||||
|
||||
def add_tee(self, position, angle, scale=30, color="LightSlateGray", mask=None):
|
||||
if mask is None:
|
||||
mask = pymunk.ShapeFilter.ALL_MASKS()
|
||||
mass = 1
|
||||
length = 4
|
||||
vertices1 = [
|
||||
(-length * scale / 2, scale),
|
||||
(length * scale / 2, scale),
|
||||
(length * scale / 2, 0),
|
||||
(-length * scale / 2, 0),
|
||||
]
|
||||
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||
vertices2 = [
|
||||
(-scale / 2, scale),
|
||||
(-scale / 2, length * scale),
|
||||
(scale / 2, length * scale),
|
||||
(scale / 2, scale),
|
||||
]
|
||||
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||
body = pymunk.Body(mass, inertia1 + inertia2)
|
||||
shape1 = pymunk.Poly(body, vertices1)
|
||||
shape2 = pymunk.Poly(body, vertices2)
|
||||
shape1.color = pygame.Color(color)
|
||||
shape2.color = pygame.Color(color)
|
||||
shape1.filter = pymunk.ShapeFilter(mask=mask)
|
||||
shape2.filter = pymunk.ShapeFilter(mask=mask)
|
||||
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
|
||||
body.position = position
|
||||
body.angle = angle
|
||||
body.friction = 1
|
||||
self.space.add(body, shape1, shape2)
|
||||
return body
|
||||
41
lerobot/common/envs/pusht/pusht_image_env.py
Normal file
41
lerobot/common/envs/pusht/pusht_image_env.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||
|
||||
|
||||
class PushTImageEnv(PushTEnv):
|
||||
metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
|
||||
|
||||
# Note: legacy defaults to True for compatibility with original
|
||||
def __init__(self, legacy=True, block_cog=None, damping=None, render_size=96):
|
||||
super().__init__(
|
||||
legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False
|
||||
)
|
||||
ws = self.window_size
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"image": spaces.Box(low=0, high=1, shape=(3, render_size, render_size), dtype=np.float32),
|
||||
"agent_pos": spaces.Box(low=0, high=ws, shape=(2,), dtype=np.float32),
|
||||
}
|
||||
)
|
||||
self.render_cache = None
|
||||
|
||||
def _get_obs(self):
|
||||
img = super()._render_frame(mode="rgb_array")
|
||||
|
||||
agent_pos = np.array(self.agent.position)
|
||||
img_obs = np.moveaxis(img, -1, 0)
|
||||
obs = {"image": img_obs, "agent_pos": agent_pos}
|
||||
|
||||
self.render_cache = img
|
||||
|
||||
return obs
|
||||
|
||||
def render(self, mode):
|
||||
assert mode == "rgb_array"
|
||||
|
||||
if self.render_cache is None:
|
||||
self._get_obs()
|
||||
|
||||
return self.render_cache
|
||||
244
lerobot/common/envs/pusht/pymunk_override.py
Normal file
244
lerobot/common/envs/pusht/pymunk_override.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# ----------------------------------------------------------------------------
|
||||
# pymunk
|
||||
# Copyright (c) 2007-2016 Victor Blomqvist
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
"""This submodule contains helper functions to help with quick prototyping
|
||||
using pymunk together with pygame.
|
||||
|
||||
Intended to help with debugging and prototyping, not for actual production use
|
||||
in a full application. The methods contained in this module is opinionated
|
||||
about your coordinate system and not in any way optimized.
|
||||
"""
|
||||
|
||||
__docformat__ = "reStructuredText"
|
||||
|
||||
__all__ = [
|
||||
"DrawOptions",
|
||||
"get_mouse_pos",
|
||||
"to_pygame",
|
||||
"from_pygame",
|
||||
# "lighten",
|
||||
"positive_y_is_up",
|
||||
]
|
||||
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pygame
|
||||
import pymunk
|
||||
from pymunk.space_debug_draw_options import SpaceDebugColor
|
||||
from pymunk.vec2d import Vec2d
|
||||
|
||||
positive_y_is_up: bool = False
|
||||
"""Make increasing values of y point upwards.
|
||||
|
||||
When True::
|
||||
|
||||
y
|
||||
^
|
||||
| . (3, 3)
|
||||
|
|
||||
| . (2, 2)
|
||||
|
|
||||
+------ > x
|
||||
|
||||
When False::
|
||||
|
||||
+------ > x
|
||||
|
|
||||
| . (2, 2)
|
||||
|
|
||||
| . (3, 3)
|
||||
v
|
||||
y
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class DrawOptions(pymunk.SpaceDebugDrawOptions):
|
||||
def __init__(self, surface: pygame.Surface) -> None:
|
||||
"""Draw a pymunk.Space on a pygame.Surface object.
|
||||
|
||||
Typical usage::
|
||||
|
||||
>>> import pymunk
|
||||
>>> surface = pygame.Surface((10,10))
|
||||
>>> space = pymunk.Space()
|
||||
>>> options = pymunk.pygame_util.DrawOptions(surface)
|
||||
>>> space.debug_draw(options)
|
||||
|
||||
You can control the color of a shape by setting shape.color to the color
|
||||
you want it drawn in::
|
||||
|
||||
>>> c = pymunk.Circle(None, 10)
|
||||
>>> c.color = pygame.Color("pink")
|
||||
|
||||
See pygame_util.demo.py for a full example
|
||||
|
||||
Since pygame uses a coordinate system where y points down (in contrast
|
||||
to many other cases), you either have to make the physics simulation
|
||||
with Pymunk also behave in that way, or flip everything when you draw.
|
||||
|
||||
The easiest is probably to just make the simulation behave the same
|
||||
way as Pygame does. In that way all coordinates used are in the same
|
||||
orientation and easy to reason about::
|
||||
|
||||
>>> space = pymunk.Space()
|
||||
>>> space.gravity = (0, -1000)
|
||||
>>> body = pymunk.Body()
|
||||
>>> body.position = (0, 0) # will be positioned in the top left corner
|
||||
>>> space.debug_draw(options)
|
||||
|
||||
To flip the drawing its possible to set the module property
|
||||
:py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
|
||||
the simulation upside down before drawing::
|
||||
|
||||
>>> positive_y_is_up = True
|
||||
>>> body = pymunk.Body()
|
||||
>>> body.position = (0, 0)
|
||||
>>> # Body will be position in bottom left corner
|
||||
|
||||
:Parameters:
|
||||
surface : pygame.Surface
|
||||
Surface that the objects will be drawn on
|
||||
"""
|
||||
self.surface = surface
|
||||
super().__init__()
|
||||
|
||||
def draw_circle(
|
||||
self,
|
||||
pos: Vec2d,
|
||||
angle: float,
|
||||
radius: float,
|
||||
outline_color: SpaceDebugColor,
|
||||
fill_color: SpaceDebugColor,
|
||||
) -> None:
|
||||
p = to_pygame(pos, self.surface)
|
||||
|
||||
pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
|
||||
pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius - 4), 0)
|
||||
|
||||
# circle_edge = pos + Vec2d(radius, 0).rotated(angle)
|
||||
# p2 = to_pygame(circle_edge, self.surface)
|
||||
# line_r = 2 if radius > 20 else 1
|
||||
# pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)
|
||||
|
||||
def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
|
||||
p1 = to_pygame(a, self.surface)
|
||||
p2 = to_pygame(b, self.surface)
|
||||
|
||||
pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])
|
||||
|
||||
def draw_fat_segment(
|
||||
self,
|
||||
a: Tuple[float, float],
|
||||
b: Tuple[float, float],
|
||||
radius: float,
|
||||
outline_color: SpaceDebugColor,
|
||||
fill_color: SpaceDebugColor,
|
||||
) -> None:
|
||||
p1 = to_pygame(a, self.surface)
|
||||
p2 = to_pygame(b, self.surface)
|
||||
|
||||
r = round(max(1, radius * 2))
|
||||
pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
|
||||
if r > 2:
|
||||
orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
|
||||
if orthog[0] == 0 and orthog[1] == 0:
|
||||
return
|
||||
scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
|
||||
orthog[0] = round(orthog[0] * scale)
|
||||
orthog[1] = round(orthog[1] * scale)
|
||||
points = [
|
||||
(p1[0] - orthog[0], p1[1] - orthog[1]),
|
||||
(p1[0] + orthog[0], p1[1] + orthog[1]),
|
||||
(p2[0] + orthog[0], p2[1] + orthog[1]),
|
||||
(p2[0] - orthog[0], p2[1] - orthog[1]),
|
||||
]
|
||||
pygame.draw.polygon(self.surface, fill_color.as_int(), points)
|
||||
pygame.draw.circle(
|
||||
self.surface,
|
||||
fill_color.as_int(),
|
||||
(round(p1[0]), round(p1[1])),
|
||||
round(radius),
|
||||
)
|
||||
pygame.draw.circle(
|
||||
self.surface,
|
||||
fill_color.as_int(),
|
||||
(round(p2[0]), round(p2[1])),
|
||||
round(radius),
|
||||
)
|
||||
|
||||
def draw_polygon(
|
||||
self,
|
||||
verts: Sequence[Tuple[float, float]],
|
||||
radius: float,
|
||||
outline_color: SpaceDebugColor,
|
||||
fill_color: SpaceDebugColor,
|
||||
) -> None:
|
||||
ps = [to_pygame(v, self.surface) for v in verts]
|
||||
ps += [ps[0]]
|
||||
|
||||
radius = 2
|
||||
pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)
|
||||
|
||||
if radius > 0:
|
||||
for i in range(len(verts)):
|
||||
a = verts[i]
|
||||
b = verts[(i + 1) % len(verts)]
|
||||
self.draw_fat_segment(a, b, radius, fill_color, fill_color)
|
||||
|
||||
def draw_dot(self, size: float, pos: Tuple[float, float], color: SpaceDebugColor) -> None:
|
||||
p = to_pygame(pos, self.surface)
|
||||
pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)
|
||||
|
||||
|
||||
def get_mouse_pos(surface: pygame.Surface) -> Tuple[int, int]:
|
||||
"""Get position of the mouse pointer in pymunk coordinates."""
|
||||
p = pygame.mouse.get_pos()
|
||||
return from_pygame(p, surface)
|
||||
|
||||
|
||||
def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
|
||||
"""Convenience method to convert pymunk coordinates to pygame surface
|
||||
local coordinates.
|
||||
|
||||
Note that in case positive_y_is_up is False, this function won't actually do
|
||||
anything except converting the point to integers.
|
||||
"""
|
||||
if positive_y_is_up:
|
||||
return round(p[0]), surface.get_height() - round(p[1])
|
||||
else:
|
||||
return round(p[0]), round(p[1])
|
||||
|
||||
|
||||
def from_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
|
||||
"""Convenience method to convert pygame surface local coordinates to
|
||||
pymunk coordinates
|
||||
"""
|
||||
return to_pygame(p, surface)
|
||||
|
||||
|
||||
def light_color(color: SpaceDebugColor):
|
||||
color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
|
||||
color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
|
||||
return color
|
||||
@@ -1,181 +0,0 @@
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
BoundedTensorSpec,
|
||||
CompositeSpec,
|
||||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
MAX_NUM_ACTIONS = 4
|
||||
|
||||
_has_gym = importlib.util.find_spec("gym") is not None
|
||||
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
|
||||
|
||||
|
||||
class SimxarmEnv(EnvBase):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
|
||||
if not _has_simxarm:
|
||||
raise ImportError("Cannot import simxarm.")
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
import gym
|
||||
from simxarm import TASKS
|
||||
|
||||
if self.task not in TASKS:
|
||||
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
||||
|
||||
self._env = TASKS[self.task]["env"]()
|
||||
|
||||
num_actions = len(TASKS[self.task]["action_space"])
|
||||
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
|
||||
if "w" not in TASKS[self.task]["action_space"]:
|
||||
self._action_padding[-1] = 1.0
|
||||
|
||||
self._make_spec()
|
||||
self.set_seed(seed)
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
return self._env.render(mode, width=width, height=height)
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
|
||||
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||
image = torch.tensor(image.copy(), dtype=torch.uint8)
|
||||
|
||||
obs = {"image": image}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32)
|
||||
else:
|
||||
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
|
||||
|
||||
obs = TensorDict(obs, batch_size=[])
|
||||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
raw_obs = self._env.reset()
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
# step expects shape=(4,) so we pad if necessary
|
||||
action = np.concatenate([action, self._action_padding])
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
for _ in range(self.frame_skip):
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
sum_reward += reward
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([info["success"]], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
return td
|
||||
|
||||
def _make_spec(self):
|
||||
obs = {}
|
||||
if self.from_pixels:
|
||||
obs["image"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(3, self.image_size, self.image_size),
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
shape=(len(self._env.robot_state),),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
shape=self._env.observation_space["observation"].shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.observation_spec = CompositeSpec({"observation": obs})
|
||||
|
||||
self.action_spec = _gym_to_torchrl_spec_transform(
|
||||
self._action_space,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||
shape=(1,),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.done_spec = CompositeSpec(
|
||||
{
|
||||
"done": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
"success": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
set_seed(seed)
|
||||
self._env.seed(seed)
|
||||
@@ -1,7 +1,10 @@
|
||||
import importlib
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
@@ -10,93 +13,86 @@ from torchrl.data.tensor_specs import (
|
||||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||
|
||||
from lerobot.common.utils import set_seed
|
||||
from lerobot.common.envs.abstract import AbstractEnv
|
||||
from lerobot.common.utils import set_global_seed
|
||||
|
||||
_has_gym = importlib.util.find_spec("gym") is not None
|
||||
_has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _has_gym
|
||||
MAX_NUM_ACTIONS = 4
|
||||
|
||||
_has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||
|
||||
|
||||
class PushtEnv(EnvBase):
|
||||
class SimxarmEnv(AbstractEnv):
|
||||
name = "simxarm"
|
||||
available_tasks = ["lift"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=1,
|
||||
num_prev_obs=0,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
self.num_prev_obs = num_prev_obs
|
||||
self.num_prev_action = num_prev_action
|
||||
super().__init__(
|
||||
task=task,
|
||||
frame_skip=frame_skip,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=image_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
|
||||
if not _has_diffpolicy:
|
||||
raise ImportError("Cannot import diffusion_policy.")
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
raise ImportError("Cannot import gymnasium.")
|
||||
|
||||
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
||||
# from diffusion_policy.env.pusht.pusht_env import PushTEnv
|
||||
import gymnasium
|
||||
|
||||
if not from_pixels:
|
||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||
from lerobot.common.envs.simxarm.simxarm import TASKS
|
||||
|
||||
self._env = PushTImageEnv(render_size=self.image_size)
|
||||
if self.task not in TASKS:
|
||||
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
||||
|
||||
self._make_spec()
|
||||
self._current_seed = self.set_seed(seed)
|
||||
self._env = TASKS[self.task]["env"]()
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||
if self.num_prev_action > 0:
|
||||
self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||
num_actions = len(TASKS[self.task]["action_space"])
|
||||
self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
|
||||
if "w" not in TASKS[self.task]["action_space"]:
|
||||
self._action_padding[-1] = 1.0
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
if width != height:
|
||||
raise NotImplementedError()
|
||||
tmp = self._env.render_size
|
||||
self._env.render_size = width
|
||||
out = self._env.render(mode)
|
||||
self._env.render_size = tmp
|
||||
return out
|
||||
return self._env.render(mode, width=width, height=height)
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
image = torch.from_numpy(raw_obs["image"])
|
||||
image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
|
||||
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||
image = torch.tensor(image.copy(), dtype=torch.uint8)
|
||||
|
||||
obs = {"image": image}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
|
||||
obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32)
|
||||
else:
|
||||
# TODO:
|
||||
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
|
||||
|
||||
# obs = TensorDict(obs, batch_size=[])
|
||||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
raw_obs = self._env.reset()
|
||||
assert self._current_seed == self._env._seed
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
@@ -123,17 +119,19 @@ class PushtEnv(EnvBase):
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
# step expects shape=(4,) so we pad if necessary
|
||||
action = np.concatenate([action, self._action_padding])
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
|
||||
if action.ndim == 1:
|
||||
action = action.repeat(self.frame_skip, 1)
|
||||
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
||||
else:
|
||||
if self.frame_skip > 1:
|
||||
raise NotImplementedError()
|
||||
@@ -157,11 +155,10 @@ class PushtEnv(EnvBase):
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||
# succes and done are true when coverage > self.success_threshold in env
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([info["success"]], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
@@ -172,24 +169,22 @@ class PushtEnv(EnvBase):
|
||||
if self.from_pixels:
|
||||
image_shape = (3, self.image_size, self.image_size)
|
||||
if self.num_prev_obs > 0:
|
||||
image_shape = (self.num_prev_obs, *image_shape)
|
||||
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||
|
||||
obs["image"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=1,
|
||||
high=255,
|
||||
shape=image_shape,
|
||||
dtype=torch.float32,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
state_shape = self._env.observation_space["agent_pos"].shape
|
||||
state_shape = (len(self._env.robot_state),)
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs, *state_shape)
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=512,
|
||||
shape=self._env.observation_space["agent_pos"].shape,
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
@@ -197,18 +192,18 @@ class PushtEnv(EnvBase):
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
state_shape = self._env.observation_space["observation"].shape
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs, *state_shape)
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
# TODO:
|
||||
shape=self._env.observation_space["observation"].shape,
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.observation_spec = CompositeSpec({"observation": obs})
|
||||
|
||||
self.action_spec = _gym_to_torchrl_spec_transform(
|
||||
self._env.action_space,
|
||||
self._action_space,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@@ -236,5 +231,7 @@ class PushtEnv(EnvBase):
|
||||
)
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
set_seed(seed)
|
||||
self._env.seed(seed)
|
||||
set_global_seed(seed)
|
||||
self._seed = seed
|
||||
# TODO(aliberts): change self._reset so that it takes in a seed value
|
||||
logging.warning("simxarm env is not properly seeded")
|
||||
166
lerobot/common/envs/simxarm/simxarm/__init__.py
Normal file
166
lerobot/common/envs/simxarm/simxarm/__init__.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from collections import OrderedDict, deque
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium.wrappers import TimeLimit
|
||||
|
||||
from lerobot.common.envs.simxarm.simxarm.tasks.base import Base as Base
|
||||
from lerobot.common.envs.simxarm.simxarm.tasks.lift import Lift
|
||||
from lerobot.common.envs.simxarm.simxarm.tasks.peg_in_box import PegInBox
|
||||
from lerobot.common.envs.simxarm.simxarm.tasks.push import Push
|
||||
from lerobot.common.envs.simxarm.simxarm.tasks.reach import Reach
|
||||
|
||||
TASKS = OrderedDict(
|
||||
(
|
||||
(
|
||||
"reach",
|
||||
{
|
||||
"env": Reach,
|
||||
"action_space": "xyz",
|
||||
"episode_length": 50,
|
||||
"description": "Reach a target location with the end effector",
|
||||
},
|
||||
),
|
||||
(
|
||||
"push",
|
||||
{
|
||||
"env": Push,
|
||||
"action_space": "xyz",
|
||||
"episode_length": 50,
|
||||
"description": "Push a cube to a target location",
|
||||
},
|
||||
),
|
||||
(
|
||||
"peg_in_box",
|
||||
{
|
||||
"env": PegInBox,
|
||||
"action_space": "xyz",
|
||||
"episode_length": 50,
|
||||
"description": "Insert a peg into a box",
|
||||
},
|
||||
),
|
||||
(
|
||||
"lift",
|
||||
{
|
||||
"env": Lift,
|
||||
"action_space": "xyzw",
|
||||
"episode_length": 50,
|
||||
"description": "Lift a cube above a height threshold",
|
||||
},
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SimXarmWrapper(gym.Wrapper):
|
||||
"""
|
||||
A wrapper for the SimXarm environments. This wrapper is used to
|
||||
convert the action and observation spaces to the correct format.
|
||||
"""
|
||||
|
||||
def __init__(self, env, task, obs_mode, image_size, action_repeat, frame_stack=1, channel_last=False):
|
||||
super().__init__(env)
|
||||
self._env = env
|
||||
self.obs_mode = obs_mode
|
||||
self.image_size = image_size
|
||||
self.action_repeat = action_repeat
|
||||
self.frame_stack = frame_stack
|
||||
self._frames = deque([], maxlen=frame_stack)
|
||||
self.channel_last = channel_last
|
||||
self._max_episode_steps = task["episode_length"] // action_repeat
|
||||
|
||||
image_shape = (
|
||||
(image_size, image_size, 3 * frame_stack)
|
||||
if channel_last
|
||||
else (3 * frame_stack, image_size, image_size)
|
||||
)
|
||||
if obs_mode == "state":
|
||||
self.observation_space = env.observation_space["observation"]
|
||||
elif obs_mode == "rgb":
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8)
|
||||
elif obs_mode == "all":
|
||||
self.observation_space = gym.spaces.Dict(
|
||||
state=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32),
|
||||
rgb=gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown obs_mode {obs_mode}. Must be one of [rgb, all, state]")
|
||||
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(len(task["action_space"]),))
|
||||
self.action_padding = np.zeros(4 - len(task["action_space"]), dtype=np.float32)
|
||||
if "w" not in task["action_space"]:
|
||||
self.action_padding[-1] = 1.0
|
||||
|
||||
def _render_obs(self):
|
||||
obs = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
|
||||
if not self.channel_last:
|
||||
obs = obs.transpose(2, 0, 1)
|
||||
return obs.copy()
|
||||
|
||||
def _update_frames(self, reset=False):
|
||||
pixels = self._render_obs()
|
||||
self._frames.append(pixels)
|
||||
if reset:
|
||||
for _ in range(1, self.frame_stack):
|
||||
self._frames.append(pixels)
|
||||
assert len(self._frames) == self.frame_stack
|
||||
|
||||
def transform_obs(self, obs, reset=False):
|
||||
if self.obs_mode == "state":
|
||||
return obs["observation"]
|
||||
elif self.obs_mode == "rgb":
|
||||
self._update_frames(reset=reset)
|
||||
rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
|
||||
return rgb_obs
|
||||
elif self.obs_mode == "all":
|
||||
self._update_frames(reset=reset)
|
||||
rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
|
||||
return OrderedDict((("rgb", rgb_obs), ("state", self.robot_state)))
|
||||
else:
|
||||
raise ValueError(f"Unknown obs_mode {self.obs_mode}. Must be one of [rgb, all, state]")
|
||||
|
||||
def reset(self):
|
||||
return self.transform_obs(self._env.reset(), reset=True)
|
||||
|
||||
def step(self, action):
|
||||
action = np.concatenate([action, self.action_padding])
|
||||
reward = 0.0
|
||||
for _ in range(self.action_repeat):
|
||||
obs, r, done, info = self._env.step(action)
|
||||
reward += r
|
||||
return self.transform_obs(obs), reward, done, info
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384, **kwargs):
|
||||
return self._env.render(mode, width=width, height=height)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self._env.robot_state
|
||||
|
||||
|
||||
def make(task, obs_mode="state", image_size=84, action_repeat=1, frame_stack=1, channel_last=False, seed=0):
|
||||
"""
|
||||
Create a new environment.
|
||||
Args:
|
||||
task (str): The task to create an environment for. Must be one of:
|
||||
- 'reach'
|
||||
- 'push'
|
||||
- 'peg-in-box'
|
||||
- 'lift'
|
||||
obs_mode (str): The observation mode to use. Must be one of:
|
||||
- 'state': Only state observations
|
||||
- 'rgb': RGB images
|
||||
- 'all': RGB images and state observations
|
||||
image_size (int): The size of the image observations
|
||||
action_repeat (int): The number of times to repeat the action
|
||||
seed (int): The random seed to use
|
||||
Returns:
|
||||
gym.Env: The environment
|
||||
"""
|
||||
if task not in TASKS:
|
||||
raise ValueError(f"Unknown task {task}. Must be one of {list(TASKS.keys())}")
|
||||
env = TASKS[task]["env"]()
|
||||
env = TimeLimit(env, TASKS[task]["episode_length"])
|
||||
env = SimXarmWrapper(env, TASKS[task], obs_mode, image_size, action_repeat, frame_stack, channel_last)
|
||||
env.seed(seed)
|
||||
|
||||
return env
|
||||
53
lerobot/common/envs/simxarm/simxarm/tasks/assets/lift.xml
Normal file
53
lerobot/common/envs/simxarm/simxarm/tasks/assets/lift.xml
Normal file
@@ -0,0 +1,53 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
|
||||
<mujoco>
|
||||
<compiler angle="radian" coordinate="local" meshdir="mesh" texturedir="texture"></compiler>
|
||||
<size nconmax="2000" njmax="500"/>
|
||||
|
||||
<option timestep="0.002">
|
||||
<flag warmstart="enable"></flag>
|
||||
</option>
|
||||
|
||||
<include file="shared.xml"></include>
|
||||
|
||||
<worldbody>
|
||||
<body name="floor0" pos="0 0 0">
|
||||
<geom name="floorgeom0" pos="1.2 -2.0 0" size="20.0 20.0 1" type="plane" condim="3" material="floor_mat"></geom>
|
||||
</body>
|
||||
|
||||
<include file="xarm.xml"></include>
|
||||
|
||||
<body pos="0.75 0 0.6325" name="pedestal0">
|
||||
<geom name="pedestalgeom0" size="0.1 0.1 0.01" pos="0.32 0.27 0" type="box" mass="2000" material="pedestal_mat"></geom>
|
||||
<site pos="0.30 0.30 0" size="0.075 0.075 0.002" type="box" name="robotmountsite0" rgba="0.55 0.54 0.53 1" />
|
||||
</body>
|
||||
|
||||
<body pos="1.5 0.075 0.3425" name="table0">
|
||||
<geom name="tablegeom0" size="0.3 0.6 0.2" pos="0 0 0" type="box" material="table_mat" density="2000" friction="1 1 1"></geom>
|
||||
</body>
|
||||
|
||||
<body name="object" pos="1.405 0.3 0.58625">
|
||||
<joint name="object_joint0" type="free" limited="false"></joint>
|
||||
<geom size="0.035 0.035 0.035" type="box" name="object0" material="block_mat" density="50000" condim="4" friction="1 1 1" solimp="1 1 1" solref="0.02 1"></geom>
|
||||
<site name="object_site" pos="0 0 0" size="0.035 0.035 0.035" rgba="1 0 0 0" type="box"></site>
|
||||
</body>
|
||||
|
||||
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="1.65 0 10" dir="-0.57 -0.57 -0.57" name="light0"></light>
|
||||
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="0 -4 4" dir="0 1 -0.1" name="light1"></light>
|
||||
<light directional="true" ambient="0.05 0.05 0.05" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="2.13 1.6 2.5" name="light2"></light>
|
||||
<light pos="0 0 2" dir="0.2 0.2 -0.8" directional="true" diffuse="0.3 0.3 0.3" castshadow="false" name="light3"></light>
|
||||
|
||||
<camera fovy="50" name="camera0" pos="0.9559 1.0 1.1" euler="-1.1 -0.6 3.4" />
|
||||
</worldbody>
|
||||
|
||||
<equality>
|
||||
<connect body2="left_finger" body1="left_inner_knuckle" anchor="0.0 0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||
<connect body2="right_finger" body1="right_inner_knuckle" anchor="0.0 -0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||
<joint joint1="left_inner_knuckle_joint" joint2="right_inner_knuckle_joint"></joint>
|
||||
</equality>
|
||||
|
||||
<actuator>
|
||||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="left_inner_knuckle_joint" gear="200.0"/>
|
||||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="right_inner_knuckle_joint" gear="200.0"/>
|
||||
</actuator>
|
||||
</mujoco>
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:21fb81ae7fba19e3c6b2d2ca60c8051712ba273357287eb5a397d92d61c7a736
|
||||
size 1211434
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:be68ce180d11630a667a5f37f4dffcc3feebe4217d4bb3912c813b6d9ca3ec66
|
||||
size 3284
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2c6448552bf6b1c4f17334d686a5320ce051bcdfe31431edf69303d8a570d1de
|
||||
size 3284
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:748b9e197e6521914f18d1f6383a36f211136b3f33f2ad2a8c11b9f921c2cf86
|
||||
size 6284
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a44756eb72f9c214cb37e61dc209cd7073fdff3e4271a7423476ef6fd090d2d4
|
||||
size 242684
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e8e48692ad26837bb3d6a97582c89784d09948fc09bfe4e5a59017859ff04dac
|
||||
size 366284
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:501665812b08d67e764390db781e839adc6896a9540301d60adf606f57648921
|
||||
size 22284
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:34b541122df84d2ef5fcb91b715eb19659dc15ad8d44a191dde481f780265636
|
||||
size 184184
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:61e641cd47c169ecef779683332e00e4914db729bf02dfb61bfbe69351827455
|
||||
size 225584
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9e2798e7946dd70046c95455d5ba96392d0b54a6069caba91dc4ca66e1379b42
|
||||
size 237084
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c757fee95f873191a0633c355c07a360032960771cabbd7593a6cdb0f1ffb089
|
||||
size 243684
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:715ad5787c5dab57589937fd47289882707b5e1eb997e340d567785b02f4ec90
|
||||
size 229084
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:85b320aa420497827223d16d492bba8de091173374e361396fc7a5dad7bdb0cb
|
||||
size 399384
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:97115d848fbf802cb770cd9be639ae2af993103b9d9bbb0c50c943c738a36f18
|
||||
size 231684
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f6fcbc18258090eb56c21cfb17baa5ae43abc98b1958cd366f3a73b9898fc7f0
|
||||
size 2106184
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c5dee87c7f37baf554b8456ebfe0b3e8ed0b22b8938bd1add6505c2ad6d32c7d
|
||||
size 242684
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b41dd2c2c550281bf78d7cc6fa117b14786700e5c453560a0cb5fd6dfa0ffb3e
|
||||
size 366284
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:75ca1107d0a42a0f03802a9a49cab48419b31851ee8935f8f1ca06be1c1c91e8
|
||||
size 22284
|
||||
@@ -0,0 +1,74 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
|
||||
<mujoco>
|
||||
<compiler angle="radian" coordinate="local" meshdir="mesh" texturedir="texture"></compiler>
|
||||
<size nconmax="2000" njmax="500"/>
|
||||
|
||||
<option timestep="0.001">
|
||||
<flag warmstart="enable"></flag>
|
||||
</option>
|
||||
|
||||
<include file="shared.xml"></include>
|
||||
|
||||
<worldbody>
|
||||
<body name="floor0" pos="0 0 0">
|
||||
<geom name="floorgeom0" pos="1.2 -2.0 0" size="1.0 10.0 1" type="plane" condim="3" material="floor_mat"></geom>
|
||||
</body>
|
||||
|
||||
<include file="xarm.xml"></include>
|
||||
|
||||
<body pos="0.75 0 0.6325" name="pedestal0">
|
||||
<geom name="pedestalgeom0" size="0.1 0.1 0.01" pos="0.32 0.27 0" type="box" mass="2000" material="pedestal_mat"></geom>
|
||||
<site pos="0.30 0.30 0" size="0.075 0.075 0.002" type="box" name="robotmountsite0" rgba="0.55 0.54 0.53 1" />
|
||||
</body>
|
||||
|
||||
<body pos="1.5 0.075 0.3425" name="table0">
|
||||
<geom name="tablegeom0" size="0.3 0.6 0.2" pos="0 0 0" type="box" material="table_mat" density="2000" friction="1 0.005 0.0002"></geom>
|
||||
</body>
|
||||
|
||||
<body name="box0" pos="1.605 0.25 0.55">
|
||||
<joint name="box_joint0" type="free" limited="false"></joint>
|
||||
<site name="box_site" pos="0 0.075 -0.01" size="0.02" rgba="0 0 0 0" type="sphere"></site>
|
||||
<geom name="box_side0" pos="0 0 0" size="0.065 0.002 0.04" type= "box" rgba="0.8 0.1 0.1 1" mass ="1" condim="4" />
|
||||
<geom name="box_side1" pos="0 0.149 0" size="0.065 0.002 0.04" type="box" rgba="0.9 0.2 0.2 1" mass ="2" condim="4" />
|
||||
<geom name="box_side2" pos="0.064 0.074 0" size="0.002 0.075 0.04" type="box" rgba="0.8 0.1 0.1 1" mass ="2" condim="4" />
|
||||
<geom name="box_side3" pos="-0.064 0.074 0" size="0.002 0.075 0.04" type="box" rgba="0.9 0.2 0.2 1" mass ="2" condim="4" />
|
||||
<geom name="box_side4" pos="-0 0.074 -0.038" size="0.065 0.075 0.002" type="box" rgba="0.5 0 0 1" mass ="2" condim="4"/>
|
||||
</body>
|
||||
|
||||
<body name="object0" pos="1.4 0.25 0.65">
|
||||
<joint name="object_joint0" type="free" limited="false"></joint>
|
||||
<geom name="object_target0" type="cylinder" pos="0 0 -0.05" size="0.03 0.035" rgba="0.6 0.8 0.5 1" mass ="0.1" condim="3" />
|
||||
<site name="object_site" pos="0 0 -0.05" size="0.0325 0.0375" rgba="0 0 0 0" type="cylinder"></site>
|
||||
<body name="B0" pos="0 0 0" euler="0 0 0 ">
|
||||
<joint name="B0:joint" type="slide" limited="true" axis="0 0 1" damping="0.05" range="0.0001 0.0001001" solimpfriction="0.98 0.98 0.95" frictionloss="1"></joint>
|
||||
<geom type="capsule" size="0.002 0.03" rgba="0 0 0 1" mass="0.001" condim="4"/>
|
||||
<body name="B1" pos="0 0 0.04" euler="0 3.14 0 ">
|
||||
<joint name="B1:joint1" type="hinge" axis="1 0 0" range="-0.1 0.1" frictionloss="1"></joint>
|
||||
<joint name="B1:joint2" type="hinge" axis="0 1 0" range="-0.1 0.1" frictionloss="1"></joint>
|
||||
<joint name="B1:joint3" type="hinge" axis="0 0 1" range="-0.1 0.1" frictionloss="1"></joint>
|
||||
<geom type="capsule" size="0.002 0.004" rgba="1 0 0 0" mass="0.001" condim="4"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
|
||||
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="1.65 0 10" dir="-0.57 -0.57 -0.57" name="light0"></light>
|
||||
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="0 -4 4" dir="0 1 -0.1" name="light1"></light>
|
||||
<light directional="true" ambient="0.05 0.05 0.05" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="2.13 1.6 2.5" name="light2"></light>
|
||||
<light pos="0 0 2" dir="0.2 0.2 -0.8" directional="true" diffuse="0.3 0.3 0.3" castshadow="false" name="light3"></light>
|
||||
|
||||
<camera fovy="50" name="camera0" pos="0.9559 1.0 1.1" euler="-1.1 -0.6 3.4" />
|
||||
</worldbody>
|
||||
|
||||
<equality>
|
||||
<connect body2="left_finger" body1="left_inner_knuckle" anchor="0.0 0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||
<connect body2="right_finger" body1="right_inner_knuckle" anchor="0.0 -0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||
<weld body1="right_hand" body2="B1" solimp="0.99 0.99 0.99" solref="0.02 1"></weld>
|
||||
<joint joint1="left_inner_knuckle_joint" joint2="right_inner_knuckle_joint"></joint>
|
||||
</equality>
|
||||
|
||||
<actuator>
|
||||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="left_inner_knuckle_joint" gear="200.0"/>
|
||||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="right_inner_knuckle_joint" gear="200.0"/>
|
||||
</actuator>
|
||||
</mujoco>
|
||||
54
lerobot/common/envs/simxarm/simxarm/tasks/assets/push.xml
Normal file
54
lerobot/common/envs/simxarm/simxarm/tasks/assets/push.xml
Normal file
@@ -0,0 +1,54 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
|
||||
<mujoco>
|
||||
<compiler angle="radian" coordinate="local" meshdir="mesh" texturedir="texture"></compiler>
|
||||
<size nconmax="2000" njmax="500"/>
|
||||
|
||||
<option timestep="0.002">
|
||||
<flag warmstart="enable"></flag>
|
||||
</option>
|
||||
|
||||
<include file="shared.xml"></include>
|
||||
|
||||
<worldbody>
|
||||
<body name="floor0" pos="0 0 0">
|
||||
<geom name="floorgeom0" pos="1.2 -2.0 0" size="1.0 10.0 1" type="plane" condim="3" material="floor_mat"></geom>
|
||||
<site name="target0" pos="1.565 0.3 0.545" size="0.0475 0.001" rgba="1 0 0 1" type="cylinder"></site>
|
||||
</body>
|
||||
|
||||
<include file="xarm.xml"></include>
|
||||
|
||||
<body pos="0.75 0 0.6325" name="pedestal0">
|
||||
<geom name="pedestalgeom0" size="0.1 0.1 0.01" pos="0.32 0.27 0" type="box" mass="2000" material="pedestal_mat"></geom>
|
||||
<site pos="0.30 0.30 0" size="0.075 0.075 0.002" type="box" name="robotmountsite0" rgba="0.55 0.54 0.53 1" />
|
||||
</body>
|
||||
|
||||
<body pos="1.5 0.075 0.3425" name="table0">
|
||||
<geom name="tablegeom0" size="0.3 0.6 0.2" pos="0 0 0" type="box" material="table_mat" density="2000" friction="1 0.005 0.0002"></geom>
|
||||
</body>
|
||||
|
||||
<body name="object" pos="1.655 0.3 0.68">
|
||||
<joint name="object_joint0" type="free" limited="false"></joint>
|
||||
<geom size="0.024 0.024 0.024" type="box" name="object" material="block_mat" density="50000" condim="4" friction="1 1 1" solimp="1 1 1" solref="0.02 1"></geom>
|
||||
<site name="object_site" pos="0 0 0" size="0.024 0.024 0.024" rgba="0 0 0 0" type="box"></site>
|
||||
</body>
|
||||
|
||||
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="1.65 0 10" dir="-0.57 -0.57 -0.57" name="light0"></light>
|
||||
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="0 -4 4" dir="0 1 -0.1" name="light1"></light>
|
||||
<light directional="true" ambient="0.05 0.05 0.05" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="2.13 1.6 2.5" name="light2"></light>
|
||||
<light pos="0 0 2" dir="0.2 0.2 -0.8" directional="true" diffuse="0.3 0.3 0.3" castshadow="false" name="light3"></light>
|
||||
|
||||
<camera fovy="50" name="camera0" pos="0.9559 1.0 1.1" euler="-1.1 -0.6 3.4" />
|
||||
</worldbody>
|
||||
|
||||
<equality>
|
||||
<connect body2="left_finger" body1="left_inner_knuckle" anchor="0.0 0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||
<connect body2="right_finger" body1="right_inner_knuckle" anchor="0.0 -0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||
<joint joint1="left_inner_knuckle_joint" joint2="right_inner_knuckle_joint"></joint>
|
||||
</equality>
|
||||
|
||||
<actuator>
|
||||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="left_inner_knuckle_joint" gear="200.0"/>
|
||||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="right_inner_knuckle_joint" gear="200.0"/>
|
||||
</actuator>
|
||||
</mujoco>
|
||||
48
lerobot/common/envs/simxarm/simxarm/tasks/assets/reach.xml
Normal file
48
lerobot/common/envs/simxarm/simxarm/tasks/assets/reach.xml
Normal file
@@ -0,0 +1,48 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
|
||||
<mujoco>
|
||||
<compiler angle="radian" coordinate="local" meshdir="mesh" texturedir="texture"></compiler>
|
||||
<size nconmax="2000" njmax="500"/>
|
||||
|
||||
<option timestep="0.002">
|
||||
<flag warmstart="enable"></flag>
|
||||
</option>
|
||||
|
||||
<include file="shared.xml"></include>
|
||||
|
||||
<worldbody>
|
||||
<body name="floor0" pos="0 0 0">
|
||||
<geom name="floorgeom0" pos="1.2 -2.0 0" size="1.0 10.0 1" type="plane" condim="3" material="floor_mat"></geom>
|
||||
<site name="target0" pos="1.605 0.3 0.58" size="0.0475 0.001" rgba="1 0 0 1" type="cylinder"></site>
|
||||
</body>
|
||||
|
||||
<include file="xarm.xml"></include>
|
||||
|
||||
<body pos="0.75 0 0.6325" name="pedestal0">
|
||||
<geom name="pedestalgeom0" size="0.1 0.1 0.01" pos="0.32 0.27 0" type="box" mass="2000" material="pedestal_mat"></geom>
|
||||
<site pos="0.30 0.30 0" size="0.075 0.075 0.002" type="box" name="robotmountsite0" rgba="0.55 0.54 0.53 1" />
|
||||
</body>
|
||||
|
||||
<body pos="1.5 0.075 0.3425" name="table0">
|
||||
<geom name="tablegeom0" size="0.3 0.6 0.2" pos="0 0 0" type="box" material="table_mat" density="2000" friction="1 0.005 0.0002"></geom>
|
||||
</body>
|
||||
|
||||
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="1.65 0 10" dir="-0.57 -0.57 -0.57" name="light0"></light>
|
||||
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="0 -4 4" dir="0 1 -0.1" name="light1"></light>
|
||||
<light directional="true" ambient="0.05 0.05 0.05" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="2.13 1.6 2.5" name="light2"></light>
|
||||
<light pos="0 0 2" dir="0.2 0.2 -0.8" directional="true" diffuse="0.3 0.3 0.3" castshadow="false" name="light3"></light>
|
||||
|
||||
<camera fovy="50" name="camera0" pos="0.9559 1.0 1.1" euler="-1.1 -0.6 3.4" />
|
||||
</worldbody>
|
||||
|
||||
<equality>
|
||||
<connect body2="left_finger" body1="left_inner_knuckle" anchor="0.0 0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||
<connect body2="right_finger" body1="right_inner_knuckle" anchor="0.0 -0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||
<joint joint1="left_inner_knuckle_joint" joint2="right_inner_knuckle_joint"></joint>
|
||||
</equality>
|
||||
|
||||
<actuator>
|
||||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="left_inner_knuckle_joint" gear="200.0"/>
|
||||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="right_inner_knuckle_joint" gear="200.0"/>
|
||||
</actuator>
|
||||
</mujoco>
|
||||
51
lerobot/common/envs/simxarm/simxarm/tasks/assets/shared.xml
Normal file
51
lerobot/common/envs/simxarm/simxarm/tasks/assets/shared.xml
Normal file
@@ -0,0 +1,51 @@
|
||||
<mujoco>
|
||||
<asset>
|
||||
<texture type="skybox" builtin="gradient" rgb1="0.0 0.0 0.0" rgb2="0.0 0.0 0.0" width="32" height="32"></texture>
|
||||
<material name="floor_mat" specular="0" shininess="0.0" reflectance="0" rgba="0.043 0.055 0.051 1"></material>
|
||||
|
||||
<material name="table_mat" specular="0.2" shininess="0.2" reflectance="0" rgba="1 1 1 1"></material>
|
||||
<material name="pedestal_mat" specular="0.35" shininess="0.5" reflectance="0" rgba="0.705 0.585 0.405 1"></material>
|
||||
<material name="block_mat" specular="0.5" shininess="0.9" reflectance="0.05" rgba="0.373 0.678 0.627 1"></material>
|
||||
|
||||
<material name="robot0:geomMat" shininess="0.03" specular="0.4"></material>
|
||||
<material name="robot0:gripper_finger_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||
<material name="robot0:gripper_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||
<material name="background:gripper_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||
<material name="robot0:arm_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||
<material name="robot0:head_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||
<material name="robot0:torso_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||
<material name="robot0:base_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||
|
||||
<mesh name="link_base" file="link_base.stl" />
|
||||
<mesh name="link1" file="link1.stl" />
|
||||
<mesh name="link2" file="link2.stl" />
|
||||
<mesh name="link3" file="link3.stl" />
|
||||
<mesh name="link4" file="link4.stl" />
|
||||
<mesh name="link5" file="link5.stl" />
|
||||
<mesh name="link6" file="link6.stl" />
|
||||
<mesh name="link7" file="link7.stl" />
|
||||
<mesh name="base_link" file="base_link.stl" />
|
||||
<mesh name="left_outer_knuckle" file="left_outer_knuckle.stl" />
|
||||
<mesh name="left_finger" file="left_finger.stl" />
|
||||
<mesh name="left_inner_knuckle" file="left_inner_knuckle.stl" />
|
||||
<mesh name="right_outer_knuckle" file="right_outer_knuckle.stl" />
|
||||
<mesh name="right_finger" file="right_finger.stl" />
|
||||
<mesh name="right_inner_knuckle" file="right_inner_knuckle.stl" />
|
||||
</asset>
|
||||
|
||||
<equality>
|
||||
<weld body1="robot0:mocap2" body2="link7" solimp="0.9 0.95 0.001" solref="0.02 1"></weld>
|
||||
</equality>
|
||||
|
||||
<default>
|
||||
<joint armature="1" damping="0.1" limited="true"/>
|
||||
<default class="robot0:blue">
|
||||
<geom rgba="0.086 0.506 0.767 1.0"></geom>
|
||||
</default>
|
||||
|
||||
<default class="robot0:grey">
|
||||
<geom rgba="0.356 0.361 0.376 1.0"></geom>
|
||||
</default>
|
||||
</default>
|
||||
|
||||
</mujoco>
|
||||
88
lerobot/common/envs/simxarm/simxarm/tasks/assets/xarm.xml
Normal file
88
lerobot/common/envs/simxarm/simxarm/tasks/assets/xarm.xml
Normal file
@@ -0,0 +1,88 @@
|
||||
<mujoco model="xarm7">
|
||||
<body mocap="true" name="robot0:mocap2" pos="0 0 0">
|
||||
<geom conaffinity="0" contype="0" pos="0 0 0" rgba="0 0.5 0 0" size="0.005 0.005 0.005" type="box"></geom>
|
||||
<geom conaffinity="0" contype="0" pos="0 0 0" rgba="0.5 0 0 0" size="1 0.005 0.005" type="box"></geom>
|
||||
<geom conaffinity="0" contype="0" pos="0 0 0" rgba="0 0 0.5 0" size="0.005 1 0.001" type="box"></geom>
|
||||
<geom conaffinity="0" contype="0" pos="0 0 0" rgba="0.5 0.5 0 0" size="0.005 0.005 1" type="box"></geom>
|
||||
</body>
|
||||
|
||||
<body name="link0" pos="1.09 0.28 0.655">
|
||||
<geom name="bb" type="mesh" mesh="link_base" material="robot0:base_mat" rgba="1 1 1 1"/>
|
||||
<body name="link1" pos="0 0 0.267">
|
||||
<inertial pos="-0.0042142 0.02821 -0.0087788" quat="0.917781 -0.277115 0.0606681 0.277858" mass="0.42603" diaginertia="0.00144551 0.00137757 0.000823511" />
|
||||
<joint name="joint1" pos="0 0 0" axis="0 0 1" limited="true" range="-6.28319 6.28319" damping="10" frictionloss="1" />
|
||||
<geom name="j1" type="mesh" mesh="link1" material="robot0:arm_mat" rgba="1 1 1 1"/>
|
||||
<body name="link2" pos="0 0 0" quat="0.707105 -0.707108 0 0">
|
||||
<inertial pos="-3.3178e-05 -0.12849 0.026337" quat="0.447793 0.894132 -0.00224061 0.00218314" mass="0.56095" diaginertia="0.00319151 0.00311598 0.000980804" />
|
||||
<joint name="joint2" pos="0 0 0" axis="0 0 1" limited="true" range="-2.059 2.0944" damping="10" frictionloss="1" />
|
||||
<geom name="j2" type="mesh" mesh="link2" material="robot0:head_mat" rgba="1 1 1 1"/>
|
||||
<body name="link3" pos="0 -0.293 0" quat="0.707105 0.707108 0 0">
|
||||
<inertial pos="0.04223 -0.023258 -0.0096674" quat="0.883205 0.339803 0.323238 0.000542237" mass="0.44463" diaginertia="0.00133227 0.00119126 0.000780475" />
|
||||
<joint name="joint3" pos="0 0 0" axis="0 0 1" limited="true" range="-6.28319 6.28319" damping="5" frictionloss="1" />
|
||||
<geom name="j3" type="mesh" mesh="link3" material="robot0:gripper_mat" rgba="1 1 1 1"/>
|
||||
<body name="link4" pos="0.0525 0 0" quat="0.707105 0.707108 0 0">
|
||||
<inertial pos="0.067148 -0.10732 0.024479" quat="0.0654142 0.483317 -0.738663 0.465298" mass="0.52387" diaginertia="0.00288984 0.00282705 0.000894409" />
|
||||
<joint name="joint4" pos="0 0 0" axis="0 0 1" limited="true" range="-0.19198 3.927" damping="5" frictionloss="1" />
|
||||
<geom name="j4" type="mesh" mesh="link4" material="robot0:arm_mat" rgba="1 1 1 1"/>
|
||||
<body name="link5" pos="0.0775 -0.3425 0" quat="0.707105 0.707108 0 0">
|
||||
<inertial pos="-0.00023397 0.036705 -0.080064" quat="0.981064 -0.19003 0.00637998 0.0369004" mass="0.18554" diaginertia="0.00099553 0.000988613 0.000247126" />
|
||||
<joint name="joint5" pos="0 0 0" axis="0 0 1" limited="true" range="-6.28319 6.28319" damping="5" frictionloss="1" />
|
||||
<geom name="j5" type="mesh" material="robot0:gripper_mat" rgba="1 1 1 1" mesh="link5" />
|
||||
<body name="link6" pos="0 0 0" quat="0.707105 0.707108 0 0">
|
||||
<inertial pos="0.058911 0.028469 0.0068428" quat="-0.188705 0.793535 0.166088 0.554173" mass="0.31344" diaginertia="0.000827892 0.000768871 0.000386708" />
|
||||
<joint name="joint6" pos="0 0 0" axis="0 0 1" limited="true" range="-1.69297 3.14159" damping="2" frictionloss="1" />
|
||||
<geom name="j6" type="mesh" material="robot0:gripper_mat" rgba="1 1 1 1" mesh="link6" />
|
||||
<body name="link7" pos="0.076 0.097 0" quat="0.707105 -0.707108 0 0">
|
||||
<inertial pos="-0.000420033 -0.00287433 0.0257078" quat="0.999372 -0.0349129 -0.00605634 0.000551744" mass="0.85624" diaginertia="0.00137671 0.00118744 0.000514968" />
|
||||
<joint name="joint7" pos="0 0 0" axis="0 0 1" limited="true" range="-6.28319 6.28319" damping="2" frictionloss="1" />
|
||||
<geom name="j8" material="robot0:gripper_mat" type="mesh" rgba="0.753 0.753 0.753 1" mesh="link7" />
|
||||
<geom name="j9" material="robot0:gripper_mat" type="mesh" rgba="1 1 1 1" mesh="base_link" />
|
||||
<site name="grasp" pos="0 0 0.16" rgba="1 0 0 0" type="sphere" size="0.01" group="1"/>
|
||||
<body name="left_outer_knuckle" pos="0 0.035 0.059098">
|
||||
<inertial pos="0 0.021559 0.015181" quat="0.47789 0.87842 0 0" mass="0.033618" diaginertia="1.9111e-05 1.79089e-05 1.90167e-06" />
|
||||
<joint name="drive_joint" pos="0 0 0" axis="1 0 0" limited="true" range="0 0.85" />
|
||||
<geom type="mesh" rgba="0 0 0 1" conaffinity="1" contype="0" mesh="left_outer_knuckle" />
|
||||
<body name="left_finger" pos="0 0.035465 0.042039">
|
||||
<inertial pos="0 -0.016413 0.029258" quat="0.697634 0.115353 -0.115353 0.697634" mass="0.048304" diaginertia="1.88037e-05 1.7493e-05 3.56792e-06" />
|
||||
<joint name="left_finger_joint" pos="0 0 0" axis="-1 0 0" limited="true" range="0 0.85" />
|
||||
<geom name="j10" material="robot0:gripper_finger_mat" type="mesh" rgba="0 0 0 1" conaffinity="3" contype="2" mesh="left_finger" friction='1.5 1.5 1.5' solref='0.01 1' solimp='0.99 0.99 0.01'/>
|
||||
<body name="right_hand" pos="0 -0.03 0.05" quat="-0.7071 0 0 0.7071">
|
||||
<site name="ee" pos="0 0 0" rgba="0 0 1 0" type="sphere" group="1"/>
|
||||
<site name="ee_x" pos="0 0 0" size="0.005 .1" quat="0.707105 0.707108 0 0 " rgba="1 0 0 0" type="cylinder" group="1"/>
|
||||
<site name="ee_z" pos="0 0 0" size="0.005 .1" quat="0.707105 0 0 0.707108" rgba="0 0 1 0" type="cylinder" group="1"/>
|
||||
<site name="ee_y" pos="0 0 0" size="0.005 .1" quat="0.707105 0 0.707108 0 " rgba="0 1 0 0" type="cylinder" group="1"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
<body name="left_inner_knuckle" pos="0 0.02 0.074098">
|
||||
<inertial pos="1.86601e-06 0.0220468 0.0261335" quat="0.664139 -0.242732 0.242713 0.664146" mass="0.0230126" diaginertia="8.34216e-06 6.0949e-06 2.75601e-06" />
|
||||
<joint name="left_inner_knuckle_joint" pos="0 0 0" axis="1 0 0" limited="true" range="0 0.85" />
|
||||
<geom type="mesh" rgba="0 0 0 1" conaffinity="1" contype="0" mesh="left_inner_knuckle" friction='1.5 1.5 1.5' solref='0.01 1' solimp='0.99 0.99 0.01'/>
|
||||
</body>
|
||||
<body name="right_outer_knuckle" pos="0 -0.035 0.059098">
|
||||
<inertial pos="0 -0.021559 0.015181" quat="0.87842 0.47789 0 0" mass="0.033618" diaginertia="1.9111e-05 1.79089e-05 1.90167e-06" />
|
||||
<joint name="right_outer_knuckle_joint" pos="0 0 0" axis="-1 0 0" limited="true" range="0 0.85" />
|
||||
<geom type="mesh" rgba="0 0 0 1" conaffinity="1" contype="0" mesh="right_outer_knuckle" />
|
||||
<body name="right_finger" pos="0 -0.035465 0.042039">
|
||||
<inertial pos="0 0.016413 0.029258" quat="0.697634 -0.115356 0.115356 0.697634" mass="0.048304" diaginertia="1.88038e-05 1.7493e-05 3.56779e-06" />
|
||||
<joint name="right_finger_joint" pos="0 0 0" axis="1 0 0" limited="true" range="0 0.85" />
|
||||
<geom name="j11" material="robot0:gripper_finger_mat" type="mesh" rgba="0 0 0 1" conaffinity="3" contype="2" mesh="right_finger" friction='1.5 1.5 1.5' solref='0.01 1' solimp='0.99 0.99 0.01'/>
|
||||
<body name="left_hand" pos="0 0.03 0.05" quat="-0.7071 0 0 0.7071">
|
||||
<site name="ee_2" pos="0 0 0" rgba="1 0 0 0" type="sphere" size="0.01" group="1"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
<body name="right_inner_knuckle" pos="0 -0.02 0.074098">
|
||||
<inertial pos="1.866e-06 -0.022047 0.026133" quat="0.66415 0.242702 -0.242721 0.664144" mass="0.023013" diaginertia="8.34209e-06 6.0949e-06 2.75601e-06" />
|
||||
<joint name="right_inner_knuckle_joint" pos="0 0 0" axis="-1 0 0" limited="true" range="0 0.85" />
|
||||
<geom type="mesh" rgba="0 0 0 1" conaffinity="1" contype="0" mesh="right_inner_knuckle" friction='1.5 1.5 1.5' solref='0.01 1' solimp='0.99 0.99 0.01'/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</mujoco>
|
||||
145
lerobot/common/envs/simxarm/simxarm/tasks/base.py
Normal file
145
lerobot/common/envs/simxarm/simxarm/tasks/base.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import os
|
||||
|
||||
import mujoco
|
||||
import numpy as np
|
||||
from gymnasium_robotics.envs import robot_env
|
||||
|
||||
from lerobot.common.envs.simxarm.simxarm.tasks import mocap
|
||||
|
||||
|
||||
class Base(robot_env.MujocoRobotEnv):
|
||||
"""
|
||||
Superclass for all simxarm environments.
|
||||
Args:
|
||||
xml_name (str): name of the xml environment file
|
||||
gripper_rotation (list): initial rotation of the gripper (given as a quaternion)
|
||||
"""
|
||||
|
||||
def __init__(self, xml_name, gripper_rotation=None):
|
||||
if gripper_rotation is None:
|
||||
gripper_rotation = [0, 1, 0, 0]
|
||||
self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32)
|
||||
self.center_of_table = np.array([1.655, 0.3, 0.63625])
|
||||
self.max_z = 1.2
|
||||
self.min_z = 0.2
|
||||
super().__init__(
|
||||
model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"),
|
||||
n_substeps=20,
|
||||
n_actions=4,
|
||||
initial_qpos={},
|
||||
)
|
||||
|
||||
@property
|
||||
def dt(self):
|
||||
return self.n_substeps * self.model.opt.timestep
|
||||
|
||||
@property
|
||||
def eef(self):
|
||||
return self._utils.get_site_xpos(self.model, self.data, "grasp")
|
||||
|
||||
@property
|
||||
def obj(self):
|
||||
return self._utils.get_site_xpos(self.model, self.data, "object_site")
|
||||
|
||||
@property
|
||||
def robot_state(self):
|
||||
gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint")
|
||||
return np.concatenate([self.eef, gripper_angle])
|
||||
|
||||
def is_success(self):
|
||||
return NotImplementedError()
|
||||
|
||||
def get_reward(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _sample_goal(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_obs(self):
|
||||
return self._get_obs()
|
||||
|
||||
def _step_callback(self):
|
||||
self._mujoco.mj_forward(self.model, self.data)
|
||||
|
||||
def _limit_gripper(self, gripper_pos, pos_ctrl):
|
||||
if gripper_pos[0] > self.center_of_table[0] - 0.105 + 0.15:
|
||||
pos_ctrl[0] = min(pos_ctrl[0], 0)
|
||||
if gripper_pos[0] < self.center_of_table[0] - 0.105 - 0.3:
|
||||
pos_ctrl[0] = max(pos_ctrl[0], 0)
|
||||
if gripper_pos[1] > self.center_of_table[1] + 0.3:
|
||||
pos_ctrl[1] = min(pos_ctrl[1], 0)
|
||||
if gripper_pos[1] < self.center_of_table[1] - 0.3:
|
||||
pos_ctrl[1] = max(pos_ctrl[1], 0)
|
||||
if gripper_pos[2] > self.max_z:
|
||||
pos_ctrl[2] = min(pos_ctrl[2], 0)
|
||||
if gripper_pos[2] < self.min_z:
|
||||
pos_ctrl[2] = max(pos_ctrl[2], 0)
|
||||
return pos_ctrl
|
||||
|
||||
def _apply_action(self, action):
|
||||
assert action.shape == (4,)
|
||||
action = action.copy()
|
||||
pos_ctrl, gripper_ctrl = action[:3], action[3]
|
||||
pos_ctrl = self._limit_gripper(
|
||||
self._utils.get_site_xpos(self.model, self.data, "grasp"), pos_ctrl
|
||||
) * (1 / self.n_substeps)
|
||||
gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl])
|
||||
mocap.apply_action(
|
||||
self.model,
|
||||
self._model_names,
|
||||
self.data,
|
||||
np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl]),
|
||||
)
|
||||
|
||||
def _render_callback(self):
|
||||
self._mujoco.mj_forward(self.model, self.data)
|
||||
|
||||
def _reset_sim(self):
|
||||
self.data.time = self.initial_time
|
||||
self.data.qpos[:] = np.copy(self.initial_qpos)
|
||||
self.data.qvel[:] = np.copy(self.initial_qvel)
|
||||
self._sample_goal()
|
||||
self._mujoco.mj_step(self.model, self.data, nstep=10)
|
||||
return True
|
||||
|
||||
def _set_gripper(self, gripper_pos, gripper_rotation):
|
||||
self._utils.set_mocap_pos(self.model, self.data, "robot0:mocap", gripper_pos)
|
||||
self._utils.set_mocap_quat(self.model, self.data, "robot0:mocap", gripper_rotation)
|
||||
self._utils.set_joint_qpos(self.model, self.data, "right_outer_knuckle_joint", 0)
|
||||
self.data.qpos[10] = 0.0
|
||||
self.data.qpos[12] = 0.0
|
||||
|
||||
def _env_setup(self, initial_qpos):
|
||||
for name, value in initial_qpos.items():
|
||||
self.data.set_joint_qpos(name, value)
|
||||
mocap.reset(self.model, self.data)
|
||||
mujoco.mj_forward(self.model, self.data)
|
||||
self._sample_goal()
|
||||
mujoco.mj_forward(self.model, self.data)
|
||||
|
||||
def reset(self):
|
||||
self._reset_sim()
|
||||
return self._get_obs()
|
||||
|
||||
def step(self, action):
|
||||
assert action.shape == (4,)
|
||||
assert self.action_space.contains(action), "{!r} ({}) invalid".format(action, type(action))
|
||||
self._apply_action(action)
|
||||
self._mujoco.mj_step(self.model, self.data, nstep=2)
|
||||
self._step_callback()
|
||||
obs = self._get_obs()
|
||||
reward = self.get_reward()
|
||||
done = False
|
||||
info = {"is_success": self.is_success(), "success": self.is_success()}
|
||||
return obs, reward, done, info
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
self._render_callback()
|
||||
# HACK
|
||||
self.model.vis.global_.offwidth = width
|
||||
self.model.vis.global_.offheight = height
|
||||
return self.mujoco_renderer.render(mode)
|
||||
|
||||
def close(self):
|
||||
if self.mujoco_renderer is not None:
|
||||
self.mujoco_renderer.close()
|
||||
100
lerobot/common/envs/simxarm/simxarm/tasks/lift.py
Normal file
100
lerobot/common/envs/simxarm/simxarm/tasks/lift.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.envs.simxarm.simxarm import Base
|
||||
|
||||
|
||||
class Lift(Base):
|
||||
def __init__(self):
|
||||
self._z_threshold = 0.15
|
||||
super().__init__("lift")
|
||||
|
||||
@property
|
||||
def z_target(self):
|
||||
return self._init_z + self._z_threshold
|
||||
|
||||
def is_success(self):
|
||||
return self.obj[2] >= self.z_target
|
||||
|
||||
def get_reward(self):
|
||||
reach_dist = np.linalg.norm(self.obj - self.eef)
|
||||
reach_dist_xy = np.linalg.norm(self.obj[:-1] - self.eef[:-1])
|
||||
pick_completed = self.obj[2] >= (self.z_target - 0.01)
|
||||
obj_dropped = (self.obj[2] < (self._init_z + 0.005)) and (reach_dist > 0.02)
|
||||
|
||||
# Reach
|
||||
if reach_dist < 0.05:
|
||||
reach_reward = -reach_dist + max(self._action[-1], 0) / 50
|
||||
elif reach_dist_xy < 0.05:
|
||||
reach_reward = -reach_dist
|
||||
else:
|
||||
z_bonus = np.linalg.norm(np.linalg.norm(self.obj[-1] - self.eef[-1]))
|
||||
reach_reward = -reach_dist - 2 * z_bonus
|
||||
|
||||
# Pick
|
||||
if pick_completed and not obj_dropped:
|
||||
pick_reward = self.z_target
|
||||
elif (reach_dist < 0.1) and (self.obj[2] > (self._init_z + 0.005)):
|
||||
pick_reward = min(self.z_target, self.obj[2])
|
||||
else:
|
||||
pick_reward = 0
|
||||
|
||||
return reach_reward / 100 + pick_reward
|
||||
|
||||
def _get_obs(self):
|
||||
eef_velp = self._utils.get_site_xvelp(self.model, self.data, "grasp") * self.dt
|
||||
gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint")
|
||||
eef = self.eef - self.center_of_table
|
||||
|
||||
obj = self.obj - self.center_of_table
|
||||
obj_rot = self._utils.get_joint_qpos(self.model, self.data, "object_joint0")[-4:]
|
||||
obj_velp = self._utils.get_site_xvelp(self.model, self.data, "object_site") * self.dt
|
||||
obj_velr = self._utils.get_site_xvelr(self.model, self.data, "object_site") * self.dt
|
||||
|
||||
obs = np.concatenate(
|
||||
[
|
||||
eef,
|
||||
eef_velp,
|
||||
obj,
|
||||
obj_rot,
|
||||
obj_velp,
|
||||
obj_velr,
|
||||
eef - obj,
|
||||
np.array(
|
||||
[
|
||||
np.linalg.norm(eef - obj),
|
||||
np.linalg.norm(eef[:-1] - obj[:-1]),
|
||||
self.z_target,
|
||||
self.z_target - obj[-1],
|
||||
self.z_target - eef[-1],
|
||||
]
|
||||
),
|
||||
gripper_angle,
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": eef}
|
||||
|
||||
def _sample_goal(self):
|
||||
# Gripper
|
||||
gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3)
|
||||
super()._set_gripper(gripper_pos, self.gripper_rotation)
|
||||
|
||||
# Object
|
||||
object_pos = self.center_of_table - np.array([0.15, 0.10, 0.07])
|
||||
object_pos[0] += self.np_random.uniform(-0.05, 0.05, size=1)
|
||||
object_pos[1] += self.np_random.uniform(-0.05, 0.05, size=1)
|
||||
object_qpos = self._utils.get_joint_qpos(self.model, self.data, "object_joint0")
|
||||
object_qpos[:3] = object_pos
|
||||
self._utils.set_joint_qpos(self.model, self.data, "object_joint0", object_qpos)
|
||||
self._init_z = object_pos[2]
|
||||
|
||||
# Goal
|
||||
return object_pos + np.array([0, 0, self._z_threshold])
|
||||
|
||||
def reset(self):
|
||||
self._action = np.zeros(4)
|
||||
return super().reset()
|
||||
|
||||
def step(self, action):
|
||||
self._action = action.copy()
|
||||
return super().step(action)
|
||||
67
lerobot/common/envs/simxarm/simxarm/tasks/mocap.py
Normal file
67
lerobot/common/envs/simxarm/simxarm/tasks/mocap.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# import mujoco_py
|
||||
import mujoco
|
||||
import numpy as np
|
||||
|
||||
|
||||
def apply_action(model, model_names, data, action):
|
||||
if model.nmocap > 0:
|
||||
pos_action, gripper_action = np.split(action, (model.nmocap * 7,))
|
||||
if data.ctrl is not None:
|
||||
for i in range(gripper_action.shape[0]):
|
||||
data.ctrl[i] = gripper_action[i]
|
||||
pos_action = pos_action.reshape(model.nmocap, 7)
|
||||
pos_delta, quat_delta = pos_action[:, :3], pos_action[:, 3:]
|
||||
reset_mocap2body_xpos(model, model_names, data)
|
||||
data.mocap_pos[:] = data.mocap_pos + pos_delta
|
||||
data.mocap_quat[:] = data.mocap_quat + quat_delta
|
||||
|
||||
|
||||
def reset(model, data):
|
||||
if model.nmocap > 0 and model.eq_data is not None:
|
||||
for i in range(model.eq_data.shape[0]):
|
||||
# if sim.model.eq_type[i] == mujoco_py.const.EQ_WELD:
|
||||
if model.eq_type[i] == mujoco.mjtEq.mjEQ_WELD:
|
||||
# model.eq_data[i, :] = np.array([0., 0., 0., 1., 0., 0., 0.])
|
||||
model.eq_data[i, :] = np.array(
|
||||
[
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
1.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
]
|
||||
)
|
||||
# sim.forward()
|
||||
mujoco.mj_forward(model, data)
|
||||
|
||||
|
||||
def reset_mocap2body_xpos(model, model_names, data):
|
||||
if model.eq_type is None or model.eq_obj1id is None or model.eq_obj2id is None:
|
||||
return
|
||||
|
||||
# For all weld constraints
|
||||
for eq_type, obj1_id, obj2_id in zip(model.eq_type, model.eq_obj1id, model.eq_obj2id, strict=False):
|
||||
# if eq_type != mujoco_py.const.EQ_WELD:
|
||||
if eq_type != mujoco.mjtEq.mjEQ_WELD:
|
||||
continue
|
||||
# body2 = model.body_id2name(obj2_id)
|
||||
body2 = model_names.body_id2name[obj2_id]
|
||||
if body2 == "B0" or body2 == "B9" or body2 == "B1":
|
||||
continue
|
||||
mocap_id = model.body_mocapid[obj1_id]
|
||||
if mocap_id != -1:
|
||||
# obj1 is the mocap, obj2 is the welded body
|
||||
body_idx = obj2_id
|
||||
else:
|
||||
# obj2 is the mocap, obj1 is the welded body
|
||||
mocap_id = model.body_mocapid[obj2_id]
|
||||
body_idx = obj1_id
|
||||
assert mocap_id != -1
|
||||
data.mocap_pos[mocap_id][:] = data.xpos[body_idx]
|
||||
data.mocap_quat[mocap_id][:] = data.xquat[body_idx]
|
||||
86
lerobot/common/envs/simxarm/simxarm/tasks/peg_in_box.py
Normal file
86
lerobot/common/envs/simxarm/simxarm/tasks/peg_in_box.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.envs.simxarm.simxarm import Base
|
||||
|
||||
|
||||
class PegInBox(Base):
|
||||
def __init__(self):
|
||||
super().__init__("peg_in_box")
|
||||
|
||||
def _reset_sim(self):
|
||||
self._act_magnitude = 0
|
||||
super()._reset_sim()
|
||||
for _ in range(10):
|
||||
self._apply_action(np.array([0, 0, 0, 1], dtype=np.float32))
|
||||
self.sim.step()
|
||||
|
||||
@property
|
||||
def box(self):
|
||||
return self.sim.data.get_site_xpos("box_site")
|
||||
|
||||
def is_success(self):
|
||||
return np.linalg.norm(self.obj - self.box) <= 0.05
|
||||
|
||||
def get_reward(self):
|
||||
dist_xy = np.linalg.norm(self.obj[:2] - self.box[:2])
|
||||
dist_xyz = np.linalg.norm(self.obj - self.box)
|
||||
return float(dist_xy <= 0.045) * (2 - 6 * dist_xyz) - 0.2 * np.square(self._act_magnitude) - dist_xy
|
||||
|
||||
def _get_obs(self):
|
||||
eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt
|
||||
gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint")
|
||||
eef, box = self.eef - self.center_of_table, self.box - self.center_of_table
|
||||
|
||||
obj = self.obj - self.center_of_table
|
||||
obj_rot = self.sim.data.get_joint_qpos("object_joint0")[-4:]
|
||||
obj_velp = self.sim.data.get_site_xvelp("object_site") * self.dt
|
||||
obj_velr = self.sim.data.get_site_xvelr("object_site") * self.dt
|
||||
|
||||
obs = np.concatenate(
|
||||
[
|
||||
eef,
|
||||
eef_velp,
|
||||
box,
|
||||
obj,
|
||||
obj_rot,
|
||||
obj_velp,
|
||||
obj_velr,
|
||||
eef - box,
|
||||
eef - obj,
|
||||
obj - box,
|
||||
np.array(
|
||||
[
|
||||
np.linalg.norm(eef - box),
|
||||
np.linalg.norm(eef - obj),
|
||||
np.linalg.norm(obj - box),
|
||||
gripper_angle,
|
||||
]
|
||||
),
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": box}
|
||||
|
||||
def _sample_goal(self):
|
||||
# Gripper
|
||||
gripper_pos = np.array([1.280, 0.295, 0.9]) + self.np_random.uniform(-0.05, 0.05, size=3)
|
||||
super()._set_gripper(gripper_pos, self.gripper_rotation)
|
||||
|
||||
# Object
|
||||
object_pos = gripper_pos - np.array([0, 0, 0.06]) + self.np_random.uniform(-0.005, 0.005, size=3)
|
||||
object_qpos = self.sim.data.get_joint_qpos("object_joint0")
|
||||
object_qpos[:3] = object_pos
|
||||
self.sim.data.set_joint_qpos("object_joint0", object_qpos)
|
||||
|
||||
# Box
|
||||
box_pos = np.array([1.61, 0.18, 0.58])
|
||||
box_pos[:2] += self.np_random.uniform(-0.11, 0.11, size=2)
|
||||
box_qpos = self.sim.data.get_joint_qpos("box_joint0")
|
||||
box_qpos[:3] = box_pos
|
||||
self.sim.data.set_joint_qpos("box_joint0", box_qpos)
|
||||
|
||||
return self.box
|
||||
|
||||
def step(self, action):
|
||||
self._act_magnitude = np.linalg.norm(action[:3])
|
||||
return super().step(action)
|
||||
78
lerobot/common/envs/simxarm/simxarm/tasks/push.py
Normal file
78
lerobot/common/envs/simxarm/simxarm/tasks/push.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.envs.simxarm.simxarm import Base
|
||||
|
||||
|
||||
class Push(Base):
|
||||
def __init__(self):
|
||||
super().__init__("push")
|
||||
|
||||
def _reset_sim(self):
|
||||
self._act_magnitude = 0
|
||||
super()._reset_sim()
|
||||
|
||||
def is_success(self):
|
||||
return np.linalg.norm(self.obj - self.goal) <= 0.05
|
||||
|
||||
def get_reward(self):
|
||||
dist = np.linalg.norm(self.obj - self.goal)
|
||||
penalty = self._act_magnitude**2
|
||||
return -(dist + 0.15 * penalty)
|
||||
|
||||
def _get_obs(self):
|
||||
eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt
|
||||
gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint")
|
||||
eef, goal = self.eef - self.center_of_table, self.goal - self.center_of_table
|
||||
|
||||
obj = self.obj - self.center_of_table
|
||||
obj_rot = self.sim.data.get_joint_qpos("object_joint0")[-4:]
|
||||
obj_velp = self.sim.data.get_site_xvelp("object_site") * self.dt
|
||||
obj_velr = self.sim.data.get_site_xvelr("object_site") * self.dt
|
||||
|
||||
obs = np.concatenate(
|
||||
[
|
||||
eef,
|
||||
eef_velp,
|
||||
goal,
|
||||
obj,
|
||||
obj_rot,
|
||||
obj_velp,
|
||||
obj_velr,
|
||||
eef - goal,
|
||||
eef - obj,
|
||||
obj - goal,
|
||||
np.array(
|
||||
[
|
||||
np.linalg.norm(eef - goal),
|
||||
np.linalg.norm(eef - obj),
|
||||
np.linalg.norm(obj - goal),
|
||||
gripper_angle,
|
||||
]
|
||||
),
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": goal}
|
||||
|
||||
def _sample_goal(self):
|
||||
# Gripper
|
||||
gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3)
|
||||
super()._set_gripper(gripper_pos, self.gripper_rotation)
|
||||
|
||||
# Object
|
||||
object_pos = self.center_of_table - np.array([0.25, 0, 0.07])
|
||||
object_pos[0] += self.np_random.uniform(-0.08, 0.08, size=1)
|
||||
object_pos[1] += self.np_random.uniform(-0.08, 0.08, size=1)
|
||||
object_qpos = self.sim.data.get_joint_qpos("object_joint0")
|
||||
object_qpos[:3] = object_pos
|
||||
self.sim.data.set_joint_qpos("object_joint0", object_qpos)
|
||||
|
||||
# Goal
|
||||
self.goal = np.array([1.600, 0.200, 0.545])
|
||||
self.goal[:2] += self.np_random.uniform(-0.1, 0.1, size=2)
|
||||
self.sim.model.site_pos[self.sim.model.site_name2id("target0")] = self.goal
|
||||
return self.goal
|
||||
|
||||
def step(self, action):
|
||||
self._act_magnitude = np.linalg.norm(action[:3])
|
||||
return super().step(action)
|
||||
44
lerobot/common/envs/simxarm/simxarm/tasks/reach.py
Normal file
44
lerobot/common/envs/simxarm/simxarm/tasks/reach.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.envs.simxarm.simxarm import Base
|
||||
|
||||
|
||||
class Reach(Base):
|
||||
def __init__(self):
|
||||
super().__init__("reach")
|
||||
|
||||
def _reset_sim(self):
|
||||
self._act_magnitude = 0
|
||||
super()._reset_sim()
|
||||
|
||||
def is_success(self):
|
||||
return np.linalg.norm(self.eef - self.goal) <= 0.05
|
||||
|
||||
def get_reward(self):
|
||||
dist = np.linalg.norm(self.eef - self.goal)
|
||||
penalty = self._act_magnitude**2
|
||||
return -(dist + 0.15 * penalty)
|
||||
|
||||
def _get_obs(self):
|
||||
eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt
|
||||
gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint")
|
||||
eef, goal = self.eef - self.center_of_table, self.goal - self.center_of_table
|
||||
obs = np.concatenate(
|
||||
[eef, eef_velp, goal, eef - goal, np.array([np.linalg.norm(eef - goal), gripper_angle])], axis=0
|
||||
)
|
||||
return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": goal}
|
||||
|
||||
def _sample_goal(self):
|
||||
# Gripper
|
||||
gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3)
|
||||
super()._set_gripper(gripper_pos, self.gripper_rotation)
|
||||
|
||||
# Goal
|
||||
self.goal = np.array([1.550, 0.287, 0.580])
|
||||
self.goal[:2] += self.np_random.uniform(-0.125, 0.125, size=2)
|
||||
self.sim.model.site_pos[self.sim.model.site_name2id("target0")] = self.goal
|
||||
return self.goal
|
||||
|
||||
def step(self, action):
|
||||
self._act_magnitude = np.linalg.norm(action[:3])
|
||||
return super().step(action)
|
||||
@@ -5,6 +5,11 @@ from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
|
||||
def cfg_to_group(cfg, return_list=False):
|
||||
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
|
||||
@@ -26,6 +31,7 @@ class Logger:
|
||||
self._model_dir = self._log_dir / "models"
|
||||
self._buffer_dir = self._log_dir / "buffers"
|
||||
self._save_model = cfg.save_model
|
||||
self._disable_wandb_artifact = cfg.wandb.disable_artifact
|
||||
self._save_buffer = cfg.save_buffer
|
||||
self._group = cfg_to_group(cfg)
|
||||
self._seed = cfg.seed
|
||||
@@ -34,7 +40,7 @@ class Logger:
|
||||
project = cfg.get("wandb", {}).get("project")
|
||||
entity = cfg.get("wandb", {}).get("entity")
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
run_offline = not enable_wandb or not project or not entity
|
||||
run_offline = not enable_wandb or not project
|
||||
if run_offline:
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
self._wandb = None
|
||||
@@ -59,16 +65,18 @@ class Logger:
|
||||
resume=None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
def save_model(self, policy, identifier):
|
||||
def save_model(self, policy: AbstractPolicy, identifier):
|
||||
if self._save_model:
|
||||
self._model_dir.mkdir(parents=True, exist_ok=True)
|
||||
fp = self._model_dir / f"{str(identifier)}.pt"
|
||||
policy.save(fp)
|
||||
if self._wandb:
|
||||
policy.save_pretrained(fp)
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
self._group + "-" + str(self._seed) + "-" + str(identifier),
|
||||
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
||||
type="model",
|
||||
)
|
||||
artifact.add_file(fp)
|
||||
|
||||
93
lerobot/common/policies/abstract.py
Normal file
93
lerobot/common/policies/abstract.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
|
||||
class AbstractPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"""Base policy which all policies should be derived from.
|
||||
|
||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||
documentation for more information.
|
||||
|
||||
The policy is a PyTorchModelHubMixin, which means that it can be saved and loaded from the Hugging Face Hub and/or to a local directory.
|
||||
# Save policy weights to local directory
|
||||
>>> policy.save_pretrained("my-awesome-policy")
|
||||
|
||||
# Push policy weights to the Hub
|
||||
>>> policy.push_to_hub("my-awesome-policy")
|
||||
|
||||
# Download and initialize policy from the Hub
|
||||
>>> policy = MyPolicy.from_pretrained("username/my-awesome-policy")
|
||||
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||
|
||||
def __init__(self, n_action_steps: int | None = None):
|
||||
"""
|
||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||
adds that dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
||||
self.n_action_steps = n_action_steps
|
||||
self.clear_action_queue()
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
"""One step of the policy's learning algorithm."""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def save(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
def select_actions(self, observation) -> Tensor:
|
||||
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||
|
||||
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
||||
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
||||
"""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def clear_action_queue(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Tensor:
|
||||
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
||||
|
||||
WARNING: In general, this should not be overriden.
|
||||
|
||||
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
||||
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
||||
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
||||
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||
the subclass doesn't have to.
|
||||
|
||||
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
||||
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||
the action trajectory horizon and * is the action dimensions.
|
||||
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||
"""
|
||||
if self.n_action_steps is None:
|
||||
return self.select_actions(*args, **kwargs)
|
||||
if len(self._action_queue) == 0:
|
||||
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
||||
# (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
115
lerobot/common/policies/act/backbone.py
Normal file
115
lerobot/common/policies/act/backbone.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
|
||||
from .position_encoding import build_position_encoding
|
||||
from .utils import NestedTensor, is_main_process
|
||||
|
||||
|
||||
class FrozenBatchNorm2d(torch.nn.Module):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
|
||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||
produce nans.
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(n))
|
||||
self.register_buffer("bias", torch.zeros(n))
|
||||
self.register_buffer("running_mean", torch.zeros(n))
|
||||
self.register_buffer("running_var", torch.ones(n))
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
num_batches_tracked_key = prefix + "num_batches_tracked"
|
||||
if num_batches_tracked_key in state_dict:
|
||||
del state_dict[num_batches_tracked_key]
|
||||
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# move reshapes to the beginning
|
||||
# to make it fuser-friendly
|
||||
w = self.weight.reshape(1, -1, 1, 1)
|
||||
b = self.bias.reshape(1, -1, 1, 1)
|
||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||
eps = 1e-5
|
||||
scale = w * (rv + eps).rsqrt()
|
||||
bias = b - rm * scale
|
||||
return x * scale + bias
|
||||
|
||||
|
||||
class BackboneBase(nn.Module):
|
||||
def __init__(
|
||||
self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool
|
||||
):
|
||||
super().__init__()
|
||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||
# parameter.requires_grad_(False)
|
||||
if return_interm_layers:
|
||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||
else:
|
||||
return_layers = {"layer4": "0"}
|
||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, tensor):
|
||||
xs = self.body(tensor)
|
||||
return xs
|
||||
# out: Dict[str, NestedTensor] = {}
|
||||
# for name, x in xs.items():
|
||||
# m = tensor_list.mask
|
||||
# assert m is not None
|
||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||
# out[name] = NestedTensor(x, mask)
|
||||
# return out
|
||||
|
||||
|
||||
class Backbone(BackboneBase):
|
||||
"""ResNet backbone with frozen BatchNorm."""
|
||||
|
||||
def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=is_main_process(),
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
) # pretrained # TODO do we want frozen batch_norm??
|
||||
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
class Joiner(nn.Sequential):
|
||||
def __init__(self, backbone, position_embedding):
|
||||
super().__init__(backbone, position_embedding)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
xs = self[0](tensor_list)
|
||||
out: List[NestedTensor] = []
|
||||
pos = []
|
||||
for _, x in xs.items():
|
||||
out.append(x)
|
||||
# position encoding
|
||||
pos.append(self[1](x).to(x.dtype))
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
def build_backbone(args):
|
||||
position_embedding = build_position_encoding(args)
|
||||
train_backbone = args.lr_backbone > 0
|
||||
return_interm_layers = args.masks
|
||||
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
||||
model = Joiner(backbone, position_embedding)
|
||||
model.num_channels = backbone.num_channels
|
||||
return model
|
||||
212
lerobot/common/policies/act/detr_vae.py
Normal file
212
lerobot/common/policies/act/detr_vae.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
from .backbone import build_backbone
|
||||
from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer
|
||||
|
||||
|
||||
def reparametrize(mu, logvar):
|
||||
std = logvar.div(2).exp()
|
||||
eps = Variable(std.data.new(std.size()).normal_())
|
||||
return mu + std * eps
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
|
||||
class DETRVAE(nn.Module):
|
||||
"""This is the DETR module that performs object detection"""
|
||||
|
||||
def __init__(
|
||||
self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae
|
||||
):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
self.vae = vae
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
else:
|
||||
# input_dim = 14 + 7 # robot_state + env_state
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
# TODO(rcadene): understand what is env_state, and why it needs to be 7
|
||||
self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim)
|
||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||
self.backbones = None
|
||||
|
||||
# encoder extra parameters
|
||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
|
||||
self.latent_proj = nn.Linear(
|
||||
hidden_dim, self.latent_dim * 2
|
||||
) # project hidden state to latent std, var
|
||||
self.register_buffer(
|
||||
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
|
||||
) # [CLS], qpos, a_seq
|
||||
|
||||
# decoder extra parameters
|
||||
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
||||
self.additional_pos_embed = nn.Embedding(
|
||||
2, hidden_dim
|
||||
) # learned position embedding for proprio and latent
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
### Obtain latent z from action sequence
|
||||
if self.vae and is_training:
|
||||
# project action sequence to embedding dim, and concat with a CLS token
|
||||
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
|
||||
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
||||
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
||||
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
|
||||
encoder_input = torch.cat(
|
||||
[cls_embed, qpos_embed, action_embed], axis=1
|
||||
) # (bs, seq+1, hidden_dim)
|
||||
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
|
||||
# do not mask cls token
|
||||
# cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
|
||||
# is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||
# obtain position embedding
|
||||
pos_embed = self.pos_table.clone().detach()
|
||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||
# query model
|
||||
encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad)
|
||||
encoder_output = encoder_output[0] # take cls output only
|
||||
latent_info = self.latent_proj(encoder_output)
|
||||
mu = latent_info[:, : self.latent_dim]
|
||||
logvar = latent_info[:, self.latent_dim :]
|
||||
latent_sample = reparametrize(mu, logvar)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
else:
|
||||
mu = logvar = None
|
||||
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
if self.backbones is not None:
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
all_cam_pos = []
|
||||
for cam_id, _ in enumerate(self.camera_names):
|
||||
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0]
|
||||
all_cam_features.append(self.input_proj(features))
|
||||
all_cam_pos.append(pos)
|
||||
# proprioception features
|
||||
proprio_input = self.input_proj_robot_state(qpos)
|
||||
# fold camera dimension into width dimension
|
||||
src = torch.cat(all_cam_features, axis=3)
|
||||
pos = torch.cat(all_cam_pos, axis=3)
|
||||
hs = self.transformer(
|
||||
src,
|
||||
None,
|
||||
self.query_embed.weight,
|
||||
pos,
|
||||
latent_input,
|
||||
proprio_input,
|
||||
self.additional_pos_embed.weight,
|
||||
)[0]
|
||||
else:
|
||||
qpos = self.input_proj_robot_state(qpos)
|
||||
env_state = self.input_proj_env_state(env_state)
|
||||
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
|
||||
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
|
||||
a_hat = self.action_head(hs)
|
||||
is_pad_hat = self.is_pad_head(hs)
|
||||
return a_hat, is_pad_hat, [mu, logvar]
|
||||
|
||||
|
||||
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||
if hidden_depth == 0:
|
||||
mods = [nn.Linear(input_dim, output_dim)]
|
||||
else:
|
||||
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
for _ in range(hidden_depth - 1):
|
||||
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
mods.append(nn.Linear(hidden_dim, output_dim))
|
||||
trunk = nn.Sequential(*mods)
|
||||
return trunk
|
||||
|
||||
|
||||
def build_encoder(args):
|
||||
d_model = args.hidden_dim # 256
|
||||
dropout = args.dropout # 0.1
|
||||
nhead = args.nheads # 8
|
||||
dim_feedforward = args.dim_feedforward # 2048
|
||||
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
|
||||
normalize_before = args.pre_norm # False
|
||||
activation = "relu"
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def build(args):
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
transformer = build_transformer(args)
|
||||
|
||||
encoder = build_encoder(args)
|
||||
|
||||
model = DETRVAE(
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim=args.state_dim,
|
||||
action_dim=args.action_dim,
|
||||
num_queries=args.num_queries,
|
||||
camera_names=args.camera_names,
|
||||
vae=args.vae,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
|
||||
|
||||
return model
|
||||
216
lerobot/common/policies/act/policy.py
Normal file
216
lerobot/common/policies/act/policy.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.act.detr_vae import build
|
||||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
||||
def build_act_model_and_optimizer(cfg):
|
||||
model = build(cfg)
|
||||
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
||||
"lr": cfg.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def kl_divergence(mu, logvar):
|
||||
batch_size = mu.size(0)
|
||||
assert batch_size != 0
|
||||
if mu.data.ndimension() == 4:
|
||||
mu = mu.view(mu.size(0), mu.size(1))
|
||||
if logvar.data.ndimension() == 4:
|
||||
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
||||
|
||||
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
||||
total_kld = klds.sum(1).mean(0, True)
|
||||
dimension_wise_kld = klds.mean(0)
|
||||
mean_kld = klds.mean(1).mean(0, True)
|
||||
|
||||
return total_kld, dimension_wise_kld, mean_kld
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
name = "act"
|
||||
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
super().__init__(n_action_steps)
|
||||
self.cfg = cfg
|
||||
self.n_action_steps = n_action_steps
|
||||
self.device = get_safe_torch_device(device)
|
||||
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
||||
self.kl_weight = self.cfg.kl_weight
|
||||
logging.info(f"KL Weight {self.kl_weight}")
|
||||
self.to(self.device)
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
del step
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
self.train()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
def process_batch(batch, horizon, num_slices):
|
||||
# trajectory t = 64, horizon h = 16
|
||||
# (t h) ... -> t h ...
|
||||
batch = batch.reshape(num_slices, horizon)
|
||||
|
||||
image = batch["observation", "image", "top"]
|
||||
image = image[:, 0] # first observation t=0
|
||||
# batch, num_cam, channel, height, width
|
||||
image = image.unsqueeze(1)
|
||||
assert image.ndim == 5
|
||||
image = image.float()
|
||||
|
||||
state = batch["observation", "state"]
|
||||
state = state[:, 0] # first observation t=0
|
||||
# batch, qpos_dim
|
||||
assert state.ndim == 2
|
||||
|
||||
action = batch["action"]
|
||||
# batch, seq, action_dim
|
||||
assert action.ndim == 3
|
||||
assert action.shape[1] == horizon
|
||||
|
||||
if self.cfg.n_obs_steps > 1:
|
||||
raise NotImplementedError()
|
||||
# # keep first n observations of the slice corresponding to t=[-1,0]
|
||||
# image = image[:, : self.cfg.n_obs_steps]
|
||||
# state = state[:, : self.cfg.n_obs_steps]
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": image.to(self.device, non_blocking=True),
|
||||
"agent_pos": state.to(self.device, non_blocking=True),
|
||||
},
|
||||
"action": action.to(self.device, non_blocking=True),
|
||||
}
|
||||
return out
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
||||
loss = self.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
# self.lr_scheduler.step()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
# "lr": self.lr_scheduler.get_last_lr()[0],
|
||||
"lr": self.cfg.lr,
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp, device=None):
|
||||
d = torch.load(fp, map_location=device)
|
||||
self.load_state_dict(d)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
loss_dict = self._forward(
|
||||
qpos=batch["obs"]["agent_pos"],
|
||||
image=batch["obs"]["image"],
|
||||
actions=batch["action"],
|
||||
)
|
||||
loss = loss_dict["loss"]
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def select_actions(self, observation, step_count):
|
||||
if observation["image"].shape[0] != 1:
|
||||
raise NotImplementedError("Batch size > 1 not handled")
|
||||
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
self.eval()
|
||||
|
||||
# TODO(rcadene): remove hack
|
||||
# add 1 camera dimension
|
||||
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
|
||||
|
||||
obs_dict = {
|
||||
"image": observation["image", "top"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
|
||||
|
||||
if self.cfg.temporal_agg:
|
||||
# TODO(rcadene): implement temporal aggregation
|
||||
raise NotImplementedError()
|
||||
# all_time_actions[[t], t:t+num_queries] = action
|
||||
# actions_for_curr_step = all_time_actions[:, t]
|
||||
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
|
||||
# actions_for_curr_step = actions_for_curr_step[actions_populated]
|
||||
# k = 0.01
|
||||
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
||||
# exp_weights = exp_weights / exp_weights.sum()
|
||||
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
||||
|
||||
# take first predicted action or n first actions
|
||||
action = action[: self.n_action_steps]
|
||||
return action
|
||||
|
||||
def _forward(self, qpos, image, actions=None, is_pad=None):
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
|
||||
is_training = actions is not None
|
||||
if is_training: # training time
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
if is_pad is not None:
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
|
||||
loss_dict = {}
|
||||
loss_dict["l1"] = l1
|
||||
if self.cfg.vae:
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
loss_dict["kl"] = total_kld[0]
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = loss_dict["l1"]
|
||||
return loss_dict
|
||||
else:
|
||||
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
return action
|
||||
102
lerobot/common/policies/act/position_encoding.py
Normal file
102
lerobot/common/policies/act/position_encoding.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Various positional encodings for the transformer.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .utils import NestedTensor
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, tensor):
|
||||
x = tensor
|
||||
# mask = tensor_list.mask
|
||||
# assert mask is not None
|
||||
# not_mask = ~mask
|
||||
|
||||
not_mask = torch.ones_like(x[0, [0]])
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Module):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.uniform_(self.row_embed.weight)
|
||||
nn.init.uniform_(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
h, w = x.shape[-2:]
|
||||
i = torch.arange(w, device=x.device)
|
||||
j = torch.arange(h, device=x.device)
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
pos = (
|
||||
torch.cat(
|
||||
[
|
||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
.permute(2, 0, 1)
|
||||
.unsqueeze(0)
|
||||
.repeat(x.shape[0], 1, 1, 1)
|
||||
)
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(args):
|
||||
n_steps = args.hidden_dim // 2
|
||||
if args.position_embedding in ("v2", "sine"):
|
||||
# TODO find a better way of exposing other arguments
|
||||
position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
|
||||
elif args.position_embedding in ("v3", "learned"):
|
||||
position_embedding = PositionEmbeddingLearned(n_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {args.position_embedding}")
|
||||
|
||||
return position_embedding
|
||||
371
lerobot/common/policies/act/transformer.py
Normal file
371
lerobot/common/policies/act/transformer.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""
|
||||
DETR Transformer class.
|
||||
|
||||
Copy-paste from torch.nn.Transformer with modifications:
|
||||
* positional encodings are passed in MHattention
|
||||
* extra LN at the end of encoder is removed
|
||||
* decoder returns a stack of activations from all decoding layers
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
return_intermediate_dec=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoder(
|
||||
decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask,
|
||||
query_embed,
|
||||
pos_embed,
|
||||
latent_input=None,
|
||||
proprio_input=None,
|
||||
additional_pos_embed=None,
|
||||
):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
bs, c, h, w = src.shape
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
# flatten NxHWxC to HWxNxC
|
||||
bs, hw, c = src.shape
|
||||
src = src.permute(1, 0, 2)
|
||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
tgt = torch.zeros_like(query_embed)
|
||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
|
||||
hs = hs.transpose(1, 2)
|
||||
return hs
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = tgt
|
||||
|
||||
intermediate = []
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos,
|
||||
query_pos=query_pos,
|
||||
)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(self.norm(output))
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
if self.return_intermediate:
|
||||
intermediate.pop()
|
||||
intermediate.append(output)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(src, pos)
|
||||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt2, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
return self.forward_post(
|
||||
tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
|
||||
)
|
||||
|
||||
|
||||
def _get_clones(module, n):
|
||||
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
|
||||
|
||||
|
||||
def build_transformer(args):
|
||||
return Transformer(
|
||||
d_model=args.hidden_dim,
|
||||
dropout=args.dropout,
|
||||
nhead=args.nheads,
|
||||
dim_feedforward=args.dim_feedforward,
|
||||
num_encoder_layers=args.enc_layers,
|
||||
num_decoder_layers=args.dec_layers,
|
||||
normalize_before=args.pre_norm,
|
||||
return_intermediate_dec=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
478
lerobot/common/policies/act/utils.py
Normal file
478
lerobot/common/policies/act/utils.py
Normal file
@@ -0,0 +1,478 @@
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||
import torchvision
|
||||
from packaging import version
|
||||
from torch import Tensor
|
||||
|
||||
if version.parse(torchvision.__version__) < version.parse("0.7"):
|
||||
from torchvision.ops import _new_empty_tensor
|
||||
from torchvision.ops.misc import _output_size
|
||||
|
||||
|
||||
class SmoothedValue:
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
||||
)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list, strict=False):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that all processes
|
||||
have the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416
|
||||
return reduced_dict
|
||||
|
||||
|
||||
class MetricLogger:
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append("{}: {}".format(name, str(meter)))
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
if not header:
|
||||
header = ""
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join(
|
||||
[
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
"max mem: {memory:.0f}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = self.delimiter.join(
|
||||
[
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
]
|
||||
)
|
||||
mega_b = 1024.0 * 1024.0
|
||||
for i, obj in enumerate(iterable):
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / mega_b,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
||||
|
||||
sha = "N/A"
|
||||
diff = "clean"
|
||||
branch = "N/A"
|
||||
try:
|
||||
sha = _run(["git", "rev-parse", "HEAD"])
|
||||
subprocess.check_output(["git", "diff"], cwd=cwd)
|
||||
diff = _run(["git", "diff-index", "HEAD"])
|
||||
diff = "has uncommited changes" if diff else "clean"
|
||||
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
batch = list(zip(*batch, strict=False))
|
||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||
return tuple(batch)
|
||||
|
||||
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
# type: (Device) -> NestedTensor # noqa
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
assert mask is not None
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
# TODO make this more general
|
||||
if tensor_list[0].ndim == 3:
|
||||
if torchvision._is_tracing():
|
||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||
|
||||
# TODO make it support different-sized images
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
b, c, h, w = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], : img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError("not supported")
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||
@torch.jit.unused
|
||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||
max_size = []
|
||||
for i in range(tensor_list[0].dim()):
|
||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(
|
||||
torch.int64
|
||||
)
|
||||
max_size.append(max_size_i)
|
||||
max_size = tuple(max_size)
|
||||
|
||||
# work around for
|
||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
# m[: img.shape[1], :img.shape[2]] = False
|
||||
# which is not yet supported in onnx
|
||||
padded_imgs = []
|
||||
padded_masks = []
|
||||
for img in tensor_list:
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
|
||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||
padded_imgs.append(padded_img)
|
||||
|
||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||
padded_masks.append(padded_mask.to(torch.bool))
|
||||
|
||||
tensor = torch.stack(padded_imgs)
|
||||
mask = torch.stack(padded_masks)
|
||||
|
||||
return NestedTensor(tensor, mask=mask)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop("force", False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||
elif "SLURM_PROCID" in os.environ:
|
||||
args.rank = int(os.environ["SLURM_PROCID"])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print("Not using distributed mode")
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = "nccl"
|
||||
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
||||
)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if target.numel() == 0:
|
||||
return [torch.zeros([], device=output.device)]
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||
"""
|
||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||
This will eventually be supported natively by PyTorch, and this
|
||||
class can go away.
|
||||
"""
|
||||
if version.parse(torchvision.__version__) < version.parse("0.7"):
|
||||
if input.numel() > 0:
|
||||
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
|
||||
output_shape = _output_size(2, input, size, scale_factor)
|
||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||
return _new_empty_tensor(input, output_shape)
|
||||
else:
|
||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
@@ -1,3 +1,44 @@
|
||||
"""Code from the original diffusion policy project.
|
||||
|
||||
Notes on how to load a checkpoint from the original repository:
|
||||
|
||||
In the original repository, run the eval and use a breakpoint to extract the policy weights.
|
||||
|
||||
```
|
||||
torch.save(policy.state_dict(), "weights.pt")
|
||||
```
|
||||
|
||||
In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights:
|
||||
|
||||
```
|
||||
loaded = torch.load("weights.pt")
|
||||
aligned = {}
|
||||
their_prefix = "obs_encoder.obs_nets.image.backbone"
|
||||
our_prefix = "obs_encoder.key_model_map.image.backbone"
|
||||
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
their_prefix = "obs_encoder.obs_nets.image.pool"
|
||||
our_prefix = "obs_encoder.key_model_map.image.pool"
|
||||
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
their_prefix = "obs_encoder.obs_nets.image.nets.3"
|
||||
our_prefix = "obs_encoder.key_model_map.image.out"
|
||||
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
|
||||
# Note: here you are loading into the ema model.
|
||||
missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False)
|
||||
assert all('_dummy_variable' in k for k in missing_keys)
|
||||
assert len(unexpected_keys) == 0
|
||||
```
|
||||
|
||||
Then in that same runtime you can also save the weights with the new aligned state_dict:
|
||||
|
||||
```
|
||||
policy.save_pretrained("my-policy")
|
||||
```
|
||||
|
||||
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
@@ -5,11 +46,33 @@ import torch.nn.functional as F # noqa: N812
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from einops import reduce
|
||||
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
|
||||
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D
|
||||
from lerobot.common.policies.diffusion.model.mask_generator import LowdimMaskGenerator
|
||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
from lerobot.common.policies.diffusion.model.normalizer import LinearNormalizer
|
||||
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
|
||||
|
||||
|
||||
class BaseImagePolicy(ModuleAttrMixin):
|
||||
# init accepts keyword argument shape_meta, see config/task/*_image.yaml
|
||||
|
||||
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
obs_dict:
|
||||
str: B,To,*
|
||||
return: B,Ta,Da
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
# reset state for stateful policies
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
# ========== training ===========
|
||||
# no standard training interface except setting normalizer
|
||||
def set_normalizer(self, normalizer: LinearNormalizer):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiffusionUnetImagePolicy(BaseImagePolicy):
|
||||
@@ -168,11 +231,10 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
|
||||
|
||||
# run sampling
|
||||
nsample = self.conditional_sample(
|
||||
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs
|
||||
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond
|
||||
)
|
||||
|
||||
action_pred = nsample[..., :action_dim]
|
||||
|
||||
# get action
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.n_action_steps
|
||||
|
||||
286
lerobot/common/policies/diffusion/model/conditional_unet1d.py
Normal file
286
lerobot/common/policies/diffusion/model/conditional_unet1d.py
Normal file
@@ -0,0 +1,286 @@
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d
|
||||
from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConditionalResidualBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
|
||||
]
|
||||
)
|
||||
|
||||
# FiLM modulation https://arxiv.org/abs/1709.07871
|
||||
# predicts per-channel scale and bias
|
||||
cond_channels = out_channels
|
||||
if cond_predict_scale:
|
||||
cond_channels = out_channels * 2
|
||||
self.cond_predict_scale = cond_predict_scale
|
||||
self.out_channels = out_channels
|
||||
self.cond_encoder = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(cond_dim, cond_channels),
|
||||
Rearrange("batch t -> batch t 1"),
|
||||
)
|
||||
|
||||
# make sure dimensions compatible
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, cond):
|
||||
"""
|
||||
x : [ batch_size x in_channels x horizon ]
|
||||
cond : [ batch_size x cond_dim]
|
||||
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
"""
|
||||
out = self.blocks[0](x)
|
||||
embed = self.cond_encoder(cond)
|
||||
if self.cond_predict_scale:
|
||||
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
|
||||
scale = embed[:, 0, ...]
|
||||
bias = embed[:, 1, ...]
|
||||
out = scale * out + bias
|
||||
else:
|
||||
out = out + embed
|
||||
out = self.blocks[1](out)
|
||||
out = out + self.residual_conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalUnet1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
local_cond_dim=None,
|
||||
global_cond_dim=None,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=None,
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=False,
|
||||
):
|
||||
super().__init__()
|
||||
if down_dims is None:
|
||||
down_dims = [256, 512, 1024]
|
||||
|
||||
all_dims = [input_dim] + list(down_dims)
|
||||
start_dim = down_dims[0]
|
||||
|
||||
dsed = diffusion_step_embed_dim
|
||||
diffusion_step_encoder = nn.Sequential(
|
||||
SinusoidalPosEmb(dsed),
|
||||
nn.Linear(dsed, dsed * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dsed * 4, dsed),
|
||||
)
|
||||
cond_dim = dsed
|
||||
if global_cond_dim is not None:
|
||||
cond_dim += global_cond_dim
|
||||
|
||||
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
|
||||
|
||||
local_cond_encoder = None
|
||||
if local_cond_dim is not None:
|
||||
_, dim_out = in_out[0]
|
||||
dim_in = local_cond_dim
|
||||
local_cond_encoder = nn.ModuleList(
|
||||
[
|
||||
# down encoder
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
# up encoder
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
mid_dim = all_dims[-1]
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
down_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out,
|
||||
dim_out,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
up_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
up_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out * 2,
|
||||
dim_in,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_in,
|
||||
cond_dim=cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
final_conv = nn.Sequential(
|
||||
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
|
||||
nn.Conv1d(start_dim, input_dim, 1),
|
||||
)
|
||||
|
||||
self.diffusion_step_encoder = diffusion_step_encoder
|
||||
self.local_cond_encoder = local_cond_encoder
|
||||
self.up_modules = up_modules
|
||||
self.down_modules = down_modules
|
||||
self.final_conv = final_conv
|
||||
|
||||
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
local_cond=None,
|
||||
global_cond=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
x: (B,T,input_dim)
|
||||
timestep: (B,) or int, diffusion step
|
||||
local_cond: (B,T,local_cond_dim)
|
||||
global_cond: (B,global_cond_dim)
|
||||
output: (B,T,input_dim)
|
||||
"""
|
||||
sample = einops.rearrange(sample, "b h t -> b t h")
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
global_feature = self.diffusion_step_encoder(timesteps)
|
||||
|
||||
if global_cond is not None:
|
||||
global_feature = torch.cat([global_feature, global_cond], axis=-1)
|
||||
|
||||
# encode local features
|
||||
h_local = []
|
||||
if local_cond is not None:
|
||||
local_cond = einops.rearrange(local_cond, "b h t -> b t h")
|
||||
resnet, resnet2 = self.local_cond_encoder
|
||||
x = resnet(local_cond, global_feature)
|
||||
h_local.append(x)
|
||||
x = resnet2(local_cond, global_feature)
|
||||
h_local.append(x)
|
||||
|
||||
x = sample
|
||||
h = []
|
||||
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
|
||||
x = resnet(x, global_feature)
|
||||
if idx == 0 and len(h_local) > 0:
|
||||
x = x + h_local[0]
|
||||
x = resnet2(x, global_feature)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
for mid_module in self.mid_modules:
|
||||
x = mid_module(x, global_feature)
|
||||
|
||||
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
# The correct condition should be:
|
||||
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
|
||||
# However this change will break compatibility with published checkpoints.
|
||||
# Therefore it is left as a comment.
|
||||
if idx == len(self.up_modules) and len(h_local) > 0:
|
||||
x = x + h_local[1]
|
||||
x = resnet2(x, global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
x = einops.rearrange(x, "b t h -> b h t")
|
||||
return x
|
||||
47
lerobot/common/policies/diffusion/model/conv1d_components.py
Normal file
47
lerobot/common/policies/diffusion/model/conv1d_components.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch.nn as nn
|
||||
|
||||
# from einops.layers.torch import Rearrange
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
# def test():
|
||||
# cb = Conv1dBlock(256, 128, kernel_size=3)
|
||||
# x = torch.zeros((1,256,16))
|
||||
# o = cb(x)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user