forked from tangger/lerobot
Compare commits
3 Commits
qgallouede
...
thom-propo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1e47202c0 | ||
|
|
24821fee24 | ||
|
|
4751642ace |
54
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
54
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -1,54 +0,0 @@
|
|||||||
name: "\U0001F41B Bug Report"
|
|
||||||
description: Submit a bug report to help us improve LeRobot
|
|
||||||
body:
|
|
||||||
- type: markdown
|
|
||||||
attributes:
|
|
||||||
value: |
|
|
||||||
Thanks for taking the time to submit a bug report! 🐛
|
|
||||||
If this is not a bug related to the LeRobot library directly, but instead a general question about your code or the library specifically please use our [discord](https://discord.gg/s3KuuzsPFb).
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
id: system-info
|
|
||||||
attributes:
|
|
||||||
label: System Info
|
|
||||||
description: If needed, you can share your lerobot configuration with us by running `python -m lerobot.scripts.display_sys_info` and copy-pasting its outputs below
|
|
||||||
render: Shell
|
|
||||||
placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
|
|
||||||
- type: checkboxes
|
|
||||||
id: information-scripts-examples
|
|
||||||
attributes:
|
|
||||||
label: Information
|
|
||||||
description: 'The problem arises when using:'
|
|
||||||
options:
|
|
||||||
- label: "One of the scripts in the examples/ folder of LeRobot"
|
|
||||||
- label: "My own task or dataset (give details below)"
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
id: reproduction
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
attributes:
|
|
||||||
label: Reproduction
|
|
||||||
description: |
|
|
||||||
If needed, provide a simple code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
|
|
||||||
Sharing error messages or stack traces could be useful as well!
|
|
||||||
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
|
||||||
Try to avoid screenshots, as they are hard to read and don't allow copy-and-pasting.
|
|
||||||
|
|
||||||
placeholder: |
|
|
||||||
Steps to reproduce the behavior:
|
|
||||||
|
|
||||||
1.
|
|
||||||
2.
|
|
||||||
3.
|
|
||||||
|
|
||||||
- type: textarea
|
|
||||||
id: expected-behavior
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
attributes:
|
|
||||||
label: Expected behavior
|
|
||||||
description: "A clear and concise description of what you would expect to happen."
|
|
||||||
15
.github/PULL_REQUEST_TEMPLATE.md
vendored
15
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,15 +0,0 @@
|
|||||||
# What does this PR do?
|
|
||||||
|
|
||||||
Example: Fixes # (issue)
|
|
||||||
|
|
||||||
|
|
||||||
## Before submitting
|
|
||||||
- Read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
|
|
||||||
- Provide a minimal code example for the reviewer to checkout & try.
|
|
||||||
- Explain how you tested your changes.
|
|
||||||
|
|
||||||
|
|
||||||
## Who can review?
|
|
||||||
|
|
||||||
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
|
|
||||||
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.
|
|
||||||
1076
.github/poetry/cpu/poetry.lock
generated
vendored
1076
.github/poetry/cpu/poetry.lock
generated
vendored
File diff suppressed because it is too large
Load Diff
58
.github/poetry/cpu/pyproject.toml
vendored
58
.github/poetry/cpu/pyproject.toml
vendored
@@ -1,25 +1,19 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "lerobot"
|
name = "lerobot"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
description = "Le robot is learning"
|
||||||
authors = [
|
authors = [
|
||||||
"Rémi Cadène <re.cadene@gmail.com>",
|
"Rémi Cadène <re.cadene@gmail.com>",
|
||||||
"Alexander Soare <alexander.soare159@gmail.com>",
|
|
||||||
"Quentin Gallouédec <quentin.gallouedec@ec-lyon.fr>",
|
|
||||||
"Simon Alibert <alibert.sim@gmail.com>",
|
"Simon Alibert <alibert.sim@gmail.com>",
|
||||||
"Thomas Wolf <thomaswolfcontact@gmail.com>",
|
|
||||||
]
|
]
|
||||||
repository = "https://github.com/huggingface/lerobot"
|
repository = "https://github.com/Cadene/lerobot"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "MIT"
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"Intended Audience :: Education",
|
|
||||||
"Intended Audience :: Science/Research",
|
|
||||||
"Topic :: Software Development :: Build Tools",
|
"Topic :: Software Development :: Build Tools",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"License :: OSI Approved :: MIT License",
|
||||||
"License :: OSI Approved :: Apache Software License",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
]
|
]
|
||||||
packages = [{include = "lerobot"}]
|
packages = [{include = "lerobot"}]
|
||||||
@@ -29,39 +23,43 @@ packages = [{include = "lerobot"}]
|
|||||||
python = "^3.10"
|
python = "^3.10"
|
||||||
termcolor = "^2.4.0"
|
termcolor = "^2.4.0"
|
||||||
omegaconf = "^2.3.0"
|
omegaconf = "^2.3.0"
|
||||||
|
dm-env = "^1.6"
|
||||||
|
pandas = "^2.2.1"
|
||||||
wandb = "^0.16.3"
|
wandb = "^0.16.3"
|
||||||
imageio = {extras = ["ffmpeg"], version = "^2.34.0"}
|
moviepy = "^1.0.3"
|
||||||
|
imageio = {extras = ["pyav"], version = "^2.34.0"}
|
||||||
gdown = "^5.1.0"
|
gdown = "^5.1.0"
|
||||||
hydra-core = "^1.3.2"
|
hydra-core = "^1.3.2"
|
||||||
einops = "^0.7.0"
|
einops = "^0.7.0"
|
||||||
|
pygame = "^2.5.2"
|
||||||
pymunk = "^6.6.0"
|
pymunk = "^6.6.0"
|
||||||
zarr = "^2.17.0"
|
zarr = "^2.17.0"
|
||||||
|
shapely = "^2.0.3"
|
||||||
|
scikit-image = "^0.22.0"
|
||||||
numba = "^0.59.0"
|
numba = "^0.59.0"
|
||||||
|
mpmath = "^1.3.0"
|
||||||
torch = {version = "^2.2.1", source = "torch-cpu"}
|
torch = {version = "^2.2.1", source = "torch-cpu"}
|
||||||
|
tensordict = {git = "https://github.com/pytorch/tensordict"}
|
||||||
|
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
|
||||||
|
mujoco = "^2.3.7"
|
||||||
opencv-python = "^4.9.0.80"
|
opencv-python = "^4.9.0.80"
|
||||||
diffusers = "^0.26.3"
|
diffusers = "^0.26.3"
|
||||||
torchvision = {version = "^0.17.1", source = "torch-cpu"}
|
torchvision = {version = "^0.17.1", source = "torch-cpu"}
|
||||||
h5py = "^3.10.0"
|
h5py = "^3.10.0"
|
||||||
huggingface-hub = "^0.21.4"
|
dm = "^1.3"
|
||||||
|
dm-control = "1.0.14"
|
||||||
robomimic = "0.2.0"
|
robomimic = "0.2.0"
|
||||||
|
huggingface-hub = "^0.21.4"
|
||||||
|
gymnasium-robotics = "^1.2.4"
|
||||||
gymnasium = "^0.29.1"
|
gymnasium = "^0.29.1"
|
||||||
cmake = "^3.29.0.1"
|
cmake = "^3.29.0.1"
|
||||||
gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
|
|
||||||
gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
|
|
||||||
gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true}
|
|
||||||
pre-commit = {version = "^3.7.0", optional = true}
|
|
||||||
debugpy = {version = "^1.8.1", optional = true}
|
|
||||||
pytest = {version = "^8.1.0", optional = true}
|
|
||||||
pytest-cov = {version = "^5.0.0", optional = true}
|
|
||||||
datasets = "^2.19.0"
|
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pusht = ["gym-pusht"]
|
pre-commit = "^3.6.2"
|
||||||
xarm = ["gym-xarm"]
|
debugpy = "^1.8.1"
|
||||||
aloha = ["gym-aloha"]
|
pytest = "^8.1.0"
|
||||||
dev = ["pre-commit", "debugpy"]
|
pytest-cov = "^5.0.0"
|
||||||
test = ["pytest", "pytest-cov"]
|
|
||||||
|
|
||||||
|
|
||||||
[[tool.poetry.source]]
|
[[tool.poetry.source]]
|
||||||
@@ -102,6 +100,10 @@ exclude = [
|
|||||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry-dynamic-versioning]
|
||||||
|
enable = true
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core>=1.5.0"]
|
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry_dynamic_versioning.backend"
|
||||||
|
|||||||
57
.github/workflows/test.yml
vendored
57
.github/workflows/test.yml
vendored
@@ -34,11 +34,6 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
- name: Add SSH key for installing envs
|
|
||||||
uses: webfactory/ssh-agent@v0.9.0
|
|
||||||
with:
|
|
||||||
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
|
||||||
|
|
||||||
#----------------------------------------------
|
#----------------------------------------------
|
||||||
# install & configure poetry
|
# install & configure poetry
|
||||||
#----------------------------------------------
|
#----------------------------------------------
|
||||||
@@ -92,7 +87,7 @@ jobs:
|
|||||||
TMP: ~/tmp
|
TMP: ~/tmp
|
||||||
run: |
|
run: |
|
||||||
mkdir ~/tmp
|
mkdir ~/tmp
|
||||||
poetry install --no-interaction --no-root --all-extras
|
poetry install --no-interaction --no-root
|
||||||
|
|
||||||
- name: Save cached venv
|
- name: Save cached venv
|
||||||
if: |
|
if: |
|
||||||
@@ -111,15 +106,17 @@ jobs:
|
|||||||
# install project
|
# install project
|
||||||
#----------------------------------------------
|
#----------------------------------------------
|
||||||
- name: Install project
|
- name: Install project
|
||||||
run: poetry install --no-interaction --all-extras
|
run: poetry install --no-interaction
|
||||||
|
|
||||||
#----------------------------------------------
|
#----------------------------------------------
|
||||||
# run tests & coverage
|
# run tests & coverage
|
||||||
#----------------------------------------------
|
#----------------------------------------------
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
|
env:
|
||||||
|
LEROBOT_TESTS_DEVICE: cpu
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
pytest -v --cov=./lerobot --cov-report=xml tests
|
pytest --cov=./lerobot --cov-report=xml tests
|
||||||
|
|
||||||
# TODO(aliberts): Link with HF Codecov account
|
# TODO(aliberts): Link with HF Codecov account
|
||||||
# - name: Upload coverage reports to Codecov with GitHub Action
|
# - name: Upload coverage reports to Codecov with GitHub Action
|
||||||
@@ -140,12 +137,10 @@ jobs:
|
|||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
offline_steps=2 \
|
offline_steps=2 \
|
||||||
online_steps=0 \
|
online_steps=0 \
|
||||||
eval_episodes=1 \
|
|
||||||
device=cpu \
|
device=cpu \
|
||||||
save_model=true \
|
save_model=true \
|
||||||
save_freq=2 \
|
save_freq=2 \
|
||||||
policy.n_action_steps=20 \
|
horizon=20 \
|
||||||
policy.chunk_size=20 \
|
|
||||||
policy.batch_size=2 \
|
policy.batch_size=2 \
|
||||||
hydra.run.dir=tests/outputs/act/
|
hydra.run.dir=tests/outputs/act/
|
||||||
|
|
||||||
@@ -159,6 +154,17 @@ jobs:
|
|||||||
device=cpu \
|
device=cpu \
|
||||||
policy.pretrained_model_path=tests/outputs/act/models/2.pt
|
policy.pretrained_model_path=tests/outputs/act/models/2.pt
|
||||||
|
|
||||||
|
# TODO(aliberts): This takes ~2mn to run, needs to be improved
|
||||||
|
# - name: Test eval ACT on ALOHA end-to-end (policy is None)
|
||||||
|
# run: |
|
||||||
|
# source .venv/bin/activate
|
||||||
|
# python lerobot/scripts/eval.py \
|
||||||
|
# --config lerobot/configs/default.yaml \
|
||||||
|
# policy=act \
|
||||||
|
# env=aloha \
|
||||||
|
# eval_episodes=1 \
|
||||||
|
# device=cpu
|
||||||
|
|
||||||
- name: Test train Diffusion on PushT end-to-end
|
- name: Test train Diffusion on PushT end-to-end
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
@@ -168,11 +174,9 @@ jobs:
|
|||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
offline_steps=2 \
|
offline_steps=2 \
|
||||||
online_steps=0 \
|
online_steps=0 \
|
||||||
eval_episodes=1 \
|
|
||||||
device=cpu \
|
device=cpu \
|
||||||
save_model=true \
|
save_model=true \
|
||||||
save_freq=2 \
|
save_freq=2 \
|
||||||
policy.batch_size=2 \
|
|
||||||
hydra.run.dir=tests/outputs/diffusion/
|
hydra.run.dir=tests/outputs/diffusion/
|
||||||
|
|
||||||
- name: Test eval Diffusion on PushT end-to-end
|
- name: Test eval Diffusion on PushT end-to-end
|
||||||
@@ -185,21 +189,28 @@ jobs:
|
|||||||
device=cpu \
|
device=cpu \
|
||||||
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
|
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
|
||||||
|
|
||||||
|
- name: Test eval Diffusion on PushT end-to-end (policy is None)
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
python lerobot/scripts/eval.py \
|
||||||
|
--config lerobot/configs/default.yaml \
|
||||||
|
policy=diffusion \
|
||||||
|
env=pusht \
|
||||||
|
eval_episodes=1 \
|
||||||
|
device=cpu
|
||||||
|
|
||||||
- name: Test train TDMPC on Simxarm end-to-end
|
- name: Test train TDMPC on Simxarm end-to-end
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
policy=tdmpc \
|
policy=tdmpc \
|
||||||
env=xarm \
|
env=simxarm \
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
offline_steps=1 \
|
offline_steps=1 \
|
||||||
online_steps=2 \
|
online_steps=1 \
|
||||||
eval_episodes=1 \
|
|
||||||
env.episode_length=2 \
|
|
||||||
device=cpu \
|
device=cpu \
|
||||||
save_model=true \
|
save_model=true \
|
||||||
save_freq=2 \
|
save_freq=2 \
|
||||||
policy.batch_size=2 \
|
|
||||||
hydra.run.dir=tests/outputs/tdmpc/
|
hydra.run.dir=tests/outputs/tdmpc/
|
||||||
|
|
||||||
- name: Test eval TDMPC on Simxarm end-to-end
|
- name: Test eval TDMPC on Simxarm end-to-end
|
||||||
@@ -211,3 +222,13 @@ jobs:
|
|||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
|
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
|
||||||
|
|
||||||
|
- name: Test eval TDPMC on Simxarm end-to-end (policy is None)
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
python lerobot/scripts/eval.py \
|
||||||
|
--config lerobot/configs/default.yaml \
|
||||||
|
policy=tdmpc \
|
||||||
|
env=simxarm \
|
||||||
|
eval_episodes=1 \
|
||||||
|
device=cpu
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,9 +11,6 @@ rl
|
|||||||
nautilus/*.yaml
|
nautilus/*.yaml
|
||||||
*.key
|
*.key
|
||||||
|
|
||||||
# Slurm
|
|
||||||
sbatch*.sh
|
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
exclude: ^(data/|tests/data)
|
exclude: ^(data/|tests/)
|
||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.10
|
python: python3.10
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.6.0
|
rev: v4.5.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
- id: debug-statements
|
- id: debug-statements
|
||||||
@@ -18,7 +18,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.3.7
|
rev: v0.3.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
|
|||||||
@@ -1,133 +0,0 @@
|
|||||||
|
|
||||||
# Contributor Covenant Code of Conduct
|
|
||||||
|
|
||||||
## Our Pledge
|
|
||||||
|
|
||||||
We as members, contributors, and leaders pledge to make participation in our
|
|
||||||
community a harassment-free experience for everyone, regardless of age, body
|
|
||||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
|
||||||
identity and expression, level of experience, education, socio-economic status,
|
|
||||||
nationality, personal appearance, race, caste, color, religion, or sexual
|
|
||||||
identity and orientation.
|
|
||||||
|
|
||||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
|
||||||
diverse, inclusive, and healthy community.
|
|
||||||
|
|
||||||
## Our Standards
|
|
||||||
|
|
||||||
Examples of behavior that contributes to a positive environment for our
|
|
||||||
community include:
|
|
||||||
|
|
||||||
* Demonstrating empathy and kindness toward other people
|
|
||||||
* Being respectful of differing opinions, viewpoints, and experiences
|
|
||||||
* Giving and gracefully accepting constructive feedback
|
|
||||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
|
||||||
and learning from the experience
|
|
||||||
* Focusing on what is best not just for us as individuals, but for the overall
|
|
||||||
community
|
|
||||||
|
|
||||||
Examples of unacceptable behavior include:
|
|
||||||
|
|
||||||
* The use of sexualized language or imagery, and sexual attention or advances of
|
|
||||||
any kind
|
|
||||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
|
||||||
* Public or private harassment
|
|
||||||
* Publishing others' private information, such as a physical or email address,
|
|
||||||
without their explicit permission
|
|
||||||
* Other conduct which could reasonably be considered inappropriate in a
|
|
||||||
professional setting
|
|
||||||
|
|
||||||
## Enforcement Responsibilities
|
|
||||||
|
|
||||||
Community leaders are responsible for clarifying and enforcing our standards of
|
|
||||||
acceptable behavior and will take appropriate and fair corrective action in
|
|
||||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
|
||||||
or harmful.
|
|
||||||
|
|
||||||
Community leaders have the right and responsibility to remove, edit, or reject
|
|
||||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
|
||||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
|
||||||
decisions when appropriate.
|
|
||||||
|
|
||||||
## Scope
|
|
||||||
|
|
||||||
This Code of Conduct applies within all community spaces, and also applies when
|
|
||||||
an individual is officially representing the community in public spaces.
|
|
||||||
Examples of representing our community include using an official email address,
|
|
||||||
posting via an official social media account, or acting as an appointed
|
|
||||||
representative at an online or offline event.
|
|
||||||
|
|
||||||
## Enforcement
|
|
||||||
|
|
||||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
|
||||||
reported to the community leaders responsible for enforcement at
|
|
||||||
[feedback@huggingface.co](mailto:feedback@huggingface.co).
|
|
||||||
All complaints will be reviewed and investigated promptly and fairly.
|
|
||||||
|
|
||||||
All community leaders are obligated to respect the privacy and security of the
|
|
||||||
reporter of any incident.
|
|
||||||
|
|
||||||
## Enforcement Guidelines
|
|
||||||
|
|
||||||
Community leaders will follow these Community Impact Guidelines in determining
|
|
||||||
the consequences for any action they deem in violation of this Code of Conduct:
|
|
||||||
|
|
||||||
### 1. Correction
|
|
||||||
|
|
||||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
|
||||||
unprofessional or unwelcome in the community.
|
|
||||||
|
|
||||||
**Consequence**: A private, written warning from community leaders, providing
|
|
||||||
clarity around the nature of the violation and an explanation of why the
|
|
||||||
behavior was inappropriate. A public apology may be requested.
|
|
||||||
|
|
||||||
### 2. Warning
|
|
||||||
|
|
||||||
**Community Impact**: A violation through a single incident or series of
|
|
||||||
actions.
|
|
||||||
|
|
||||||
**Consequence**: A warning with consequences for continued behavior. No
|
|
||||||
interaction with the people involved, including unsolicited interaction with
|
|
||||||
those enforcing the Code of Conduct, for a specified period of time. This
|
|
||||||
includes avoiding interactions in community spaces as well as external channels
|
|
||||||
like social media. Violating these terms may lead to a temporary or permanent
|
|
||||||
ban.
|
|
||||||
|
|
||||||
### 3. Temporary Ban
|
|
||||||
|
|
||||||
**Community Impact**: A serious violation of community standards, including
|
|
||||||
sustained inappropriate behavior.
|
|
||||||
|
|
||||||
**Consequence**: A temporary ban from any sort of interaction or public
|
|
||||||
communication with the community for a specified period of time. No public or
|
|
||||||
private interaction with the people involved, including unsolicited interaction
|
|
||||||
with those enforcing the Code of Conduct, is allowed during this period.
|
|
||||||
Violating these terms may lead to a permanent ban.
|
|
||||||
|
|
||||||
### 4. Permanent Ban
|
|
||||||
|
|
||||||
**Community Impact**: Demonstrating a pattern of violation of community
|
|
||||||
standards, including sustained inappropriate behavior, harassment of an
|
|
||||||
individual, or aggression toward or disparagement of classes of individuals.
|
|
||||||
|
|
||||||
**Consequence**: A permanent ban from any sort of public interaction within the
|
|
||||||
community.
|
|
||||||
|
|
||||||
## Attribution
|
|
||||||
|
|
||||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
|
||||||
version 2.1, available at
|
|
||||||
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
|
||||||
|
|
||||||
Community Impact Guidelines were inspired by
|
|
||||||
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
|
|
||||||
|
|
||||||
For answers to common questions about this code of conduct, see the FAQ at
|
|
||||||
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
|
|
||||||
[https://www.contributor-covenant.org/translations][translations].
|
|
||||||
|
|
||||||
[homepage]: https://www.contributor-covenant.org
|
|
||||||
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
|
||||||
[Mozilla CoC]: https://github.com/mozilla/diversity
|
|
||||||
[FAQ]: https://www.contributor-covenant.org/faq
|
|
||||||
[translations]: https://www.contributor-covenant.org/translations
|
|
||||||
273
CONTRIBUTING.md
273
CONTRIBUTING.md
@@ -1,273 +0,0 @@
|
|||||||
# How to contribute to 🤗 LeRobot?
|
|
||||||
|
|
||||||
Everyone is welcome to contribute, and we value everybody's contribution. Code
|
|
||||||
is thus not the only way to help the community. Answering questions, helping
|
|
||||||
others, reaching out and improving the documentations are immensely valuable to
|
|
||||||
the community.
|
|
||||||
|
|
||||||
It also helps us if you spread the word: reference the library from blog posts
|
|
||||||
on the awesome projects it made possible, shout out on Twitter when it has
|
|
||||||
helped you, or simply ⭐️ the repo to say "thank you".
|
|
||||||
|
|
||||||
Whichever way you choose to contribute, please be mindful to respect our
|
|
||||||
[code of conduct](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md).
|
|
||||||
|
|
||||||
## You can contribute in so many ways!
|
|
||||||
|
|
||||||
Some of the ways you can contribute to 🤗 LeRobot:
|
|
||||||
* Fixing outstanding issues with the existing code.
|
|
||||||
* Implementing new models, datasets or simulation environments.
|
|
||||||
* Contributing to the examples or to the documentation.
|
|
||||||
* Submitting issues related to bugs or desired new features.
|
|
||||||
|
|
||||||
Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co).
|
|
||||||
|
|
||||||
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
|
|
||||||
|
|
||||||
## Submitting a new issue or feature request
|
|
||||||
|
|
||||||
Do your best to follow these guidelines when submitting an issue or a feature
|
|
||||||
request. It will make it easier for us to come back to you quickly and with good
|
|
||||||
feedback.
|
|
||||||
|
|
||||||
### Did you find a bug?
|
|
||||||
|
|
||||||
The 🤗 LeRobot library is robust and reliable thanks to the users who notify us of
|
|
||||||
the problems they encounter. So thank you for reporting an issue.
|
|
||||||
|
|
||||||
First, we would really appreciate it if you could **make sure the bug was not
|
|
||||||
already reported** (use the search bar on Github under Issues).
|
|
||||||
|
|
||||||
Did not find it? :( So we can act quickly on it, please follow these steps:
|
|
||||||
|
|
||||||
* Include your **OS type and version**, the versions of **Python** and **PyTorch**.
|
|
||||||
* A short, self-contained, code snippet that allows us to reproduce the bug in
|
|
||||||
less than 30s.
|
|
||||||
* The full traceback if an exception is raised.
|
|
||||||
* Attach any other additional information, like screenshots, you think may help.
|
|
||||||
|
|
||||||
### Do you want a new feature?
|
|
||||||
|
|
||||||
A good feature request addresses the following points:
|
|
||||||
|
|
||||||
1. Motivation first:
|
|
||||||
* Is it related to a problem/frustration with the library? If so, please explain
|
|
||||||
why. Providing a code snippet that demonstrates the problem is best.
|
|
||||||
* Is it related to something you would need for a project? We'd love to hear
|
|
||||||
about it!
|
|
||||||
* Is it something you worked on and think could benefit the community?
|
|
||||||
Awesome! Tell us what problem it solved for you.
|
|
||||||
2. Write a *paragraph* describing the feature.
|
|
||||||
3. Provide a **code snippet** that demonstrates its future use.
|
|
||||||
4. In case this is related to a paper, please attach a link.
|
|
||||||
5. Attach any additional information (drawings, screenshots, etc.) you think may help.
|
|
||||||
|
|
||||||
If your issue is well written we're already 80% of the way there by the time you
|
|
||||||
post it.
|
|
||||||
|
|
||||||
## Adding new policies, datasets or environments
|
|
||||||
|
|
||||||
Look at our implementations for [datasets](./lerobot/common/datasets/), [policies](./lerobot/common/policies/),
|
|
||||||
environments ([aloha](https://github.com/huggingface/gym-aloha),
|
|
||||||
[xarm](https://github.com/huggingface/gym-xarm),
|
|
||||||
[pusht](https://github.com/huggingface/gym-pusht))
|
|
||||||
and follow the same api design.
|
|
||||||
|
|
||||||
When implementing a new dataset loadable with LeRobotDataset follow these steps:
|
|
||||||
- Update `available_datasets_per_env` in `lerobot/__init__.py`
|
|
||||||
|
|
||||||
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
|
|
||||||
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
|
|
||||||
|
|
||||||
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
|
|
||||||
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
|
|
||||||
- Set the required `name` class attribute.
|
|
||||||
- Update variables in `tests/test_available.py` by importing your new Policy class
|
|
||||||
|
|
||||||
## Submitting a pull request (PR)
|
|
||||||
|
|
||||||
Before writing code, we strongly advise you to search through the existing PRs or
|
|
||||||
issues to make sure that nobody is already working on the same thing. If you are
|
|
||||||
unsure, it is always a good idea to open an issue to get some feedback.
|
|
||||||
|
|
||||||
You will need basic `git` proficiency to be able to contribute to
|
|
||||||
🤗 LeRobot. `git` is not the easiest tool to use but it has the greatest
|
|
||||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
|
||||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
|
||||||
|
|
||||||
Follow these steps to start contributing:
|
|
||||||
|
|
||||||
1. Fork the [repository](https://github.com/huggingface/lerobot) by
|
|
||||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
|
||||||
under your GitHub user account.
|
|
||||||
|
|
||||||
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
|
|
||||||
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
|
|
||||||
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone git@github.com:<your Github handle>/lerobot.git
|
|
||||||
cd lerobot
|
|
||||||
git remote add upstream https://github.com/huggingface/lerobot.git
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
|
|
||||||
|
|
||||||
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git checkout main
|
|
||||||
git fetch upstream
|
|
||||||
git rebase upstream/main
|
|
||||||
```
|
|
||||||
|
|
||||||
Once your `main` branch is synchronized, create a new branch from it:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git checkout -b a-descriptive-name-for-my-changes
|
|
||||||
```
|
|
||||||
|
|
||||||
🚨 **Do not** work on the `main` branch.
|
|
||||||
|
|
||||||
4. Instead of using `pip` directly, we use `poetry` for development purposes to easily track our dependencies.
|
|
||||||
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
|
|
||||||
Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
|
|
||||||
Install the project with dev dependencies and all environments:
|
|
||||||
```bash
|
|
||||||
poetry install --sync --with dev --all-extras
|
|
||||||
```
|
|
||||||
This command should be run when pulling code with and updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the dependencies.
|
|
||||||
|
|
||||||
To selectively install environments (for example aloha and pusht) use:
|
|
||||||
```bash
|
|
||||||
poetry install --sync --with dev --extras "aloha pusht"
|
|
||||||
```
|
|
||||||
|
|
||||||
The equivalent of `pip install some-package`, would just be:
|
|
||||||
```bash
|
|
||||||
poetry add some-package
|
|
||||||
```
|
|
||||||
|
|
||||||
When changes are made to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
|
||||||
```bash
|
|
||||||
poetry lock --no-update
|
|
||||||
```
|
|
||||||
|
|
||||||
**NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated to the CI. To do this, you should create a separate environment and add the new package there as well. For example:
|
|
||||||
```bash
|
|
||||||
# Add the new package to your main poetry env
|
|
||||||
poetry add some-package
|
|
||||||
# Add the same package to the CPU-only env dedicated to CI
|
|
||||||
conda create -y -n lerobot-ci python=3.10
|
|
||||||
conda activate lerobot-ci
|
|
||||||
cd .github/poetry/cpu
|
|
||||||
poetry add some-package
|
|
||||||
```
|
|
||||||
|
|
||||||
5. Develop the features on your branch.
|
|
||||||
|
|
||||||
As you work on the features, you should make sure that the test suite
|
|
||||||
passes. You should run the tests impacted by your changes like this (see
|
|
||||||
below an explanation regarding the environment variable):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pytest tests/<TEST_TO_RUN>.py
|
|
||||||
```
|
|
||||||
|
|
||||||
6. Follow our style.
|
|
||||||
|
|
||||||
`lerobot` relies on `ruff` to format its source code
|
|
||||||
consistently. Set up [`pre-commit`](https://pre-commit.com/) to run these checks
|
|
||||||
automatically as Git commit hooks.
|
|
||||||
|
|
||||||
Install `pre-commit` hooks:
|
|
||||||
```bash
|
|
||||||
pre-commit install
|
|
||||||
```
|
|
||||||
|
|
||||||
You can run these hooks whenever you need on staged files with:
|
|
||||||
```bash
|
|
||||||
pre-commit
|
|
||||||
```
|
|
||||||
|
|
||||||
Once you're happy with your changes, add changed files using `git add` and
|
|
||||||
make a commit with `git commit` to record your changes locally:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add modified_file.py
|
|
||||||
git commit
|
|
||||||
```
|
|
||||||
|
|
||||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
|
||||||
|
|
||||||
It is a good idea to sync your copy of the code with the original
|
|
||||||
repository regularly. This way you can quickly account for changes:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git fetch upstream
|
|
||||||
git rebase upstream/main
|
|
||||||
```
|
|
||||||
|
|
||||||
Push the changes to your account using:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git push -u origin a-descriptive-name-for-my-changes
|
|
||||||
```
|
|
||||||
|
|
||||||
6. Once you are satisfied (**and the checklist below is happy too**), go to the
|
|
||||||
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
|
|
||||||
to the project maintainers for review.
|
|
||||||
|
|
||||||
7. It's ok if maintainers ask you for changes. It happens to core contributors
|
|
||||||
too! So everyone can see the changes in the Pull request, work in your local
|
|
||||||
branch and push the changes to your fork. They will automatically appear in
|
|
||||||
the pull request.
|
|
||||||
|
|
||||||
|
|
||||||
### Checklist
|
|
||||||
|
|
||||||
1. The title of your pull request should be a summary of its contribution;
|
|
||||||
2. If your pull request addresses an issue, please mention the issue number in
|
|
||||||
the pull request description to make sure they are linked (and people
|
|
||||||
consulting the issue know you are working on it);
|
|
||||||
3. To indicate a work in progress please prefix the title with `[WIP]`, or preferably mark
|
|
||||||
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
|
|
||||||
it from PRs ready to be merged;
|
|
||||||
4. Make sure existing tests pass;
|
|
||||||
<!-- 5. Add high-coverage tests. No quality testing = no merge.
|
|
||||||
|
|
||||||
See an example of a good PR here: https://github.com/huggingface/lerobot/pull/ -->
|
|
||||||
|
|
||||||
### Tests
|
|
||||||
|
|
||||||
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in the [tests folder](https://github.com/huggingface/lerobot/tree/main/tests).
|
|
||||||
|
|
||||||
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
|
|
||||||
|
|
||||||
On Mac:
|
|
||||||
```bash
|
|
||||||
brew install git-lfs
|
|
||||||
git lfs install
|
|
||||||
```
|
|
||||||
|
|
||||||
On Ubuntu:
|
|
||||||
```bash
|
|
||||||
sudo apt-get install git-lfs
|
|
||||||
git lfs install
|
|
||||||
```
|
|
||||||
|
|
||||||
Pull artifacts if they're not in [tests/data](tests/data)
|
|
||||||
```bash
|
|
||||||
git lfs pull
|
|
||||||
```
|
|
||||||
|
|
||||||
We use `pytest` in order to run the tests. From the root of the
|
|
||||||
repository, here's how to run tests with `pytest` for the library:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
DATA_DIR="tests/data" python -m pytest -sv ./tests
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
You can specify a smaller set of tests in order to test only the feature
|
|
||||||
you're working on.
|
|
||||||
148
README.md
148
README.md
@@ -17,7 +17,6 @@
|
|||||||
[](https://pypi.org/project/lerobot/)
|
[](https://pypi.org/project/lerobot/)
|
||||||
[](https://pypi.org/project/lerobot/)
|
[](https://pypi.org/project/lerobot/)
|
||||||
[](https://github.com/huggingface/lerobot/tree/main/examples)
|
[](https://github.com/huggingface/lerobot/tree/main/examples)
|
||||||
[](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md)
|
|
||||||
[](https://discord.gg/s3KuuzsPFb)
|
[](https://discord.gg/s3KuuzsPFb)
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
@@ -63,29 +62,21 @@
|
|||||||
|
|
||||||
Download our source code:
|
Download our source code:
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/huggingface/lerobot.git && cd lerobot
|
git clone https://github.com/huggingface/lerobot.git
|
||||||
|
cd lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||||
```bash
|
```bash
|
||||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
conda create -y -n lerobot python=3.10
|
||||||
|
conda activate lerobot
|
||||||
```
|
```
|
||||||
|
|
||||||
Install 🤗 LeRobot:
|
Then, install 🤗 LeRobot:
|
||||||
```bash
|
```bash
|
||||||
python -m pip install .
|
python -m pip install .
|
||||||
```
|
```
|
||||||
|
|
||||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
|
||||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
|
||||||
- [xarm](https://github.com/huggingface/gym-xarm)
|
|
||||||
- [pusht](https://github.com/huggingface/gym-pusht)
|
|
||||||
|
|
||||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
|
||||||
```bash
|
|
||||||
python -m pip install ".[aloha, pusht]"
|
|
||||||
```
|
|
||||||
|
|
||||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
|
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
|
||||||
```bash
|
```bash
|
||||||
wandb login
|
wandb login
|
||||||
@@ -98,11 +89,11 @@ wandb login
|
|||||||
├── lerobot
|
├── lerobot
|
||||||
| ├── configs # contains hydra yaml files with all options that you can override in the command line
|
| ├── configs # contains hydra yaml files with all options that you can override in the command line
|
||||||
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
|
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
|
||||||
| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, xarm.yaml
|
| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, simxarm.yaml
|
||||||
| | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml
|
| | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml
|
||||||
| ├── common # contains classes and utilities
|
| ├── common # contains classes and utilities
|
||||||
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
|
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, simxarm
|
||||||
| | ├── envs # various sim environments: aloha, pusht, xarm
|
| | ├── envs # various sim environments: aloha, pusht, simxarm
|
||||||
| | └── policies # various policies: act, diffusion, tdmpc
|
| | └── policies # various policies: act, diffusion, tdmpc
|
||||||
| └── scripts # contains functions to execute via command line
|
| └── scripts # contains functions to execute via command line
|
||||||
| ├── visualize_dataset.py # load a dataset and render its demonstrations
|
| ├── visualize_dataset.py # load a dataset and render its demonstrations
|
||||||
@@ -118,19 +109,44 @@ wandb login
|
|||||||
|
|
||||||
### Visualize datasets
|
### Visualize datasets
|
||||||
|
|
||||||
Check out [examples](./examples) to see how you can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities.
|
You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities:
|
||||||
|
```python
|
||||||
|
""" Copy pasted from `examples/1_visualize_dataset.py` """
|
||||||
|
import lerobot
|
||||||
|
from lerobot.common.datasets.aloha import AlohaDataset
|
||||||
|
from torchrl.data.replay_buffers import SamplerWithoutReplacement
|
||||||
|
from lerobot.scripts.visualize_dataset import render_dataset
|
||||||
|
|
||||||
|
print(lerobot.available_datasets)
|
||||||
|
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
||||||
|
|
||||||
|
# we use this sampler to sample 1 frame after the other
|
||||||
|
sampler = SamplerWithoutReplacement(shuffle=False)
|
||||||
|
|
||||||
|
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler)
|
||||||
|
|
||||||
|
video_paths = render_dataset(
|
||||||
|
dataset,
|
||||||
|
out_dir="outputs/visualize_dataset/example",
|
||||||
|
max_num_samples=300,
|
||||||
|
fps=50,
|
||||||
|
)
|
||||||
|
print(video_paths)
|
||||||
|
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||||
|
```
|
||||||
|
|
||||||
Or you can achieve the same result by executing our script from the command line:
|
Or you can achieve the same result by executing our script from the command line:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/visualize_dataset.py \
|
python lerobot/scripts/visualize_dataset.py \
|
||||||
env=pusht \
|
env=aloha \
|
||||||
|
task=sim_sim_transfer_cube_human \
|
||||||
hydra.run.dir=outputs/visualize_dataset/example
|
hydra.run.dir=outputs/visualize_dataset/example
|
||||||
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
|
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||||
```
|
```
|
||||||
|
|
||||||
### Evaluate a pretrained policy
|
### Evaluate a pretrained policy
|
||||||
|
|
||||||
Check out [examples](./examples) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
|
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
|
||||||
|
|
||||||
Or you can achieve the same result by executing our script from the command line:
|
Or you can achieve the same result by executing our script from the command line:
|
||||||
```bash
|
```bash
|
||||||
@@ -153,30 +169,94 @@ See `python lerobot/scripts/eval.py --help` for more instructions.
|
|||||||
|
|
||||||
### Train your own policy
|
### Train your own policy
|
||||||
|
|
||||||
Check out [examples](./examples) to see how you can start training a model on a dataset, which will be automatically downloaded if needed.
|
You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): check out [example 3](./examples/3_train_policy.py). After you run this, you may want to revisit [example 2](./examples/2_evaluate_pretrained_policy.py) to evaluate your training output!
|
||||||
|
|
||||||
In general, you can use our training script to easily train any policy on any environment:
|
In general, you can use our training script to easily train any policy on any environment:
|
||||||
```bash
|
```bash
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
env=aloha \
|
env=aloha \
|
||||||
task=sim_insertion \
|
task=sim_insertion \
|
||||||
repo_id=lerobot/aloha_sim_insertion_scripted \
|
dataset_id=aloha_sim_insertion_scripted \
|
||||||
policy=act \
|
policy=act \
|
||||||
hydra.run.dir=outputs/train/aloha_act
|
hydra.run.dir=outputs/train/aloha_act
|
||||||
```
|
```
|
||||||
|
|
||||||
After training, you may want to revisit model evaluation to change the evaluation settings. In fact, during training every checkpoint is already evaluated but on a low number of episodes for efficiency. Check out [example](./examples) to evaluate any model checkpoint on more episodes to increase statistical significance.
|
|
||||||
|
|
||||||
## Contribute
|
## Contribute
|
||||||
|
|
||||||
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
|
Feel free to open issues and PRs, and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co).
|
||||||
|
|
||||||
|
### TODO
|
||||||
|
|
||||||
|
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
|
||||||
|
|
||||||
|
### Follow our style
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# install if needed
|
||||||
|
pre-commit install
|
||||||
|
# apply style and linter checks before git commit
|
||||||
|
pre-commit
|
||||||
|
```
|
||||||
|
|
||||||
|
### Add dependencies
|
||||||
|
|
||||||
|
Instead of using `pip` directly, we use `poetry` for development purposes to easily track our dependencies.
|
||||||
|
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
|
||||||
|
|
||||||
|
Install the project with:
|
||||||
|
```bash
|
||||||
|
poetry install
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, the equivalent of `pip install some-package`, would just be:
|
||||||
|
```bash
|
||||||
|
poetry add some-package
|
||||||
|
```
|
||||||
|
|
||||||
|
**NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated to the CI. To do this, you should create a separate environment and add the new package there as well. For example:
|
||||||
|
```bash
|
||||||
|
# Add the new package to your main poetry env
|
||||||
|
poetry add some-package
|
||||||
|
# Add the same package to the CPU-only env dedicated to CI
|
||||||
|
conda create -y -n lerobot-ci python=3.10
|
||||||
|
conda activate lerobot-ci
|
||||||
|
cd .github/poetry/cpu
|
||||||
|
poetry add some-package
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run tests locally
|
||||||
|
|
||||||
|
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
|
||||||
|
|
||||||
|
On Mac:
|
||||||
|
```bash
|
||||||
|
brew install git-lfs
|
||||||
|
git lfs install
|
||||||
|
```
|
||||||
|
|
||||||
|
On Ubuntu:
|
||||||
|
```bash
|
||||||
|
sudo apt-get install git-lfs
|
||||||
|
git lfs install
|
||||||
|
```
|
||||||
|
|
||||||
|
Pull artifacts if they're not in [tests/data](tests/data)
|
||||||
|
```bash
|
||||||
|
git lfs pull
|
||||||
|
```
|
||||||
|
|
||||||
|
When adding a new dataset, mock it with
|
||||||
|
```bash
|
||||||
|
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
|
||||||
|
```
|
||||||
|
|
||||||
|
Run tests
|
||||||
|
```bash
|
||||||
|
DATA_DIR="tests/data" pytest -sx tests
|
||||||
|
```
|
||||||
|
|
||||||
### Add a new dataset
|
### Add a new dataset
|
||||||
|
|
||||||
```python
|
|
||||||
# TODO(rcadene, AdilZouitine): rewrite this section
|
|
||||||
```
|
|
||||||
|
|
||||||
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
|
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
|
||||||
```bash
|
```bash
|
||||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||||
@@ -191,7 +271,7 @@ HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATAS
|
|||||||
|
|
||||||
You will need to set the corresponding version as a default argument in your dataset class:
|
You will need to set the corresponding version as a default argument in your dataset class:
|
||||||
```python
|
```python
|
||||||
version: str | None = "v1.1",
|
version: str | None = "v1.0",
|
||||||
```
|
```
|
||||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
||||||
|
|
||||||
@@ -238,10 +318,6 @@ python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir
|
|||||||
|
|
||||||
### Add a pretrained policy
|
### Add a pretrained policy
|
||||||
|
|
||||||
```python
|
|
||||||
# TODO(rcadene, alexander-soare): rewrite this section
|
|
||||||
```
|
|
||||||
|
|
||||||
Once you have trained a policy you may upload it to the HuggingFace hub.
|
Once you have trained a policy you may upload it to the HuggingFace hub.
|
||||||
|
|
||||||
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
|
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
|
||||||
@@ -250,13 +326,15 @@ Secondly, assuming you have trained a policy, you need:
|
|||||||
|
|
||||||
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
|
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
|
||||||
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
|
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
|
||||||
|
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
|
||||||
|
|
||||||
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
|
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
|
||||||
|
|
||||||
```
|
```
|
||||||
to_upload
|
to_upload
|
||||||
├── config.yaml
|
├── config.yaml
|
||||||
└── model.pt
|
├── model.pt
|
||||||
|
└── stats.pth
|
||||||
```
|
```
|
||||||
|
|
||||||
With the folder prepared, run the following with a desired revision ID.
|
With the folder prepared, run the following with a desired revision ID.
|
||||||
|
|||||||
@@ -1,550 +0,0 @@
|
|||||||
"""
|
|
||||||
This file contains all obsolete download scripts. They are centralized here to not have to load
|
|
||||||
useless dependencies when using datasets.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import pickle
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import einops
|
|
||||||
import h5py
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
|
||||||
from huggingface_hub import HfApi
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload(root, revision, dataset_id):
|
|
||||||
# TODO(rcadene, adilzouitine): add community_id/user_id (e.g. "lerobot", "cadene") or repo_id (e.g. "lerobot/pusht")
|
|
||||||
if "pusht" in dataset_id:
|
|
||||||
download_and_upload_pusht(root, revision, dataset_id)
|
|
||||||
elif "xarm" in dataset_id:
|
|
||||||
download_and_upload_xarm(root, revision, dataset_id)
|
|
||||||
elif "aloha" in dataset_id:
|
|
||||||
download_and_upload_aloha(root, revision, dataset_id)
|
|
||||||
else:
|
|
||||||
raise ValueError(dataset_id)
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
print(f"downloading from {url}")
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
if response.status_code == 200:
|
|
||||||
total_size = int(response.headers.get("content-length", 0))
|
|
||||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
|
||||||
|
|
||||||
zip_file = io.BytesIO()
|
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
|
||||||
if chunk:
|
|
||||||
zip_file.write(chunk)
|
|
||||||
progress_bar.update(len(chunk))
|
|
||||||
|
|
||||||
progress_bar.close()
|
|
||||||
|
|
||||||
zip_file.seek(0)
|
|
||||||
|
|
||||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
|
||||||
zip_ref.extractall(destination_folder)
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def concatenate_episodes(ep_dicts):
|
|
||||||
data_dict = {}
|
|
||||||
|
|
||||||
keys = ep_dicts[0].keys()
|
|
||||||
for key in keys:
|
|
||||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
|
||||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
|
||||||
else:
|
|
||||||
if key not in data_dict:
|
|
||||||
data_dict[key] = []
|
|
||||||
for ep_dict in ep_dicts:
|
|
||||||
for x in ep_dict[key]:
|
|
||||||
data_dict[key].append(x)
|
|
||||||
|
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
return data_dict
|
|
||||||
|
|
||||||
|
|
||||||
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
|
|
||||||
# push to main to indicate latest version
|
|
||||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
|
||||||
|
|
||||||
# push to version branch
|
|
||||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision)
|
|
||||||
|
|
||||||
# create and store meta_data
|
|
||||||
meta_data_dir = root / dataset_id / "meta_data"
|
|
||||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
api = HfApi()
|
|
||||||
|
|
||||||
# info
|
|
||||||
info_path = meta_data_dir / "info.json"
|
|
||||||
with open(str(info_path), "w") as f:
|
|
||||||
json.dump(info, f, indent=4)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=info_path,
|
|
||||||
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
|
|
||||||
repo_id=f"lerobot/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=info_path,
|
|
||||||
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
|
|
||||||
repo_id=f"lerobot/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
# stats
|
|
||||||
stats_path = meta_data_dir / "stats.safetensors"
|
|
||||||
save_file(flatten_dict(stats), stats_path)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=stats_path,
|
|
||||||
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
|
|
||||||
repo_id=f"lerobot/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=stats_path,
|
|
||||||
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
|
|
||||||
repo_id=f"lerobot/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
# episode_data_index
|
|
||||||
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
|
||||||
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
|
||||||
save_file(episode_data_index, ep_data_idx_path)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=ep_data_idx_path,
|
|
||||||
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
|
|
||||||
repo_id=f"lerobot/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=ep_data_idx_path,
|
|
||||||
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
|
|
||||||
repo_id=f"lerobot/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
# copy in tests folder, the first episode and the meta_data directory
|
|
||||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
|
||||||
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
|
|
||||||
f"tests/data/lerobot/{dataset_id}/train"
|
|
||||||
)
|
|
||||||
if Path(f"tests/data/lerobot/{dataset_id}/meta_data").exists():
|
|
||||||
shutil.rmtree(f"tests/data/lerobot/{dataset_id}/meta_data")
|
|
||||||
shutil.copytree(meta_data_dir, f"tests/data/lerobot/{dataset_id}/meta_data")
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
|
||||||
try:
|
|
||||||
import pymunk
|
|
||||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
|
||||||
|
|
||||||
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
|
|
||||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
|
||||||
)
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# as define in env
|
|
||||||
success_threshold = 0.95 # 95% coverage,
|
|
||||||
|
|
||||||
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
|
||||||
pusht_zarr = Path("pusht/pusht_cchi_v7_replay.zarr")
|
|
||||||
|
|
||||||
root = Path(root)
|
|
||||||
raw_dir = root / f"{dataset_id}_raw"
|
|
||||||
zarr_path = (raw_dir / pusht_zarr).resolve()
|
|
||||||
if not zarr_path.is_dir():
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
download_and_extract_zip(pusht_url, raw_dir)
|
|
||||||
|
|
||||||
# load
|
|
||||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
|
||||||
|
|
||||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
|
||||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
|
||||||
assert len(
|
|
||||||
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
|
||||||
), "Some data type dont have the same number of total frames."
|
|
||||||
|
|
||||||
# TODO: verify that goal pose is expected to be fixed
|
|
||||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
|
||||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
|
||||||
|
|
||||||
imgs = torch.from_numpy(dataset_dict["img"]) # b h w c
|
|
||||||
states = torch.from_numpy(dataset_dict["state"])
|
|
||||||
actions = torch.from_numpy(dataset_dict["action"])
|
|
||||||
|
|
||||||
ep_dicts = []
|
|
||||||
episode_data_index = {"from": [], "to": []}
|
|
||||||
|
|
||||||
id_from = 0
|
|
||||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
|
||||||
id_to = dataset_dict.meta["episode_ends"][episode_id]
|
|
||||||
|
|
||||||
num_frames = id_to - id_from
|
|
||||||
|
|
||||||
assert (episode_ids[id_from:id_to] == episode_id).all()
|
|
||||||
|
|
||||||
image = imgs[id_from:id_to]
|
|
||||||
assert image.min() >= 0.0
|
|
||||||
assert image.max() <= 255.0
|
|
||||||
image = image.type(torch.uint8)
|
|
||||||
|
|
||||||
state = states[id_from:id_to]
|
|
||||||
agent_pos = state[:, :2]
|
|
||||||
block_pos = state[:, 2:4]
|
|
||||||
block_angle = state[:, 4]
|
|
||||||
|
|
||||||
reward = torch.zeros(num_frames)
|
|
||||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
|
||||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
|
||||||
for i in range(num_frames):
|
|
||||||
space = pymunk.Space()
|
|
||||||
space.gravity = 0, 0
|
|
||||||
space.damping = 0
|
|
||||||
|
|
||||||
# Add walls.
|
|
||||||
walls = [
|
|
||||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
|
||||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
|
||||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
|
||||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
|
||||||
]
|
|
||||||
space.add(*walls)
|
|
||||||
|
|
||||||
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
|
||||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
|
||||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
|
||||||
intersection_area = goal_geom.intersection(block_geom).area
|
|
||||||
goal_area = goal_geom.area
|
|
||||||
coverage = intersection_area / goal_area
|
|
||||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
|
||||||
success[i] = coverage > success_threshold
|
|
||||||
|
|
||||||
# last step of demonstration is considered done
|
|
||||||
done[-1] = True
|
|
||||||
|
|
||||||
ep_dict = {
|
|
||||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
|
||||||
"observation.state": agent_pos,
|
|
||||||
"action": actions[id_from:id_to],
|
|
||||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
|
||||||
"frame_index": torch.arange(0, num_frames, 1),
|
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
|
||||||
# "next.observation.image": image[1:],
|
|
||||||
# "next.observation.state": agent_pos[1:],
|
|
||||||
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
|
||||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
|
||||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
|
||||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
|
||||||
}
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from += num_frames
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
|
|
||||||
features = {
|
|
||||||
"observation.image": Image(),
|
|
||||||
"observation.state": Sequence(
|
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
),
|
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
|
||||||
"episode_index": Value(dtype="int64", id=None),
|
|
||||||
"frame_index": Value(dtype="int64", id=None),
|
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
|
||||||
"next.reward": Value(dtype="float32", id=None),
|
|
||||||
"next.done": Value(dtype="bool", id=None),
|
|
||||||
"next.success": Value(dtype="bool", id=None),
|
|
||||||
"index": Value(dtype="int64", id=None),
|
|
||||||
}
|
|
||||||
features = Features(features)
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"fps": fps,
|
|
||||||
}
|
|
||||||
stats = compute_stats(hf_dataset)
|
|
||||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
|
||||||
root = Path(root)
|
|
||||||
raw_dir = root / "xarm_datasets_raw"
|
|
||||||
if not raw_dir.exists():
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
import gdown
|
|
||||||
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
|
||||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
|
||||||
zip_path = raw_dir / "data.zip"
|
|
||||||
gdown.download(url, str(zip_path), quiet=False)
|
|
||||||
print("Extracting...")
|
|
||||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
|
||||||
for member in zip_f.namelist():
|
|
||||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
|
||||||
print(member)
|
|
||||||
zip_f.extract(member=member)
|
|
||||||
zip_path.unlink()
|
|
||||||
|
|
||||||
dataset_path = root / f"{dataset_id}" / "buffer.pkl"
|
|
||||||
print(f"Using offline dataset '{dataset_path}'")
|
|
||||||
with open(dataset_path, "rb") as f:
|
|
||||||
dataset_dict = pickle.load(f)
|
|
||||||
|
|
||||||
ep_dicts = []
|
|
||||||
episode_data_index = {"from": [], "to": []}
|
|
||||||
|
|
||||||
id_from = 0
|
|
||||||
id_to = 0
|
|
||||||
episode_id = 0
|
|
||||||
total_frames = dataset_dict["actions"].shape[0]
|
|
||||||
for i in tqdm.tqdm(range(total_frames)):
|
|
||||||
id_to += 1
|
|
||||||
|
|
||||||
if not dataset_dict["dones"][i]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
num_frames = id_to - id_from
|
|
||||||
|
|
||||||
image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to])
|
|
||||||
image = einops.rearrange(image, "b c h w -> b h w c")
|
|
||||||
state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to])
|
|
||||||
action = torch.tensor(dataset_dict["actions"][id_from:id_to])
|
|
||||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
|
||||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
|
||||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to])
|
|
||||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to])
|
|
||||||
next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to])
|
|
||||||
next_done = torch.tensor(dataset_dict["dones"][id_from:id_to])
|
|
||||||
|
|
||||||
ep_dict = {
|
|
||||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
|
||||||
"observation.state": state,
|
|
||||||
"action": action,
|
|
||||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
|
||||||
"frame_index": torch.arange(0, num_frames, 1),
|
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
|
||||||
# "next.observation.image": next_image,
|
|
||||||
# "next.observation.state": next_state,
|
|
||||||
"next.reward": next_reward,
|
|
||||||
"next.done": next_done,
|
|
||||||
}
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from = id_to
|
|
||||||
episode_id += 1
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
|
|
||||||
features = {
|
|
||||||
"observation.image": Image(),
|
|
||||||
"observation.state": Sequence(
|
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
),
|
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
|
||||||
"episode_index": Value(dtype="int64", id=None),
|
|
||||||
"frame_index": Value(dtype="int64", id=None),
|
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
|
||||||
"next.reward": Value(dtype="float32", id=None),
|
|
||||||
"next.done": Value(dtype="bool", id=None),
|
|
||||||
#'next.success': Value(dtype='bool', id=None),
|
|
||||||
"index": Value(dtype="int64", id=None),
|
|
||||||
}
|
|
||||||
features = Features(features)
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"fps": fps,
|
|
||||||
}
|
|
||||||
stats = compute_stats(hf_dataset)
|
|
||||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
|
||||||
folder_urls = {
|
|
||||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
|
||||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
|
||||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
|
|
||||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
|
|
||||||
}
|
|
||||||
|
|
||||||
ep48_urls = {
|
|
||||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
|
|
||||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
|
|
||||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
|
|
||||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
|
|
||||||
}
|
|
||||||
|
|
||||||
ep49_urls = {
|
|
||||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
|
|
||||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
|
|
||||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
|
|
||||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
|
|
||||||
}
|
|
||||||
|
|
||||||
num_episodes = {
|
|
||||||
"aloha_sim_insertion_human": 50,
|
|
||||||
"aloha_sim_insertion_scripted": 50,
|
|
||||||
"aloha_sim_transfer_cube_human": 50,
|
|
||||||
"aloha_sim_transfer_cube_scripted": 50,
|
|
||||||
}
|
|
||||||
|
|
||||||
episode_len = {
|
|
||||||
"aloha_sim_insertion_human": 500,
|
|
||||||
"aloha_sim_insertion_scripted": 400,
|
|
||||||
"aloha_sim_transfer_cube_human": 400,
|
|
||||||
"aloha_sim_transfer_cube_scripted": 400,
|
|
||||||
}
|
|
||||||
|
|
||||||
cameras = {
|
|
||||||
"aloha_sim_insertion_human": ["top"],
|
|
||||||
"aloha_sim_insertion_scripted": ["top"],
|
|
||||||
"aloha_sim_transfer_cube_human": ["top"],
|
|
||||||
"aloha_sim_transfer_cube_scripted": ["top"],
|
|
||||||
}
|
|
||||||
|
|
||||||
root = Path(root)
|
|
||||||
raw_dir = root / f"{dataset_id}_raw"
|
|
||||||
if not raw_dir.is_dir():
|
|
||||||
import gdown
|
|
||||||
|
|
||||||
assert dataset_id in folder_urls
|
|
||||||
assert dataset_id in ep48_urls
|
|
||||||
assert dataset_id in ep49_urls
|
|
||||||
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
gdown.download_folder(folder_urls[dataset_id], output=str(raw_dir))
|
|
||||||
|
|
||||||
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
|
||||||
gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True)
|
|
||||||
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
|
|
||||||
|
|
||||||
ep_dicts = []
|
|
||||||
episode_data_index = {"from": [], "to": []}
|
|
||||||
|
|
||||||
id_from = 0
|
|
||||||
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
|
|
||||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
|
||||||
with h5py.File(ep_path, "r") as ep:
|
|
||||||
num_frames = ep["/action"].shape[0]
|
|
||||||
assert episode_len[dataset_id] == num_frames
|
|
||||||
|
|
||||||
# last step of demonstration is considered done
|
|
||||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
|
||||||
done[-1] = True
|
|
||||||
|
|
||||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
|
||||||
action = torch.from_numpy(ep["/action"][:])
|
|
||||||
|
|
||||||
ep_dict = {}
|
|
||||||
|
|
||||||
for cam in cameras[dataset_id]:
|
|
||||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
|
|
||||||
# image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
|
||||||
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
|
|
||||||
# ep_dict[f"next.observation.images.{cam}"] = image
|
|
||||||
|
|
||||||
ep_dict.update(
|
|
||||||
{
|
|
||||||
"observation.state": state,
|
|
||||||
"action": action,
|
|
||||||
"episode_index": torch.tensor([ep_id] * num_frames),
|
|
||||||
"frame_index": torch.arange(0, num_frames, 1),
|
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
|
||||||
# "next.observation.state": state,
|
|
||||||
# TODO(rcadene): compute reward and success
|
|
||||||
# "next.reward": reward,
|
|
||||||
"next.done": done,
|
|
||||||
# "next.success": success,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(ep_id, int)
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from += num_frames
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
|
|
||||||
features = {
|
|
||||||
"observation.images.top": Image(),
|
|
||||||
"observation.state": Sequence(
|
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
),
|
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
|
||||||
"episode_index": Value(dtype="int64", id=None),
|
|
||||||
"frame_index": Value(dtype="int64", id=None),
|
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
|
||||||
#'next.reward': Value(dtype='float32', id=None),
|
|
||||||
"next.done": Value(dtype="bool", id=None),
|
|
||||||
#'next.success': Value(dtype='bool', id=None),
|
|
||||||
"index": Value(dtype="int64", id=None),
|
|
||||||
}
|
|
||||||
features = Features(features)
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"fps": fps,
|
|
||||||
}
|
|
||||||
stats = compute_stats(hf_dataset)
|
|
||||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
root = "data"
|
|
||||||
revision = "v1.1"
|
|
||||||
|
|
||||||
dataset_ids = [
|
|
||||||
"pusht",
|
|
||||||
"xarm_lift_medium",
|
|
||||||
"xarm_lift_medium_replay",
|
|
||||||
"xarm_push_medium",
|
|
||||||
"xarm_push_medium_replay",
|
|
||||||
"aloha_sim_insertion_human",
|
|
||||||
"aloha_sim_insertion_scripted",
|
|
||||||
"aloha_sim_transfer_cube_human",
|
|
||||||
"aloha_sim_transfer_cube_scripted",
|
|
||||||
]
|
|
||||||
for dataset_id in dataset_ids:
|
|
||||||
download_and_upload(root, revision, dataset_id)
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
"""
|
|
||||||
This script demonstrates the visualization of various robotic datasets from Hugging Face hub.
|
|
||||||
It covers the steps from loading the datasets, filtering specific episodes, and converting the frame data to MP4 videos.
|
|
||||||
Importantly, the dataset format is agnostic to any deep learning library and doesn't require using `lerobot` functions.
|
|
||||||
It is compatible with pytorch, jax, numpy, etc.
|
|
||||||
|
|
||||||
As an example, this script saves frames of episode number 5 of the PushT dataset to a mp4 video and saves the result here:
|
|
||||||
`outputs/examples/1_visualize_hugging_face_datasets/episode_5.mp4`
|
|
||||||
|
|
||||||
This script supports several Hugging Face datasets, among which:
|
|
||||||
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
|
|
||||||
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
|
|
||||||
3. [Xarm Lift Medium Replay](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay)
|
|
||||||
4. [Xarm Push Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium)
|
|
||||||
5. [Xarm Push Medium Replay](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay)
|
|
||||||
6. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
|
||||||
7. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
|
||||||
8. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
|
||||||
9. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
|
||||||
|
|
||||||
To try a different Hugging Face dataset, you can replace this line:
|
|
||||||
```python
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
|
||||||
```
|
|
||||||
by one of these:
|
|
||||||
```python
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium_replay", split="train"), 15
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium", split="train"), 15
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium_replay", split="train"), 15
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
# TODO(rcadene): remove this example file of using hf_dataset
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import imageio
|
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
|
||||||
|
|
||||||
# download/load hugging face dataset in pyarrow format
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train", revision="v1.1"), 10
|
|
||||||
|
|
||||||
# display name of dataset and its features
|
|
||||||
# TODO(rcadene): update to make the print pretty
|
|
||||||
print(f"{hf_dataset=}")
|
|
||||||
print(f"{hf_dataset.features=}")
|
|
||||||
|
|
||||||
# display useful statistics about frames and episodes, which are sequences of frames from the same video
|
|
||||||
print(f"number of frames: {len(hf_dataset)=}")
|
|
||||||
print(f"number of episodes: {len(hf_dataset.unique('episode_index'))=}")
|
|
||||||
print(
|
|
||||||
f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_index')):.3f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# select the frames belonging to episode number 5
|
|
||||||
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
|
||||||
|
|
||||||
# load all frames of episode 5 in RAM in PIL format
|
|
||||||
frames = hf_dataset["observation.image"]
|
|
||||||
|
|
||||||
# save episode frames to a mp4 video
|
|
||||||
Path("outputs/examples/1_load_hugging_face_dataset").mkdir(parents=True, exist_ok=True)
|
|
||||||
imageio.mimsave("outputs/examples/1_load_hugging_face_dataset/episode_5.mp4", frames, fps=fps)
|
|
||||||
24
examples/1_visualize_dataset.py
Normal file
24
examples/1_visualize_dataset.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from torchrl.data.replay_buffers import SamplerWithoutReplacement
|
||||||
|
|
||||||
|
import lerobot
|
||||||
|
from lerobot.common.datasets.aloha import AlohaDataset
|
||||||
|
from lerobot.scripts.visualize_dataset import render_dataset
|
||||||
|
|
||||||
|
print(lerobot.available_datasets)
|
||||||
|
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
||||||
|
|
||||||
|
# we use this sampler to sample 1 frame after the other
|
||||||
|
sampler = SamplerWithoutReplacement(shuffle=False)
|
||||||
|
|
||||||
|
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR"))
|
||||||
|
|
||||||
|
video_paths = render_dataset(
|
||||||
|
dataset,
|
||||||
|
out_dir="outputs/visualize_dataset/example",
|
||||||
|
max_num_samples=300,
|
||||||
|
fps=50,
|
||||||
|
)
|
||||||
|
print(video_paths)
|
||||||
|
# ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||||
@@ -7,11 +7,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils import init_hydra_config
|
||||||
from lerobot.scripts.eval import eval
|
from lerobot.scripts.eval import eval
|
||||||
|
|
||||||
# Get a pretrained policy from the hub.
|
# Get a pretrained policy from the hub.
|
||||||
# TODO(alexander-soare): This no longer works until we upload a new model that uses the current configs.
|
|
||||||
hub_id = "lerobot/diffusion_policy_pusht_image"
|
hub_id = "lerobot/diffusion_policy_pusht_image"
|
||||||
folder = Path(snapshot_download(hub_id))
|
folder = Path(snapshot_download(hub_id))
|
||||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||||
@@ -19,6 +18,7 @@ folder = Path(snapshot_download(hub_id))
|
|||||||
|
|
||||||
config_path = folder / "config.yaml"
|
config_path = folder / "config.yaml"
|
||||||
weights_path = folder / "model.pt"
|
weights_path = folder / "model.pt"
|
||||||
|
stats_path = folder / "stats.pth" # normalization stats
|
||||||
|
|
||||||
# Override some config parameters to do with evaluation.
|
# Override some config parameters to do with evaluation.
|
||||||
overrides = [
|
overrides = [
|
||||||
@@ -35,4 +35,5 @@ cfg = init_hydra_config(config_path, overrides)
|
|||||||
eval(
|
eval(
|
||||||
cfg,
|
cfg,
|
||||||
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
|
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
|
||||||
|
stats_path=stats_path,
|
||||||
)
|
)
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
"""
|
|
||||||
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
|
|
||||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
|
||||||
|
|
||||||
Features included in this script:
|
|
||||||
- Loading a dataset and accessing its properties.
|
|
||||||
- Filtering data by episode number.
|
|
||||||
- Converting tensor data for visualization.
|
|
||||||
- Saving video files from dataset frames.
|
|
||||||
- Using advanced dataset features like timestamp-based frame selection.
|
|
||||||
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
|
||||||
|
|
||||||
The script ends with examples of how to batch process data using PyTorch's DataLoader.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import imageio
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import lerobot
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
|
|
||||||
print("List of available datasets", lerobot.available_datasets)
|
|
||||||
# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted',
|
|
||||||
# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted',
|
|
||||||
# # 'lerobot/pusht', 'lerobot/xarm_lift_medium']
|
|
||||||
|
|
||||||
repo_id = "lerobot/pusht"
|
|
||||||
|
|
||||||
# You can easily load a dataset from a Hugging Face repositery
|
|
||||||
dataset = LeRobotDataset(repo_id)
|
|
||||||
|
|
||||||
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
|
||||||
# TODO(rcadene): update to make the print pretty
|
|
||||||
print(f"{dataset=}")
|
|
||||||
print(f"{dataset.hf_dataset=}")
|
|
||||||
|
|
||||||
# and provides additional utilities for robotics and compatibility with pytorch
|
|
||||||
print(f"number of samples/frames: {dataset.num_samples=}")
|
|
||||||
print(f"number of episodes: {dataset.num_episodes=}")
|
|
||||||
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
|
||||||
print(f"frames per second used during data collection: {dataset.fps=}")
|
|
||||||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
|
||||||
|
|
||||||
# While the LeRobotDataset adds helpers for working within our library, we still expose the underling Hugging Face dataset.
|
|
||||||
# It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
|
||||||
# TODO(rcadene): remove this example of accessing hf_dataset
|
|
||||||
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
|
||||||
|
|
||||||
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames.
|
|
||||||
frames = [sample["observation.image"] for sample in dataset]
|
|
||||||
|
|
||||||
# but frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention,
|
|
||||||
# to view them, we convert to uint8 range [0,255]
|
|
||||||
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
|
||||||
# and to channel last (h,w,c)
|
|
||||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
|
||||||
|
|
||||||
# and finally save them to a mp4 video
|
|
||||||
Path("outputs/examples/2_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
|
||||||
imageio.mimsave("outputs/examples/2_load_lerobot_dataset/episode_5.mp4", frames, fps=dataset.fps)
|
|
||||||
|
|
||||||
# For many machine learning applications we need to load histories of past observations, or trajectorys of future actions. Our datasets can load previous and future frames for each key/modality,
|
|
||||||
# using timestamps differences with the current loaded frame. For instance:
|
|
||||||
delta_timestamps = {
|
|
||||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
|
||||||
"observation.image": [-1, -0.5, -0.20, 0],
|
|
||||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
|
|
||||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
|
|
||||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
|
||||||
"action": [t / dataset.fps for t in range(64)],
|
|
||||||
}
|
|
||||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
|
||||||
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
|
||||||
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
|
||||||
print(f"{dataset[0]['action'].shape=}") # (64,c)
|
|
||||||
|
|
||||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers
|
|
||||||
# because they are just PyTorch datasets.
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
num_workers=0,
|
|
||||||
batch_size=32,
|
|
||||||
shuffle=True,
|
|
||||||
)
|
|
||||||
for batch in dataloader:
|
|
||||||
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
|
|
||||||
print(f"{batch['observation.state'].shape=}") # (32,8,c)
|
|
||||||
print(f"{batch['action'].shape=}") # (32,64,c)
|
|
||||||
break
|
|
||||||
55
examples/3_train_policy.py
Normal file
55
examples/3_train_policy.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
||||||
|
|
||||||
|
Once you have trained a model with this script, you can try to evaluate it on
|
||||||
|
examples/2_evaluate_pretrained_policy.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
|
from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
|
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||||
|
from lerobot.common.utils import init_hydra_config
|
||||||
|
|
||||||
|
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||||
|
os.makedirs(output_directory, exist_ok=True)
|
||||||
|
|
||||||
|
overrides = [
|
||||||
|
"env=pusht",
|
||||||
|
"policy=diffusion",
|
||||||
|
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||||
|
"offline_steps=5000",
|
||||||
|
"log_freq=250",
|
||||||
|
"device=cuda",
|
||||||
|
]
|
||||||
|
|
||||||
|
cfg = init_hydra_config("lerobot/configs/default.yaml", overrides)
|
||||||
|
|
||||||
|
policy = DiffusionPolicy(
|
||||||
|
cfg=cfg.policy,
|
||||||
|
cfg_device=cfg.device,
|
||||||
|
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||||
|
cfg_rgb_model=cfg.rgb_model,
|
||||||
|
cfg_obs_encoder=cfg.obs_encoder,
|
||||||
|
cfg_optimizer=cfg.optimizer,
|
||||||
|
cfg_ema=cfg.ema,
|
||||||
|
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||||
|
**cfg.policy,
|
||||||
|
)
|
||||||
|
policy.train()
|
||||||
|
|
||||||
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
|
||||||
|
for offline_step in trange(cfg.offline_steps):
|
||||||
|
train_info = policy.update(offline_buffer, offline_step)
|
||||||
|
if offline_step % cfg.log_freq == 0:
|
||||||
|
print(train_info)
|
||||||
|
|
||||||
|
# Save the policy, configuration, and normalization stats for later use.
|
||||||
|
policy.save_pretrained(output_directory / "model.pt")
|
||||||
|
OmegaConf.save(cfg, output_directory / "config.yaml")
|
||||||
|
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
|
||||||
|
|
||||||
Once you have trained a model with this script, you can try to evaluate it on
|
|
||||||
examples/2_evaluate_pretrained_policy.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
|
||||||
|
|
||||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
|
||||||
os.makedirs(output_directory, exist_ok=True)
|
|
||||||
|
|
||||||
# Number of offline training steps (we'll only do offline training for this example.
|
|
||||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
|
||||||
training_steps = 5000
|
|
||||||
device = torch.device("cuda")
|
|
||||||
log_freq = 250
|
|
||||||
|
|
||||||
# Set up the dataset.
|
|
||||||
hydra_cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
|
|
||||||
dataset = make_dataset(hydra_cfg)
|
|
||||||
|
|
||||||
# Set up the the policy.
|
|
||||||
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
|
|
||||||
# For this example, no arguments need to be passed because the defaults are set up for PushT.
|
|
||||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
|
||||||
cfg = DiffusionConfig()
|
|
||||||
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
|
||||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
|
|
||||||
policy.train()
|
|
||||||
policy.to(device)
|
|
||||||
|
|
||||||
# Create dataloader for offline training.
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
num_workers=4,
|
|
||||||
batch_size=cfg.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
pin_memory=device != torch.device("cpu"),
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run training loop.
|
|
||||||
step = 0
|
|
||||||
done = False
|
|
||||||
while not done:
|
|
||||||
for batch in dataloader:
|
|
||||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
|
||||||
info = policy.update(batch)
|
|
||||||
if step % log_freq == 0:
|
|
||||||
print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
|
|
||||||
step += 1
|
|
||||||
if step >= training_steps:
|
|
||||||
done = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# Save the policy and configuration for later use.
|
|
||||||
policy.save(output_directory / "model.pt")
|
|
||||||
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
|
|
||||||
@@ -7,72 +7,53 @@ Example:
|
|||||||
import lerobot
|
import lerobot
|
||||||
print(lerobot.available_envs)
|
print(lerobot.available_envs)
|
||||||
print(lerobot.available_tasks_per_env)
|
print(lerobot.available_tasks_per_env)
|
||||||
print(lerobot.available_datasets)
|
|
||||||
print(lerobot.available_datasets_per_env)
|
print(lerobot.available_datasets_per_env)
|
||||||
|
print(lerobot.available_datasets)
|
||||||
print(lerobot.available_policies)
|
print(lerobot.available_policies)
|
||||||
print(lerobot.available_policies_per_env)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
When implementing a new dataset loadable with LeRobotDataset follow these steps:
|
Note:
|
||||||
- Update `available_datasets_per_env` in `lerobot/__init__.py`
|
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||||
|
1. set the required class attributes:
|
||||||
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
|
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||||
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
|
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||||
|
- for classes inheriting from `AbstractPolicy`: `name`
|
||||||
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
|
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||||
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
|
3. update variables in `tests/test_available.py` by importing your new class
|
||||||
- Set the required `name` class attribute.
|
|
||||||
- Update variables in `tests/test_available.py` by importing your new Policy class
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from lerobot.__version__ import __version__ # noqa: F401
|
from lerobot.__version__ import __version__ # noqa: F401
|
||||||
|
|
||||||
|
available_envs = [
|
||||||
|
"aloha",
|
||||||
|
"pusht",
|
||||||
|
"simxarm",
|
||||||
|
]
|
||||||
|
|
||||||
available_tasks_per_env = {
|
available_tasks_per_env = {
|
||||||
"aloha": [
|
"aloha": [
|
||||||
"AlohaInsertion-v0",
|
"sim_insertion",
|
||||||
"AlohaTransferCube-v0",
|
"sim_transfer_cube",
|
||||||
],
|
],
|
||||||
"pusht": ["PushT-v0"],
|
"pusht": ["pusht"],
|
||||||
"xarm": ["XarmLift-v0"],
|
"simxarm": ["lift"],
|
||||||
}
|
}
|
||||||
available_envs = list(available_tasks_per_env.keys())
|
|
||||||
|
|
||||||
available_datasets_per_env = {
|
available_datasets_per_env = {
|
||||||
"aloha": [
|
"aloha": [
|
||||||
"lerobot/aloha_sim_insertion_human",
|
"aloha_sim_insertion_human",
|
||||||
"lerobot/aloha_sim_insertion_scripted",
|
"aloha_sim_insertion_scripted",
|
||||||
"lerobot/aloha_sim_transfer_cube_human",
|
"aloha_sim_transfer_cube_human",
|
||||||
"lerobot/aloha_sim_transfer_cube_scripted",
|
"aloha_sim_transfer_cube_scripted",
|
||||||
],
|
|
||||||
"pusht": ["lerobot/pusht"],
|
|
||||||
"xarm": [
|
|
||||||
"lerobot/xarm_lift_medium",
|
|
||||||
"lerobot/xarm_lift_medium_replay",
|
|
||||||
"lerobot/xarm_push_medium",
|
|
||||||
"lerobot/xarm_push_medium_replay",
|
|
||||||
],
|
],
|
||||||
|
"pusht": ["pusht"],
|
||||||
|
"simxarm": ["xarm_lift_medium"],
|
||||||
}
|
}
|
||||||
available_datasets = [dataset for datasets in available_datasets_per_env.values() for dataset in datasets]
|
|
||||||
|
available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]]
|
||||||
|
|
||||||
available_policies = [
|
available_policies = [
|
||||||
"act",
|
"act",
|
||||||
"diffusion",
|
"diffusion",
|
||||||
"tdmpc",
|
"tdmpc",
|
||||||
]
|
]
|
||||||
|
|
||||||
available_policies_per_env = {
|
|
||||||
"aloha": ["act"],
|
|
||||||
"pusht": ["diffusion"],
|
|
||||||
"xarm": ["tdmpc"],
|
|
||||||
}
|
|
||||||
|
|
||||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
|
||||||
env_dataset_pairs = [
|
|
||||||
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
|
|
||||||
]
|
|
||||||
env_dataset_policy_triplets = [
|
|
||||||
(env, dataset, policy)
|
|
||||||
for env, datasets in available_datasets_per_env.items()
|
|
||||||
for dataset in datasets
|
|
||||||
for policy in available_policies_per_env[env]
|
|
||||||
]
|
|
||||||
|
|||||||
207
lerobot/common/datasets/abstract.py
Normal file
207
lerobot/common/datasets/abstract.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
import torchrl
|
||||||
|
import tqdm
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||||
|
from torchrl.data.replay_buffers.samplers import Sampler
|
||||||
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||||
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||||
|
from torchrl.envs.transforms.transforms import Compose
|
||||||
|
|
||||||
|
HF_USER = "lerobot"
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractDataset(TensorDictReplayBuffer):
|
||||||
|
"""
|
||||||
|
AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning.
|
||||||
|
This class is designed to be subclassed by concrete implementations that specify particular types of datasets.
|
||||||
|
These implementations can vary based on the source of the data, the environment the data pertains to,
|
||||||
|
or the specific kind of data manipulation applied.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational
|
||||||
|
functionality for storing and retrieving `TensorDict`-like data.
|
||||||
|
- `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported.
|
||||||
|
It is expected that these variants correspond to a HuggingFace dataset on the hub.
|
||||||
|
For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants:
|
||||||
|
- [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
||||||
|
- [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
||||||
|
- [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
||||||
|
- [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
||||||
|
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||||
|
1. set the required class attributes:
|
||||||
|
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||||
|
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||||
|
- for classes inheriting from `AbstractPolicy`: `name`
|
||||||
|
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||||
|
3. update variables in `tests/test_available.py` by importing your new class
|
||||||
|
"""
|
||||||
|
|
||||||
|
available_datasets: list[str] | None = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
version: str | None = None,
|
||||||
|
batch_size: int | None = None,
|
||||||
|
*,
|
||||||
|
shuffle: bool = True,
|
||||||
|
root: Path | None = None,
|
||||||
|
pin_memory: bool = False,
|
||||||
|
prefetch: int = None,
|
||||||
|
sampler: Sampler | None = None,
|
||||||
|
collate_fn: Callable | None = None,
|
||||||
|
writer: Writer | None = None,
|
||||||
|
transform: "torchrl.envs.Transform" = None,
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
self.available_datasets is not None
|
||||||
|
), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute."
|
||||||
|
assert (
|
||||||
|
dataset_id in self.available_datasets
|
||||||
|
), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}."
|
||||||
|
|
||||||
|
self.dataset_id = dataset_id
|
||||||
|
self.version = version
|
||||||
|
self.shuffle = shuffle
|
||||||
|
self.root = root if root is None else Path(root)
|
||||||
|
|
||||||
|
if self.root is not None and self.version is not None:
|
||||||
|
logging.warning(
|
||||||
|
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = self._download_or_load_dataset()
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
storage=storage,
|
||||||
|
sampler=sampler,
|
||||||
|
writer=ImmutableDatasetWriter() if writer is None else writer,
|
||||||
|
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
prefetch=prefetch,
|
||||||
|
batch_size=batch_size,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stats_patterns(self) -> dict:
|
||||||
|
return {
|
||||||
|
("observation", "state"): "b c -> c",
|
||||||
|
("observation", "image"): "b c h w -> c 1 1",
|
||||||
|
("action",): "b c -> c",
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_keys(self) -> list:
|
||||||
|
return [("observation", "image")]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_cameras(self) -> int:
|
||||||
|
return len(self.image_keys)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_samples(self) -> int:
|
||||||
|
return len(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_episodes(self) -> int:
|
||||||
|
return len(self._storage._storage["episode"].unique())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transform(self):
|
||||||
|
return self._transform
|
||||||
|
|
||||||
|
def set_transform(self, transform):
|
||||||
|
if not isinstance(transform, Compose):
|
||||||
|
# required since torchrl calls `len(self._transform)` downstream
|
||||||
|
if isinstance(transform, list):
|
||||||
|
self._transform = Compose(*transform)
|
||||||
|
else:
|
||||||
|
self._transform = Compose(transform)
|
||||||
|
else:
|
||||||
|
self._transform = transform
|
||||||
|
|
||||||
|
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
||||||
|
stats_path = self.data_dir / "stats.pth"
|
||||||
|
if stats_path.exists():
|
||||||
|
stats = torch.load(stats_path)
|
||||||
|
else:
|
||||||
|
logging.info(f"compute_stats and save to {stats_path}")
|
||||||
|
stats = self._compute_stats(num_batch, batch_size)
|
||||||
|
torch.save(stats, stats_path)
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def _download_or_load_dataset(self) -> torch.StorageBase:
|
||||||
|
if self.root is None:
|
||||||
|
self.data_dir = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.data_dir = self.root / self.dataset_id
|
||||||
|
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
|
||||||
|
|
||||||
|
def _compute_stats(self, num_batch=100, batch_size=32):
|
||||||
|
rb = TensorDictReplayBuffer(
|
||||||
|
storage=self._storage,
|
||||||
|
batch_size=batch_size,
|
||||||
|
prefetch=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mean, std, max, min = {}, {}, {}, {}
|
||||||
|
|
||||||
|
# compute mean, min, max
|
||||||
|
for _ in tqdm.tqdm(range(num_batch)):
|
||||||
|
batch = rb.sample()
|
||||||
|
for key, pattern in self.stats_patterns.items():
|
||||||
|
batch[key] = batch[key].float()
|
||||||
|
if key not in mean:
|
||||||
|
# first batch initialize mean, min, max
|
||||||
|
mean[key] = einops.reduce(batch[key], pattern, "mean")
|
||||||
|
max[key] = einops.reduce(batch[key], pattern, "max")
|
||||||
|
min[key] = einops.reduce(batch[key], pattern, "min")
|
||||||
|
else:
|
||||||
|
mean[key] += einops.reduce(batch[key], pattern, "mean")
|
||||||
|
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||||
|
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||||
|
batch = rb.sample()
|
||||||
|
|
||||||
|
for key in self.stats_patterns:
|
||||||
|
mean[key] /= num_batch
|
||||||
|
|
||||||
|
# compute std, min, max
|
||||||
|
for _ in tqdm.tqdm(range(num_batch)):
|
||||||
|
batch = rb.sample()
|
||||||
|
for key, pattern in self.stats_patterns.items():
|
||||||
|
batch[key] = batch[key].float()
|
||||||
|
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||||
|
if key not in std:
|
||||||
|
# first batch initialize std
|
||||||
|
std[key] = (batch_mean - mean[key]) ** 2
|
||||||
|
else:
|
||||||
|
std[key] += (batch_mean - mean[key]) ** 2
|
||||||
|
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||||
|
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||||
|
|
||||||
|
for key in self.stats_patterns:
|
||||||
|
std[key] = torch.sqrt(std[key] / num_batch)
|
||||||
|
|
||||||
|
stats = TensorDict({}, batch_size=[])
|
||||||
|
for key in self.stats_patterns:
|
||||||
|
stats[(*key, "mean")] = mean[key]
|
||||||
|
stats[(*key, "std")] = std[key]
|
||||||
|
stats[(*key, "max")] = max[key]
|
||||||
|
stats[(*key, "min")] = min[key]
|
||||||
|
|
||||||
|
if key[0] == "observation":
|
||||||
|
# use same stats for the next observations
|
||||||
|
stats[("next", *key)] = stats[key]
|
||||||
|
return stats
|
||||||
185
lerobot/common/datasets/aloha.py
Normal file
185
lerobot/common/datasets/aloha.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import gdown
|
||||||
|
import h5py
|
||||||
|
import torch
|
||||||
|
import torchrl
|
||||||
|
import tqdm
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.replay_buffers.samplers import Sampler
|
||||||
|
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||||
|
from torchrl.data.replay_buffers.writers import Writer
|
||||||
|
|
||||||
|
from lerobot.common.datasets.abstract import AbstractDataset
|
||||||
|
|
||||||
|
DATASET_IDS = [
|
||||||
|
"aloha_sim_insertion_human",
|
||||||
|
"aloha_sim_insertion_scripted",
|
||||||
|
"aloha_sim_transfer_cube_human",
|
||||||
|
"aloha_sim_transfer_cube_scripted",
|
||||||
|
]
|
||||||
|
|
||||||
|
FOLDER_URLS = {
|
||||||
|
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||||
|
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||||
|
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
|
||||||
|
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
|
||||||
|
}
|
||||||
|
|
||||||
|
EP48_URLS = {
|
||||||
|
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
|
||||||
|
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
|
||||||
|
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
|
||||||
|
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
|
||||||
|
}
|
||||||
|
|
||||||
|
EP49_URLS = {
|
||||||
|
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
|
||||||
|
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
|
||||||
|
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
|
||||||
|
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
|
||||||
|
}
|
||||||
|
|
||||||
|
NUM_EPISODES = {
|
||||||
|
"aloha_sim_insertion_human": 50,
|
||||||
|
"aloha_sim_insertion_scripted": 50,
|
||||||
|
"aloha_sim_transfer_cube_human": 50,
|
||||||
|
"aloha_sim_transfer_cube_scripted": 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
EPISODE_LEN = {
|
||||||
|
"aloha_sim_insertion_human": 500,
|
||||||
|
"aloha_sim_insertion_scripted": 400,
|
||||||
|
"aloha_sim_transfer_cube_human": 400,
|
||||||
|
"aloha_sim_transfer_cube_scripted": 400,
|
||||||
|
}
|
||||||
|
|
||||||
|
CAMERAS = {
|
||||||
|
"aloha_sim_insertion_human": ["top"],
|
||||||
|
"aloha_sim_insertion_scripted": ["top"],
|
||||||
|
"aloha_sim_transfer_cube_human": ["top"],
|
||||||
|
"aloha_sim_transfer_cube_scripted": ["top"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def download(data_dir, dataset_id):
|
||||||
|
assert dataset_id in DATASET_IDS
|
||||||
|
assert dataset_id in FOLDER_URLS
|
||||||
|
assert dataset_id in EP48_URLS
|
||||||
|
assert dataset_id in EP49_URLS
|
||||||
|
|
||||||
|
data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
|
||||||
|
|
||||||
|
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
||||||
|
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
|
||||||
|
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AlohaDataset(AbstractDataset):
|
||||||
|
available_datasets = DATASET_IDS
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
version: str | None = "v1.2",
|
||||||
|
batch_size: int | None = None,
|
||||||
|
*,
|
||||||
|
shuffle: bool = True,
|
||||||
|
root: Path | None = None,
|
||||||
|
pin_memory: bool = False,
|
||||||
|
prefetch: int = None,
|
||||||
|
sampler: Sampler | None = None,
|
||||||
|
collate_fn: Callable | None = None,
|
||||||
|
writer: Writer | None = None,
|
||||||
|
transform: "torchrl.envs.Transform" = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
dataset_id,
|
||||||
|
version,
|
||||||
|
batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
root=root,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
prefetch=prefetch,
|
||||||
|
sampler=sampler,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
writer=writer,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stats_patterns(self) -> dict:
|
||||||
|
d = {
|
||||||
|
("observation", "state"): "b c -> c",
|
||||||
|
("action",): "b c -> c",
|
||||||
|
}
|
||||||
|
for cam in CAMERAS[self.dataset_id]:
|
||||||
|
d[("observation", "image", cam)] = "b c h w -> c 1 1"
|
||||||
|
return d
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_keys(self) -> list:
|
||||||
|
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
||||||
|
|
||||||
|
def _download_and_preproc_obsolete(self):
|
||||||
|
assert self.root is not None
|
||||||
|
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||||
|
if not raw_dir.is_dir():
|
||||||
|
download(raw_dir, self.dataset_id)
|
||||||
|
|
||||||
|
total_num_frames = 0
|
||||||
|
logging.info("Compute total number of frames to initialize offline buffer")
|
||||||
|
for ep_id in range(NUM_EPISODES[self.dataset_id]):
|
||||||
|
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||||
|
with h5py.File(ep_path, "r") as ep:
|
||||||
|
total_num_frames += ep["/action"].shape[0] - 1
|
||||||
|
logging.info(f"{total_num_frames=}")
|
||||||
|
|
||||||
|
logging.info("Initialize and feed offline buffer")
|
||||||
|
idxtd = 0
|
||||||
|
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
||||||
|
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||||
|
with h5py.File(ep_path, "r") as ep:
|
||||||
|
ep_num_frames = ep["/action"].shape[0]
|
||||||
|
|
||||||
|
# last step of demonstration is considered done
|
||||||
|
done = torch.zeros(ep_num_frames, 1, dtype=torch.bool)
|
||||||
|
done[-1] = True
|
||||||
|
|
||||||
|
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||||
|
action = torch.from_numpy(ep["/action"][:])
|
||||||
|
|
||||||
|
ep_td = TensorDict(
|
||||||
|
{
|
||||||
|
("observation", "state"): state[:-1],
|
||||||
|
"action": action[:-1],
|
||||||
|
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
|
||||||
|
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
|
||||||
|
("next", "observation", "state"): state[1:],
|
||||||
|
# TODO: compute reward and success
|
||||||
|
# ("next", "reward"): reward[1:],
|
||||||
|
("next", "done"): done[1:],
|
||||||
|
# ("next", "success"): success[1:],
|
||||||
|
},
|
||||||
|
batch_size=ep_num_frames - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
for cam in CAMERAS[self.dataset_id]:
|
||||||
|
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||||
|
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||||
|
ep_td["observation", "image", cam] = image[:-1]
|
||||||
|
ep_td["next", "observation", "image", cam] = image[1:]
|
||||||
|
|
||||||
|
if ep_id == 0:
|
||||||
|
# hack to initialize tensordict data structure to store episodes
|
||||||
|
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||||
|
|
||||||
|
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||||
|
idxtd = idxtd + len(ep_td)
|
||||||
|
|
||||||
|
return TensorStorage(td_data.lock_())
|
||||||
@@ -3,42 +3,134 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||||
|
|
||||||
|
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||||
|
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
||||||
|
# to load a subset of our datasets for faster continuous integration.
|
||||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
|
|
||||||
|
|
||||||
def make_dataset(
|
def make_offline_buffer(
|
||||||
cfg,
|
cfg,
|
||||||
split="train",
|
overwrite_sampler=None,
|
||||||
|
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||||
|
normalize=True,
|
||||||
|
overwrite_batch_size=None,
|
||||||
|
overwrite_prefetch=None,
|
||||||
|
stats_path=None,
|
||||||
):
|
):
|
||||||
if cfg.env.name not in cfg.dataset.repo_id:
|
if cfg.policy.balanced_sampling:
|
||||||
logging.warning(
|
assert cfg.online_steps > 0
|
||||||
f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})."
|
batch_size = None
|
||||||
)
|
pin_memory = False
|
||||||
|
prefetch = None
|
||||||
|
else:
|
||||||
|
assert cfg.online_steps == 0
|
||||||
|
num_slices = cfg.policy.batch_size
|
||||||
|
batch_size = cfg.policy.horizon * num_slices
|
||||||
|
pin_memory = cfg.device == "cuda"
|
||||||
|
prefetch = cfg.prefetch
|
||||||
|
|
||||||
delta_timestamps = cfg.policy.get("delta_timestamps")
|
if overwrite_batch_size is not None:
|
||||||
if delta_timestamps is not None:
|
batch_size = overwrite_batch_size
|
||||||
for key in delta_timestamps:
|
|
||||||
if isinstance(delta_timestamps[key], str):
|
|
||||||
delta_timestamps[key] = eval(delta_timestamps[key])
|
|
||||||
|
|
||||||
# TODO(rcadene): add data augmentations
|
if overwrite_prefetch is not None:
|
||||||
|
prefetch = overwrite_prefetch
|
||||||
|
|
||||||
dataset = LeRobotDataset(
|
if overwrite_sampler is None:
|
||||||
cfg.dataset.repo_id,
|
# TODO(rcadene): move batch_size outside
|
||||||
split=split,
|
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
||||||
|
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||||
|
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||||
|
|
||||||
|
if cfg.offline_prioritized_sampler:
|
||||||
|
logging.info("use prioritized sampler for offline dataset")
|
||||||
|
sampler = PrioritizedSliceSampler(
|
||||||
|
max_capacity=100_000,
|
||||||
|
alpha=cfg.policy.per_alpha,
|
||||||
|
beta=cfg.policy.per_beta,
|
||||||
|
num_slices=num_traj_per_batch,
|
||||||
|
strict_length=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("use simple sampler for offline dataset")
|
||||||
|
sampler = SliceSampler(
|
||||||
|
num_slices=num_traj_per_batch,
|
||||||
|
strict_length=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampler = overwrite_sampler
|
||||||
|
|
||||||
|
if cfg.env.name == "simxarm":
|
||||||
|
from lerobot.common.datasets.simxarm import SimxarmDataset
|
||||||
|
|
||||||
|
clsfunc = SimxarmDataset
|
||||||
|
|
||||||
|
elif cfg.env.name == "pusht":
|
||||||
|
from lerobot.common.datasets.pusht import PushtDataset
|
||||||
|
|
||||||
|
clsfunc = PushtDataset
|
||||||
|
|
||||||
|
elif cfg.env.name == "aloha":
|
||||||
|
from lerobot.common.datasets.aloha import AlohaDataset
|
||||||
|
|
||||||
|
clsfunc = AlohaDataset
|
||||||
|
else:
|
||||||
|
raise ValueError(cfg.env.name)
|
||||||
|
|
||||||
|
offline_buffer = clsfunc(
|
||||||
|
dataset_id=cfg.dataset_id,
|
||||||
|
sampler=sampler,
|
||||||
|
batch_size=batch_size,
|
||||||
root=DATA_DIR,
|
root=DATA_DIR,
|
||||||
delta_timestamps=delta_timestamps,
|
pin_memory=pin_memory,
|
||||||
|
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.get("override_dataset_stats"):
|
if cfg.policy.name == "tdmpc":
|
||||||
for key, stats_dict in cfg.override_dataset_stats.items():
|
img_keys = []
|
||||||
for stats_type, listconfig in stats_dict.items():
|
for key in offline_buffer.image_keys:
|
||||||
# example of stats_type: min, max, mean, std
|
img_keys.append(("next", *key))
|
||||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
img_keys += offline_buffer.image_keys
|
||||||
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
else:
|
||||||
|
img_keys = offline_buffer.image_keys
|
||||||
|
|
||||||
return dataset
|
if normalize:
|
||||||
|
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
||||||
|
|
||||||
|
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||||
|
# min_max_from_spec
|
||||||
|
stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||||
|
|
||||||
|
# we only normalize the state and action, since the images are usually normalized inside the model for
|
||||||
|
# now (except for tdmpc: see the following)
|
||||||
|
in_keys = [("observation", "state"), ("action")]
|
||||||
|
|
||||||
|
if cfg.policy.name == "tdmpc":
|
||||||
|
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
|
||||||
|
in_keys += img_keys
|
||||||
|
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
|
||||||
|
in_keys += [("next", *key) for key in img_keys]
|
||||||
|
in_keys.append(("next", "observation", "state"))
|
||||||
|
|
||||||
|
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||||
|
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||||
|
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||||
|
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||||
|
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||||
|
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||||
|
|
||||||
|
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||||
|
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||||
|
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
||||||
|
|
||||||
|
offline_buffer.set_transform(transforms)
|
||||||
|
|
||||||
|
if not overwrite_sampler:
|
||||||
|
index = torch.arange(0, offline_buffer.num_samples, 1)
|
||||||
|
sampler.extend(index)
|
||||||
|
|
||||||
|
return offline_buffer
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import (
|
|
||||||
load_episode_data_index,
|
|
||||||
load_hf_dataset,
|
|
||||||
load_info,
|
|
||||||
load_previous_and_future_frames,
|
|
||||||
load_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LeRobotDataset(torch.utils.data.Dataset):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
repo_id: str,
|
|
||||||
version: str | None = "v1.1",
|
|
||||||
root: Path | None = None,
|
|
||||||
split: str = "train",
|
|
||||||
transform: callable = None,
|
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.repo_id = repo_id
|
|
||||||
self.version = version
|
|
||||||
self.root = root
|
|
||||||
self.split = split
|
|
||||||
self.transform = transform
|
|
||||||
self.delta_timestamps = delta_timestamps
|
|
||||||
# load data from hub or locally when root is provided
|
|
||||||
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
|
|
||||||
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
|
||||||
self.stats = load_stats(repo_id, version, root)
|
|
||||||
self.info = load_info(repo_id, version, root)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def fps(self) -> int:
|
|
||||||
return self.info["fps"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def image_keys(self) -> list[str]:
|
|
||||||
return [key for key, feats in self.hf_dataset.features.items() if isinstance(feats, datasets.Image)]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_samples(self) -> int:
|
|
||||||
return len(self.hf_dataset)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_episodes(self) -> int:
|
|
||||||
return len(self.hf_dataset.unique("episode_index"))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.num_samples
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
item = self.hf_dataset[idx]
|
|
||||||
|
|
||||||
if self.delta_timestamps is not None:
|
|
||||||
item = load_previous_and_future_frames(
|
|
||||||
item,
|
|
||||||
self.hf_dataset,
|
|
||||||
self.episode_data_index,
|
|
||||||
self.delta_timestamps,
|
|
||||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.transform is not None:
|
|
||||||
item = self.transform(item)
|
|
||||||
|
|
||||||
return item
|
|
||||||
223
lerobot/common/datasets/pusht.py
Normal file
223
lerobot/common/datasets/pusht.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import numpy as np
|
||||||
|
import pygame
|
||||||
|
import pymunk
|
||||||
|
import torch
|
||||||
|
import torchrl
|
||||||
|
import tqdm
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.replay_buffers.samplers import Sampler
|
||||||
|
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||||
|
from torchrl.data.replay_buffers.writers import Writer
|
||||||
|
|
||||||
|
from lerobot.common.datasets.abstract import AbstractDataset
|
||||||
|
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||||
|
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
||||||
|
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||||
|
|
||||||
|
# as define in env
|
||||||
|
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||||
|
|
||||||
|
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||||
|
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||||
|
|
||||||
|
|
||||||
|
def get_goal_pose_body(pose):
|
||||||
|
mass = 1
|
||||||
|
inertia = pymunk.moment_for_box(mass, (50, 100))
|
||||||
|
body = pymunk.Body(mass, inertia)
|
||||||
|
# preserving the legacy assignment order for compatibility
|
||||||
|
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
|
||||||
|
body.position = pose[:2].tolist()
|
||||||
|
body.angle = pose[2]
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def add_segment(space, a, b, radius):
|
||||||
|
shape = pymunk.Segment(space.static_body, a, b, radius)
|
||||||
|
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
|
||||||
|
return shape
|
||||||
|
|
||||||
|
|
||||||
|
def add_tee(
|
||||||
|
space,
|
||||||
|
position,
|
||||||
|
angle,
|
||||||
|
scale=30,
|
||||||
|
color="LightSlateGray",
|
||||||
|
mask=None,
|
||||||
|
):
|
||||||
|
if mask is None:
|
||||||
|
mask = pymunk.ShapeFilter.ALL_MASKS()
|
||||||
|
mass = 1
|
||||||
|
length = 4
|
||||||
|
vertices1 = [
|
||||||
|
(-length * scale / 2, scale),
|
||||||
|
(length * scale / 2, scale),
|
||||||
|
(length * scale / 2, 0),
|
||||||
|
(-length * scale / 2, 0),
|
||||||
|
]
|
||||||
|
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||||
|
vertices2 = [
|
||||||
|
(-scale / 2, scale),
|
||||||
|
(-scale / 2, length * scale),
|
||||||
|
(scale / 2, length * scale),
|
||||||
|
(scale / 2, scale),
|
||||||
|
]
|
||||||
|
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||||
|
body = pymunk.Body(mass, inertia1 + inertia2)
|
||||||
|
shape1 = pymunk.Poly(body, vertices1)
|
||||||
|
shape2 = pymunk.Poly(body, vertices2)
|
||||||
|
shape1.color = pygame.Color(color)
|
||||||
|
shape2.color = pygame.Color(color)
|
||||||
|
shape1.filter = pymunk.ShapeFilter(mask=mask)
|
||||||
|
shape2.filter = pymunk.ShapeFilter(mask=mask)
|
||||||
|
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
|
||||||
|
body.position = position
|
||||||
|
body.angle = angle
|
||||||
|
body.friction = 1
|
||||||
|
space.add(body, shape1, shape2)
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
class PushtDataset(AbstractDataset):
|
||||||
|
available_datasets = ["pusht"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
version: str | None = "v1.2",
|
||||||
|
batch_size: int | None = None,
|
||||||
|
*,
|
||||||
|
shuffle: bool = True,
|
||||||
|
root: Path | None = None,
|
||||||
|
pin_memory: bool = False,
|
||||||
|
prefetch: int = None,
|
||||||
|
sampler: Sampler | None = None,
|
||||||
|
collate_fn: Callable | None = None,
|
||||||
|
writer: Writer | None = None,
|
||||||
|
transform: "torchrl.envs.Transform" = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
dataset_id,
|
||||||
|
version,
|
||||||
|
batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
root=root,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
prefetch=prefetch,
|
||||||
|
sampler=sampler,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
writer=writer,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _download_and_preproc_obsolete(self):
|
||||||
|
assert self.root is not None
|
||||||
|
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||||
|
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||||
|
if not zarr_path.is_dir():
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||||
|
|
||||||
|
# load
|
||||||
|
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
|
||||||
|
zarr_path
|
||||||
|
) # , keys=['img', 'state', 'action'])
|
||||||
|
|
||||||
|
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||||
|
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||||
|
total_frames = dataset_dict["action"].shape[0]
|
||||||
|
# to create test artifact
|
||||||
|
# num_episodes = 1
|
||||||
|
# total_frames = 50
|
||||||
|
assert len(
|
||||||
|
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||||
|
), "Some data type dont have the same number of total frames."
|
||||||
|
|
||||||
|
# TODO: verify that goal pose is expected to be fixed
|
||||||
|
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||||
|
goal_body = get_goal_pose_body(goal_pos_angle)
|
||||||
|
|
||||||
|
imgs = torch.from_numpy(dataset_dict["img"])
|
||||||
|
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
|
||||||
|
states = torch.from_numpy(dataset_dict["state"])
|
||||||
|
actions = torch.from_numpy(dataset_dict["action"])
|
||||||
|
|
||||||
|
idx0 = 0
|
||||||
|
idxtd = 0
|
||||||
|
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||||
|
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||||
|
# to create test artifact
|
||||||
|
# idx1 = 51
|
||||||
|
|
||||||
|
num_frames = idx1 - idx0
|
||||||
|
|
||||||
|
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||||
|
|
||||||
|
image = imgs[idx0:idx1]
|
||||||
|
|
||||||
|
state = states[idx0:idx1]
|
||||||
|
agent_pos = state[:, :2]
|
||||||
|
block_pos = state[:, 2:4]
|
||||||
|
block_angle = state[:, 4]
|
||||||
|
|
||||||
|
reward = torch.zeros(num_frames, 1)
|
||||||
|
success = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||||
|
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||||
|
for i in range(num_frames):
|
||||||
|
space = pymunk.Space()
|
||||||
|
space.gravity = 0, 0
|
||||||
|
space.damping = 0
|
||||||
|
|
||||||
|
# Add walls.
|
||||||
|
walls = [
|
||||||
|
add_segment(space, (5, 506), (5, 5), 2),
|
||||||
|
add_segment(space, (5, 5), (506, 5), 2),
|
||||||
|
add_segment(space, (506, 5), (506, 506), 2),
|
||||||
|
add_segment(space, (5, 506), (506, 506), 2),
|
||||||
|
]
|
||||||
|
space.add(*walls)
|
||||||
|
|
||||||
|
block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||||
|
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||||
|
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||||
|
intersection_area = goal_geom.intersection(block_geom).area
|
||||||
|
goal_area = goal_geom.area
|
||||||
|
coverage = intersection_area / goal_area
|
||||||
|
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
|
||||||
|
success[i] = coverage > SUCCESS_THRESHOLD
|
||||||
|
|
||||||
|
# last step of demonstration is considered done
|
||||||
|
done[-1] = True
|
||||||
|
|
||||||
|
ep_td = TensorDict(
|
||||||
|
{
|
||||||
|
("observation", "image"): image[:-1],
|
||||||
|
("observation", "state"): agent_pos[:-1],
|
||||||
|
"action": actions[idx0:idx1][:-1],
|
||||||
|
"episode": episode_ids[idx0:idx1][:-1],
|
||||||
|
"frame_id": torch.arange(0, num_frames - 1, 1),
|
||||||
|
("next", "observation", "image"): image[1:],
|
||||||
|
("next", "observation", "state"): agent_pos[1:],
|
||||||
|
# TODO: verify that reward and done are aligned with image and agent_pos
|
||||||
|
("next", "reward"): reward[1:],
|
||||||
|
("next", "done"): done[1:],
|
||||||
|
("next", "success"): success[1:],
|
||||||
|
},
|
||||||
|
batch_size=num_frames - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if episode_id == 0:
|
||||||
|
# hack to initialize tensordict data structure to store episodes
|
||||||
|
td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||||
|
|
||||||
|
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||||
|
|
||||||
|
idx0 = idx1
|
||||||
|
idxtd = idxtd + len(ep_td)
|
||||||
|
|
||||||
|
return TensorStorage(td_data.lock_())
|
||||||
127
lerobot/common/datasets/simxarm.py
Normal file
127
lerobot/common/datasets/simxarm.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import pickle
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchrl
|
||||||
|
import tqdm
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.replay_buffers.samplers import (
|
||||||
|
Sampler,
|
||||||
|
)
|
||||||
|
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||||
|
from torchrl.data.replay_buffers.writers import Writer
|
||||||
|
|
||||||
|
from lerobot.common.datasets.abstract import AbstractDataset
|
||||||
|
|
||||||
|
|
||||||
|
def download():
|
||||||
|
raise NotImplementedError()
|
||||||
|
import gdown
|
||||||
|
|
||||||
|
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||||
|
download_path = "data.zip"
|
||||||
|
gdown.download(url, download_path, quiet=False)
|
||||||
|
print("Extracting...")
|
||||||
|
with zipfile.ZipFile(download_path, "r") as zip_f:
|
||||||
|
for member in zip_f.namelist():
|
||||||
|
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||||
|
print(member)
|
||||||
|
zip_f.extract(member=member)
|
||||||
|
Path(download_path).unlink()
|
||||||
|
|
||||||
|
|
||||||
|
class SimxarmDataset(AbstractDataset):
|
||||||
|
available_datasets = [
|
||||||
|
"xarm_lift_medium",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
version: str | None = "v1.1",
|
||||||
|
batch_size: int | None = None,
|
||||||
|
*,
|
||||||
|
shuffle: bool = True,
|
||||||
|
root: Path | None = None,
|
||||||
|
pin_memory: bool = False,
|
||||||
|
prefetch: int = None,
|
||||||
|
sampler: Sampler | None = None,
|
||||||
|
collate_fn: Callable | None = None,
|
||||||
|
writer: Writer | None = None,
|
||||||
|
transform: "torchrl.envs.Transform" = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
dataset_id,
|
||||||
|
version,
|
||||||
|
batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
root=root,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
prefetch=prefetch,
|
||||||
|
sampler=sampler,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
writer=writer,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _download_and_preproc_obsolete(self):
|
||||||
|
# assert self.root is not None
|
||||||
|
# TODO(rcadene): finish download
|
||||||
|
# download()
|
||||||
|
|
||||||
|
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
|
||||||
|
print(f"Using offline dataset '{dataset_path}'")
|
||||||
|
with open(dataset_path, "rb") as f:
|
||||||
|
dataset_dict = pickle.load(f)
|
||||||
|
|
||||||
|
total_frames = dataset_dict["actions"].shape[0]
|
||||||
|
|
||||||
|
idx0 = 0
|
||||||
|
idx1 = 0
|
||||||
|
episode_id = 0
|
||||||
|
for i in tqdm.tqdm(range(total_frames)):
|
||||||
|
idx1 += 1
|
||||||
|
|
||||||
|
if not dataset_dict["dones"][i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_frames = idx1 - idx0
|
||||||
|
|
||||||
|
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||||
|
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||||
|
next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||||
|
next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||||
|
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||||
|
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||||
|
|
||||||
|
episode = TensorDict(
|
||||||
|
{
|
||||||
|
("observation", "image"): image,
|
||||||
|
("observation", "state"): state,
|
||||||
|
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
|
||||||
|
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||||
|
"frame_id": torch.arange(0, num_frames, 1),
|
||||||
|
("next", "observation", "image"): next_image,
|
||||||
|
("next", "observation", "state"): next_state,
|
||||||
|
("next", "reward"): next_reward,
|
||||||
|
("next", "done"): next_done,
|
||||||
|
},
|
||||||
|
batch_size=num_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
if episode_id == 0:
|
||||||
|
# hack to initialize tensordict data structure to store episodes
|
||||||
|
td_data = (
|
||||||
|
episode[0]
|
||||||
|
.expand(total_frames)
|
||||||
|
.memmap_like(self.root / f"{self.dataset_id}" / "replay_buffer")
|
||||||
|
)
|
||||||
|
|
||||||
|
td_data[idx0:idx1] = episode
|
||||||
|
|
||||||
|
episode_id += 1
|
||||||
|
idx0 = idx1
|
||||||
|
|
||||||
|
return TensorStorage(td_data.lock_())
|
||||||
@@ -1,359 +1,30 @@
|
|||||||
import json
|
import io
|
||||||
from copy import deepcopy
|
import zipfile
|
||||||
from math import ceil
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import requests
|
||||||
import einops
|
|
||||||
import torch
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from datasets import Image, load_dataset, load_from_disk
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
|
|
||||||
def flatten_dict(d, parent_key="", sep="/"):
|
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
print(f"downloading from {url}")
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
if response.status_code == 200:
|
||||||
|
total_size = int(response.headers.get("content-length", 0))
|
||||||
|
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||||
|
|
||||||
For example:
|
zip_file = io.BytesIO()
|
||||||
```
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
|
if chunk:
|
||||||
>>> print(flatten_dict(dct))
|
zip_file.write(chunk)
|
||||||
{"a/b": 1, "a/c/d": 2, "e": 3}
|
progress_bar.update(len(chunk))
|
||||||
"""
|
|
||||||
items = []
|
|
||||||
for k, v in d.items():
|
|
||||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
|
||||||
if isinstance(v, dict):
|
|
||||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
|
||||||
else:
|
|
||||||
items.append((new_key, v))
|
|
||||||
return dict(items)
|
|
||||||
|
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
def unflatten_dict(d, sep="/"):
|
zip_file.seek(0)
|
||||||
outdict = {}
|
|
||||||
for key, value in d.items():
|
|
||||||
parts = key.split(sep)
|
|
||||||
d = outdict
|
|
||||||
for part in parts[:-1]:
|
|
||||||
if part not in d:
|
|
||||||
d[part] = {}
|
|
||||||
d = d[part]
|
|
||||||
d[parts[-1]] = value
|
|
||||||
return outdict
|
|
||||||
|
|
||||||
|
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||||
def hf_transform_to_torch(items_dict):
|
zip_ref.extractall(destination_folder)
|
||||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
return True
|
||||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
|
||||||
a channel last representation (h w c) of uint8 type, to a torch image representation
|
|
||||||
with channel first (c h w) of float32 type in range [0,1].
|
|
||||||
"""
|
|
||||||
for key in items_dict:
|
|
||||||
first_item = items_dict[key][0]
|
|
||||||
if isinstance(first_item, PILImage.Image):
|
|
||||||
to_tensor = transforms.ToTensor()
|
|
||||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
|
||||||
else:
|
|
||||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
|
||||||
return items_dict
|
|
||||||
|
|
||||||
|
|
||||||
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
|
||||||
if root is not None:
|
|
||||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / split))
|
|
||||||
else:
|
else:
|
||||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
return False
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
return hf_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
|
|
||||||
"""episode_data_index contains the range of indices for each episode
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
from_id = episode_data_index["from"][episode_id].item()
|
|
||||||
to_id = episode_data_index["to"][episode_id].item()
|
|
||||||
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if root is not None:
|
|
||||||
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
|
||||||
else:
|
|
||||||
path = hf_hub_download(
|
|
||||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
|
|
||||||
)
|
|
||||||
|
|
||||||
return load_file(path)
|
|
||||||
|
|
||||||
|
|
||||||
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
|
||||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if root is not None:
|
|
||||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
|
||||||
else:
|
|
||||||
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
|
|
||||||
|
|
||||||
stats = load_file(path)
|
|
||||||
return unflatten_dict(stats)
|
|
||||||
|
|
||||||
|
|
||||||
def load_info(repo_id, version, root) -> dict:
|
|
||||||
"""info contains useful information regarding the dataset that are not stored elsewhere
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
print("frame per second used to collect the video", info["fps"])
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if root is not None:
|
|
||||||
path = Path(root) / repo_id / "meta_data" / "info.json"
|
|
||||||
else:
|
|
||||||
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
|
|
||||||
|
|
||||||
with open(path) as f:
|
|
||||||
info = json.load(f)
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
def load_previous_and_future_frames(
|
|
||||||
item: dict[str, torch.Tensor],
|
|
||||||
hf_dataset: datasets.Dataset,
|
|
||||||
episode_data_index: dict[str, torch.Tensor],
|
|
||||||
delta_timestamps: dict[str, list[float]],
|
|
||||||
tol: float,
|
|
||||||
) -> dict[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
|
|
||||||
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
|
|
||||||
given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
|
|
||||||
|
|
||||||
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
|
|
||||||
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
|
|
||||||
the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function
|
|
||||||
populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array
|
|
||||||
is useful during batched training to not supervise actions associated to timestamps coming after the end of the
|
|
||||||
episode, or to pad the observations in a specific way. Note that by default the observation frames before the start
|
|
||||||
of the episode are the same as the first frame of the episode.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key
|
|
||||||
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
|
||||||
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
|
|
||||||
modality (e.g., "timestamp", "observation.image", "action").
|
|
||||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
|
||||||
They indicate the start index and end index of each episode in the dataset.
|
|
||||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
|
|
||||||
retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
|
||||||
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
|
|
||||||
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
|
|
||||||
smallest expected inter-frame period, but large enough to account for jitter.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for
|
|
||||||
each modality (e.g. "observation.image_is_pad").
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization
|
|
||||||
issues with timestamps during data collection.
|
|
||||||
"""
|
|
||||||
# get indices of the frames associated to the episode, and their timestamps
|
|
||||||
ep_id = item["episode_index"].item()
|
|
||||||
ep_data_id_from = episode_data_index["from"][ep_id].item()
|
|
||||||
ep_data_id_to = episode_data_index["to"][ep_id].item()
|
|
||||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
|
||||||
|
|
||||||
# load timestamps
|
|
||||||
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
|
||||||
ep_timestamps = torch.stack(ep_timestamps)
|
|
||||||
|
|
||||||
# we make the assumption that the timestamps are sorted
|
|
||||||
ep_first_ts = ep_timestamps[0]
|
|
||||||
ep_last_ts = ep_timestamps[-1]
|
|
||||||
current_ts = item["timestamp"].item()
|
|
||||||
|
|
||||||
for key in delta_timestamps:
|
|
||||||
# get timestamps used as query to retrieve data of previous/future frames
|
|
||||||
delta_ts = delta_timestamps[key]
|
|
||||||
query_ts = current_ts + torch.tensor(delta_ts)
|
|
||||||
|
|
||||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
|
||||||
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
|
|
||||||
min_, argmin_ = dist.min(1)
|
|
||||||
|
|
||||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
|
||||||
|
|
||||||
is_pad = min_ > tol
|
|
||||||
|
|
||||||
# check violated query timestamps are all outside the episode range
|
|
||||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
|
||||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
|
|
||||||
"This might be due to synchronization issues with timestamps during data collection."
|
|
||||||
)
|
|
||||||
|
|
||||||
# get dataset indices corresponding to frames to be loaded
|
|
||||||
data_ids = ep_data_ids[argmin_]
|
|
||||||
|
|
||||||
# load frames modality
|
|
||||||
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
|
||||||
item[key] = torch.stack(item[key])
|
|
||||||
item[f"{key}_is_pad"] = is_pad
|
|
||||||
|
|
||||||
return item
|
|
||||||
|
|
||||||
|
|
||||||
def get_stats_einops_patterns(hf_dataset):
|
|
||||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
|
||||||
|
|
||||||
Note: We assume the images of `hf_dataset` are in channel first format
|
|
||||||
"""
|
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
hf_dataset,
|
|
||||||
num_workers=0,
|
|
||||||
batch_size=2,
|
|
||||||
shuffle=False,
|
|
||||||
)
|
|
||||||
batch = next(iter(dataloader))
|
|
||||||
|
|
||||||
stats_patterns = {}
|
|
||||||
for key, feats_type in hf_dataset.features.items():
|
|
||||||
# sanity check that tensors are not float64
|
|
||||||
assert batch[key].dtype != torch.float64
|
|
||||||
|
|
||||||
if isinstance(feats_type, Image):
|
|
||||||
# sanity check that images are channel first
|
|
||||||
_, c, h, w = batch[key].shape
|
|
||||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
|
||||||
|
|
||||||
# sanity check that images are float32 in range [0,1]
|
|
||||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
|
||||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
|
||||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
|
||||||
|
|
||||||
stats_patterns[key] = "b c h w -> c 1 1"
|
|
||||||
elif batch[key].ndim == 2:
|
|
||||||
stats_patterns[key] = "b c -> c "
|
|
||||||
elif batch[key].ndim == 1:
|
|
||||||
stats_patterns[key] = "b -> 1"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
|
||||||
|
|
||||||
return stats_patterns
|
|
||||||
|
|
||||||
|
|
||||||
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
|
|
||||||
if max_num_samples is None:
|
|
||||||
max_num_samples = len(hf_dataset)
|
|
||||||
|
|
||||||
stats_patterns = get_stats_einops_patterns(hf_dataset)
|
|
||||||
|
|
||||||
# mean and std will be computed incrementally while max and min will track the running value.
|
|
||||||
mean, std, max, min = {}, {}, {}, {}
|
|
||||||
for key in stats_patterns:
|
|
||||||
mean[key] = torch.tensor(0.0).float()
|
|
||||||
std[key] = torch.tensor(0.0).float()
|
|
||||||
max[key] = torch.tensor(-float("inf")).float()
|
|
||||||
min[key] = torch.tensor(float("inf")).float()
|
|
||||||
|
|
||||||
def create_seeded_dataloader(hf_dataset, batch_size, seed):
|
|
||||||
generator = torch.Generator()
|
|
||||||
generator.manual_seed(seed)
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
hf_dataset,
|
|
||||||
num_workers=4,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
drop_last=False,
|
|
||||||
generator=generator,
|
|
||||||
)
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
|
||||||
# surprises when rerunning the sampler.
|
|
||||||
first_batch = None
|
|
||||||
running_item_count = 0 # for online mean computation
|
|
||||||
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
|
||||||
for i, batch in enumerate(
|
|
||||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
|
||||||
):
|
|
||||||
this_batch_size = len(batch["index"])
|
|
||||||
running_item_count += this_batch_size
|
|
||||||
if first_batch is None:
|
|
||||||
first_batch = deepcopy(batch)
|
|
||||||
for key, pattern in stats_patterns.items():
|
|
||||||
batch[key] = batch[key].float()
|
|
||||||
# Numerically stable update step for mean computation.
|
|
||||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
|
||||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
|
||||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
|
||||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
|
||||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
|
||||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
|
||||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
|
||||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
|
||||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
|
||||||
|
|
||||||
if i == ceil(max_num_samples / batch_size) - 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
first_batch_ = None
|
|
||||||
running_item_count = 0 # for online std computation
|
|
||||||
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
|
||||||
for i, batch in enumerate(
|
|
||||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
|
||||||
):
|
|
||||||
this_batch_size = len(batch["index"])
|
|
||||||
running_item_count += this_batch_size
|
|
||||||
# Sanity check to make sure the batches are still in the same order as before.
|
|
||||||
if first_batch_ is None:
|
|
||||||
first_batch_ = deepcopy(batch)
|
|
||||||
for key in stats_patterns:
|
|
||||||
assert torch.equal(first_batch_[key], first_batch[key])
|
|
||||||
for key, pattern in stats_patterns.items():
|
|
||||||
batch[key] = batch[key].float()
|
|
||||||
# Numerically stable update step for mean computation (where the mean is over squared
|
|
||||||
# residuals).See notes in the mean computation loop above.
|
|
||||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
|
||||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
|
||||||
|
|
||||||
if i == ceil(max_num_samples / batch_size) - 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
for key in stats_patterns:
|
|
||||||
std[key] = torch.sqrt(std[key])
|
|
||||||
|
|
||||||
stats = {}
|
|
||||||
for key in stats_patterns:
|
|
||||||
stats[key] = {
|
|
||||||
"mean": mean[key],
|
|
||||||
"std": std[key],
|
|
||||||
"max": max[key],
|
|
||||||
"min": min[key],
|
|
||||||
}
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
|
|
||||||
def cycle(iterable):
|
|
||||||
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
|
||||||
|
|
||||||
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
|
|
||||||
"""
|
|
||||||
iterator = iter(iterable)
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
yield next(iterator)
|
|
||||||
except StopIteration:
|
|
||||||
iterator = iter(iterable)
|
|
||||||
|
|||||||
92
lerobot/common/envs/abstract.py
Normal file
92
lerobot/common/envs/abstract.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
from collections import deque
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.envs import EnvBase
|
||||||
|
|
||||||
|
from lerobot.common.utils import set_global_seed
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractEnv(EnvBase):
|
||||||
|
"""
|
||||||
|
Note:
|
||||||
|
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||||
|
1. set the required class attributes:
|
||||||
|
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||||
|
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||||
|
- for classes inheriting from `AbstractPolicy`: `name`
|
||||||
|
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||||
|
3. update variables in `tests/test_available.py` by importing your new class
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str | None = None # same name should be used to instantiate the environment in factory.py
|
||||||
|
available_tasks: list[str] | None = None # for instance: sim_insertion, sim_transfer_cube, pusht, lift
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task,
|
||||||
|
frame_skip: int = 1,
|
||||||
|
from_pixels: bool = False,
|
||||||
|
pixels_only: bool = False,
|
||||||
|
image_size=None,
|
||||||
|
seed=1337,
|
||||||
|
device="cpu",
|
||||||
|
num_prev_obs=1,
|
||||||
|
num_prev_action=0,
|
||||||
|
):
|
||||||
|
super().__init__(device=device, batch_size=[])
|
||||||
|
assert self.name is not None, "Subclasses of `AbstractEnv` should set the `name` class attribute."
|
||||||
|
assert (
|
||||||
|
self.available_tasks is not None
|
||||||
|
), "Subclasses of `AbstractEnv` should set the `available_tasks` class attribute."
|
||||||
|
assert (
|
||||||
|
task in self.available_tasks
|
||||||
|
), f"The provided task ({task}) is not on the list of available tasks {self.available_tasks}."
|
||||||
|
|
||||||
|
self.task = task
|
||||||
|
self.frame_skip = frame_skip
|
||||||
|
self.from_pixels = from_pixels
|
||||||
|
self.pixels_only = pixels_only
|
||||||
|
self.image_size = image_size
|
||||||
|
self.num_prev_obs = num_prev_obs
|
||||||
|
self.num_prev_action = num_prev_action
|
||||||
|
|
||||||
|
if pixels_only:
|
||||||
|
assert from_pixels
|
||||||
|
if from_pixels:
|
||||||
|
assert image_size
|
||||||
|
|
||||||
|
self._make_env()
|
||||||
|
self._make_spec()
|
||||||
|
|
||||||
|
# self._next_seed will be used for the next reset. It is recommended that when self.set_seed is called
|
||||||
|
# you store the return value in self._next_seed (it will be a new randomly generated seed).
|
||||||
|
self._next_seed = seed
|
||||||
|
# Don't store the result of this in self._next_seed, as we want to make sure that the first time
|
||||||
|
# self._reset is called, we use seed.
|
||||||
|
self.set_seed(seed)
|
||||||
|
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||||
|
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||||
|
if self.num_prev_action > 0:
|
||||||
|
raise NotImplementedError()
|
||||||
|
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||||
|
|
||||||
|
def render(self, mode="rgb_array", width=640, height=480):
|
||||||
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
|
def _step(self, tensordict: TensorDict):
|
||||||
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
|
def _make_env(self):
|
||||||
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
|
def _make_spec(self):
|
||||||
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
|
def _set_seed(self, seed: Optional[int]):
|
||||||
|
set_global_seed(seed)
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
<mujoco>
|
||||||
|
<include file="scene.xml"/>
|
||||||
|
<include file="vx300s_dependencies.xml"/>
|
||||||
|
|
||||||
|
<equality>
|
||||||
|
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||||
|
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||||
|
</equality>
|
||||||
|
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<include file="vx300s_left.xml" />
|
||||||
|
<include file="vx300s_right.xml" />
|
||||||
|
|
||||||
|
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
|
||||||
|
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
|
||||||
|
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
|
||||||
|
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
|
||||||
|
</body>
|
||||||
|
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
|
||||||
|
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
|
||||||
|
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
|
||||||
|
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="peg" pos="0.2 0.5 0.05">
|
||||||
|
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
|
||||||
|
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="socket" pos="-0.2 0.5 0.05">
|
||||||
|
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
|
||||||
|
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||||
|
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||||
|
|
||||||
|
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||||
|
|
||||||
|
</actuator>
|
||||||
|
|
||||||
|
<keyframe>
|
||||||
|
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
|
||||||
|
</keyframe>
|
||||||
|
|
||||||
|
|
||||||
|
</mujoco>
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
<mujoco>
|
||||||
|
<include file="scene.xml"/>
|
||||||
|
<include file="vx300s_dependencies.xml"/>
|
||||||
|
|
||||||
|
<equality>
|
||||||
|
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||||
|
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||||
|
</equality>
|
||||||
|
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<include file="vx300s_left.xml" />
|
||||||
|
<include file="vx300s_right.xml" />
|
||||||
|
|
||||||
|
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
|
||||||
|
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
|
||||||
|
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
|
||||||
|
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
|
||||||
|
</body>
|
||||||
|
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
|
||||||
|
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
|
||||||
|
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
|
||||||
|
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="box" pos="0.2 0.5 0.05">
|
||||||
|
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||||
|
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||||
|
|
||||||
|
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||||
|
|
||||||
|
</actuator>
|
||||||
|
|
||||||
|
<keyframe>
|
||||||
|
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
|
||||||
|
</keyframe>
|
||||||
|
|
||||||
|
|
||||||
|
</mujoco>
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
<mujoco>
|
||||||
|
<include file="scene.xml"/>
|
||||||
|
<include file="vx300s_dependencies.xml"/>
|
||||||
|
<worldbody>
|
||||||
|
<include file="vx300s_left.xml" />
|
||||||
|
<include file="vx300s_right.xml" />
|
||||||
|
|
||||||
|
<body name="peg" pos="0.2 0.5 0.05">
|
||||||
|
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
|
||||||
|
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="socket" pos="-0.2 0.5 0.05">
|
||||||
|
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
|
||||||
|
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||||
|
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||||
|
|
||||||
|
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||||
|
|
||||||
|
</actuator>
|
||||||
|
|
||||||
|
<keyframe>
|
||||||
|
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
|
||||||
|
</keyframe>
|
||||||
|
|
||||||
|
|
||||||
|
</mujoco>
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
<mujoco>
|
||||||
|
<include file="scene.xml"/>
|
||||||
|
<include file="vx300s_dependencies.xml"/>
|
||||||
|
<worldbody>
|
||||||
|
<include file="vx300s_left.xml" />
|
||||||
|
<include file="vx300s_right.xml" />
|
||||||
|
|
||||||
|
<body name="box" pos="0.2 0.5 0.05">
|
||||||
|
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||||
|
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||||
|
|
||||||
|
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||||
|
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||||
|
|
||||||
|
</actuator>
|
||||||
|
|
||||||
|
<keyframe>
|
||||||
|
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
|
||||||
|
</keyframe>
|
||||||
|
|
||||||
|
|
||||||
|
</mujoco>
|
||||||
38
lerobot/common/envs/aloha/assets/scene.xml
Normal file
38
lerobot/common/envs/aloha/assets/scene.xml
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
<mujocoinclude>
|
||||||
|
<!-- <option timestep='0.0025' iterations="50" tolerance="1e-10" solver="Newton" jacobian="dense" cone="elliptic"/>-->
|
||||||
|
|
||||||
|
<asset>
|
||||||
|
<mesh file="tabletop.stl" name="tabletop" scale="0.001 0.001 0.001"/>
|
||||||
|
</asset>
|
||||||
|
|
||||||
|
<visual>
|
||||||
|
<map fogstart="1.5" fogend="5" force="0.1" znear="0.1"/>
|
||||||
|
<quality shadowsize="4096" offsamples="4"/>
|
||||||
|
<headlight ambient="0.4 0.4 0.4"/>
|
||||||
|
</visual>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='-1 -1 1'
|
||||||
|
dir='1 1 -1'/>
|
||||||
|
<light directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='1 -1 1' dir='-1 1 -1'/>
|
||||||
|
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='0 1 1'
|
||||||
|
dir='0 -1 -1'/>
|
||||||
|
|
||||||
|
<body name="table" pos="0 .6 0">
|
||||||
|
<geom group="1" mesh="tabletop" pos="0 0 0" type="mesh" conaffinity="1" contype="1" name="table" rgba="0.2 0.2 0.2 1" />
|
||||||
|
</body>
|
||||||
|
<body name="midair" pos="0 .6 0.2">
|
||||||
|
<site pos="0 0 0" size="0.01" type="sphere" name="midair" rgba="1 0 0 0"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<camera name="left_pillar" pos="-0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||||
|
<camera name="right_pillar" pos="0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||||
|
<camera name="top" pos="0 0.6 0.8" fovy="78" mode="targetbody" target="table"/>
|
||||||
|
<camera name="angle" pos="0 0 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||||
|
<camera name="front_close" pos="0 0.2 0.4" fovy="78" mode="targetbody" target="vx300s_left/camera_focus"/>
|
||||||
|
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
</mujocoinclude>
|
||||||
3
lerobot/common/envs/aloha/assets/tabletop.stl
Normal file
3
lerobot/common/envs/aloha/assets/tabletop.stl
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:76a1571d1aa36520f2bd81c268991b99816c2a7819464d718e0fd9976fe30dce
|
||||||
|
size 684
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:df73ae5b9058e5d50a6409ac2ab687dade75053a86591bb5e23ab051dbf2d659
|
||||||
|
size 83384
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:56fb3cc1236d4193106038adf8e457c7252ae9e86c7cee6dabf0578c53666358
|
||||||
|
size 83384
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:a4baacd9a64df1be60ea5e98f50f3c660e1b7a1fe9684aace6004c5058c09483
|
||||||
|
size 42884
|
||||||
3
lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl
Normal file
3
lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:a18a1601074d29ed1d546ead70cd18fbb063f1db7b5b96b9f0365be714f3136a
|
||||||
|
size 3884
|
||||||
3
lerobot/common/envs/aloha/assets/vx300s_1_base.stl
Normal file
3
lerobot/common/envs/aloha/assets/vx300s_1_base.stl
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:d100cafe656671ca8fde98fb6a4cf2d1b746995c51c61c25ad9ea2715635d146
|
||||||
|
size 99984
|
||||||
3
lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl
Normal file
3
lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:139745a74055cb0b23430bb5bc032bf68cf7bea5e4975c8f4c04107ae005f7f0
|
||||||
|
size 63884
|
||||||
3
lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl
Normal file
3
lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:900f236320dd3d500870c5fde763b2d47502d51e043a5c377875e70237108729
|
||||||
|
size 102984
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:4104fc54bbfb8a9b533029f1e7e3ade3d54d638372b3195daa0c98f57e0295b5
|
||||||
|
size 49584
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:66814e27fa728056416e25e02e89eb7d34c51d51c51e7c3df873829037ddc6b8
|
||||||
|
size 99884
|
||||||
3
lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl
Normal file
3
lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:90eb145c85627968c3776ae6de23ccff7e112c9dd713c46bc9acdfdaa859a048
|
||||||
|
size 70784
|
||||||
3
lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl
Normal file
3
lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:786c1077bfd226f14219581b11d5f19464ca95b17132e0bb7532503568f5af90
|
||||||
|
size 450084
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:d1275a93fe2157c83dbc095617fb7e672888bdd48ec070a35ef4ab9ebd9755b0
|
||||||
|
size 31684
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:a4de62c9a2ed2c78433010e4c05530a1254b1774a7651967f406120c9bf8973e
|
||||||
|
size 379484
|
||||||
17
lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
Normal file
17
lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
<mujocoinclude>
|
||||||
|
<compiler angle="radian" inertiafromgeom="auto" inertiagrouprange="4 5"/>
|
||||||
|
<asset>
|
||||||
|
<mesh name="vx300s_1_base" file="vx300s_1_base.stl" scale="0.001 0.001 0.001" />
|
||||||
|
<mesh name="vx300s_2_shoulder" file="vx300s_2_shoulder.stl" scale="0.001 0.001 0.001" />
|
||||||
|
<mesh name="vx300s_3_upper_arm" file="vx300s_3_upper_arm.stl" scale="0.001 0.001 0.001" />
|
||||||
|
<mesh name="vx300s_4_upper_forearm" file="vx300s_4_upper_forearm.stl" scale="0.001 0.001 0.001" />
|
||||||
|
<mesh name="vx300s_5_lower_forearm" file="vx300s_5_lower_forearm.stl" scale="0.001 0.001 0.001" />
|
||||||
|
<mesh name="vx300s_6_wrist" file="vx300s_6_wrist.stl" scale="0.001 0.001 0.001" />
|
||||||
|
<mesh name="vx300s_7_gripper" file="vx300s_7_gripper.stl" scale="0.001 0.001 0.001" />
|
||||||
|
<mesh name="vx300s_8_gripper_prop" file="vx300s_8_gripper_prop.stl" scale="0.001 0.001 0.001" />
|
||||||
|
<mesh name="vx300s_9_gripper_bar" file="vx300s_9_gripper_bar.stl" scale="0.001 0.001 0.001" />
|
||||||
|
<mesh name="vx300s_10_gripper_finger_left" file="vx300s_10_custom_finger_left.stl" scale="0.001 0.001 0.001" />
|
||||||
|
<mesh name="vx300s_10_gripper_finger_right" file="vx300s_10_custom_finger_right.stl" scale="0.001 0.001 0.001" />
|
||||||
|
</asset>
|
||||||
|
|
||||||
|
</mujocoinclude>
|
||||||
59
lerobot/common/envs/aloha/assets/vx300s_left.xml
Normal file
59
lerobot/common/envs/aloha/assets/vx300s_left.xml
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
|
||||||
|
<mujocoinclude>
|
||||||
|
<body name="vx300s_left" pos="-0.469 0.5 0">
|
||||||
|
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_left/1_base" contype="0" conaffinity="0"/>
|
||||||
|
<body name="vx300s_left/shoulder_link" pos="0 0 0.079">
|
||||||
|
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
|
||||||
|
<joint name="vx300s_left/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
|
||||||
|
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_left/2_shoulder" />
|
||||||
|
<body name="vx300s_left/upper_arm_link" pos="0 0 0.04805">
|
||||||
|
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
|
||||||
|
<joint name="vx300s_left/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
|
||||||
|
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_left/3_upper_arm"/>
|
||||||
|
<body name="vx300s_left/upper_forearm_link" pos="0.05955 0 0.3">
|
||||||
|
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
|
||||||
|
<joint name="vx300s_left/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
|
||||||
|
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_left/4_upper_forearm" />
|
||||||
|
<body name="vx300s_left/lower_forearm_link" pos="0.2 0 0">
|
||||||
|
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
|
||||||
|
<joint name="vx300s_left/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||||
|
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_left/5_lower_forearm"/>
|
||||||
|
<body name="vx300s_left/wrist_link" pos="0.1 0 0">
|
||||||
|
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
|
||||||
|
<joint name="vx300s_left/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
|
||||||
|
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_left/6_wrist" />
|
||||||
|
<body name="vx300s_left/gripper_link" pos="0.069744 0 0">
|
||||||
|
<body name="vx300s_left/camera_focus" pos="0.15 0 0.01">
|
||||||
|
<site pos="0 0 0" size="0.01" type="sphere" name="left_cam_focus" rgba="0 0 1 0"/>
|
||||||
|
</body>
|
||||||
|
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_left_site1" rgba="0 0 1 0"/>
|
||||||
|
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_left_site2" rgba="0 0 1 0"/>
|
||||||
|
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_left_site3" rgba="0 0 1 0"/>
|
||||||
|
<camera name="left_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_left/camera_focus"/>
|
||||||
|
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
|
||||||
|
<joint name="vx300s_left/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||||
|
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_left/7_gripper" />
|
||||||
|
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_left/9_gripper_bar" />
|
||||||
|
<body name="vx300s_left/gripper_prop_link" pos="0.0485 0 0">
|
||||||
|
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
|
||||||
|
<!-- <joint name="vx300s_left/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
|
||||||
|
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_left/8_gripper_prop" />
|
||||||
|
</body>
|
||||||
|
<body name="vx300s_left/left_finger_link" pos="0.0687 0 0">
|
||||||
|
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
|
||||||
|
<joint name="vx300s_left/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_left/10_left_gripper_finger"/>
|
||||||
|
</body>
|
||||||
|
<body name="vx300s_left/right_finger_link" pos="0.0687 0 0">
|
||||||
|
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
|
||||||
|
<joint name="vx300s_left/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_left/10_right_gripper_finger"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</mujocoinclude>
|
||||||
59
lerobot/common/envs/aloha/assets/vx300s_right.xml
Normal file
59
lerobot/common/envs/aloha/assets/vx300s_right.xml
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
|
||||||
|
<mujocoinclude>
|
||||||
|
<body name="vx300s_right" pos="0.469 0.5 0" euler="0 0 3.1416">
|
||||||
|
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_right/1_base" contype="0" conaffinity="0"/>
|
||||||
|
<body name="vx300s_right/shoulder_link" pos="0 0 0.079">
|
||||||
|
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
|
||||||
|
<joint name="vx300s_right/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
|
||||||
|
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_right/2_shoulder" />
|
||||||
|
<body name="vx300s_right/upper_arm_link" pos="0 0 0.04805">
|
||||||
|
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
|
||||||
|
<joint name="vx300s_right/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
|
||||||
|
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_right/3_upper_arm"/>
|
||||||
|
<body name="vx300s_right/upper_forearm_link" pos="0.05955 0 0.3">
|
||||||
|
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
|
||||||
|
<joint name="vx300s_right/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
|
||||||
|
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_right/4_upper_forearm" />
|
||||||
|
<body name="vx300s_right/lower_forearm_link" pos="0.2 0 0">
|
||||||
|
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
|
||||||
|
<joint name="vx300s_right/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||||
|
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_right/5_lower_forearm"/>
|
||||||
|
<body name="vx300s_right/wrist_link" pos="0.1 0 0">
|
||||||
|
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
|
||||||
|
<joint name="vx300s_right/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
|
||||||
|
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_right/6_wrist" />
|
||||||
|
<body name="vx300s_right/gripper_link" pos="0.069744 0 0">
|
||||||
|
<body name="vx300s_right/camera_focus" pos="0.15 0 0.01">
|
||||||
|
<site pos="0 0 0" size="0.01" type="sphere" name="right_cam_focus" rgba="0 0 1 0"/>
|
||||||
|
</body>
|
||||||
|
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_right_site1" rgba="0 0 1 0"/>
|
||||||
|
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_right_site2" rgba="0 0 1 0"/>
|
||||||
|
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_right_site3" rgba="0 0 1 0"/>
|
||||||
|
<camera name="right_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_right/camera_focus"/>
|
||||||
|
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
|
||||||
|
<joint name="vx300s_right/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||||
|
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_right/7_gripper" />
|
||||||
|
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_right/9_gripper_bar" />
|
||||||
|
<body name="vx300s_right/gripper_prop_link" pos="0.0485 0 0">
|
||||||
|
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
|
||||||
|
<!-- <joint name="vx300s_right/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
|
||||||
|
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_right/8_gripper_prop" />
|
||||||
|
</body>
|
||||||
|
<body name="vx300s_right/left_finger_link" pos="0.0687 0 0">
|
||||||
|
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
|
||||||
|
<joint name="vx300s_right/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_right/10_left_gripper_finger"/>
|
||||||
|
</body>
|
||||||
|
<body name="vx300s_right/right_finger_link" pos="0.0687 0 0">
|
||||||
|
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
|
||||||
|
<joint name="vx300s_right/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
|
||||||
|
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_right/10_right_gripper_finger"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</mujocoinclude>
|
||||||
163
lerobot/common/envs/aloha/constants.py
Normal file
163
lerobot/common/envs/aloha/constants.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
### Simulation envs fixed constants
|
||||||
|
DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz
|
||||||
|
FPS = 50
|
||||||
|
|
||||||
|
|
||||||
|
JOINTS = [
|
||||||
|
# absolute joint position
|
||||||
|
"left_arm_waist",
|
||||||
|
"left_arm_shoulder",
|
||||||
|
"left_arm_elbow",
|
||||||
|
"left_arm_forearm_roll",
|
||||||
|
"left_arm_wrist_angle",
|
||||||
|
"left_arm_wrist_rotate",
|
||||||
|
# normalized gripper position 0: close, 1: open
|
||||||
|
"left_arm_gripper",
|
||||||
|
# absolute joint position
|
||||||
|
"right_arm_waist",
|
||||||
|
"right_arm_shoulder",
|
||||||
|
"right_arm_elbow",
|
||||||
|
"right_arm_forearm_roll",
|
||||||
|
"right_arm_wrist_angle",
|
||||||
|
"right_arm_wrist_rotate",
|
||||||
|
# normalized gripper position 0: close, 1: open
|
||||||
|
"right_arm_gripper",
|
||||||
|
]
|
||||||
|
|
||||||
|
ACTIONS = [
|
||||||
|
# position and quaternion for end effector
|
||||||
|
"left_arm_waist",
|
||||||
|
"left_arm_shoulder",
|
||||||
|
"left_arm_elbow",
|
||||||
|
"left_arm_forearm_roll",
|
||||||
|
"left_arm_wrist_angle",
|
||||||
|
"left_arm_wrist_rotate",
|
||||||
|
# normalized gripper position (0: close, 1: open)
|
||||||
|
"left_arm_gripper",
|
||||||
|
"right_arm_waist",
|
||||||
|
"right_arm_shoulder",
|
||||||
|
"right_arm_elbow",
|
||||||
|
"right_arm_forearm_roll",
|
||||||
|
"right_arm_wrist_angle",
|
||||||
|
"right_arm_wrist_rotate",
|
||||||
|
# normalized gripper position (0: close, 1: open)
|
||||||
|
"right_arm_gripper",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
START_ARM_POSE = [
|
||||||
|
0,
|
||||||
|
-0.96,
|
||||||
|
1.16,
|
||||||
|
0,
|
||||||
|
-0.3,
|
||||||
|
0,
|
||||||
|
0.02239,
|
||||||
|
-0.02239,
|
||||||
|
0,
|
||||||
|
-0.96,
|
||||||
|
1.16,
|
||||||
|
0,
|
||||||
|
-0.3,
|
||||||
|
0,
|
||||||
|
0.02239,
|
||||||
|
-0.02239,
|
||||||
|
]
|
||||||
|
|
||||||
|
ASSETS_DIR = Path(__file__).parent.resolve() / "assets" # note: absolute path
|
||||||
|
|
||||||
|
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||||
|
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||||
|
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||||
|
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||||
|
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||||
|
|
||||||
|
# Gripper joint limits (qpos[6])
|
||||||
|
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||||
|
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||||
|
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||||
|
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||||
|
|
||||||
|
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
||||||
|
|
||||||
|
############################ Helper functions ############################
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_master_gripper_position(x):
|
||||||
|
return (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
||||||
|
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_puppet_gripper_position(x):
|
||||||
|
return (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
||||||
|
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def unnormalize_master_gripper_position(x):
|
||||||
|
return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||||
|
|
||||||
|
|
||||||
|
def unnormalize_puppet_gripper_position(x):
|
||||||
|
return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||||
|
|
||||||
|
|
||||||
|
def convert_position_from_master_to_puppet(x):
|
||||||
|
return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x))
|
||||||
|
|
||||||
|
|
||||||
|
def normalizer_master_gripper_joint(x):
|
||||||
|
return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_puppet_gripper_joint(x):
|
||||||
|
return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||||
|
|
||||||
|
|
||||||
|
def unnormalize_master_gripper_joint(x):
|
||||||
|
return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||||
|
|
||||||
|
|
||||||
|
def unnormalize_puppet_gripper_joint(x):
|
||||||
|
return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||||
|
|
||||||
|
|
||||||
|
def convert_join_from_master_to_puppet(x):
|
||||||
|
return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x))
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_master_gripper_velocity(x):
|
||||||
|
return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_puppet_gripper_velocity(x):
|
||||||
|
return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_master_from_position_to_joint(x):
|
||||||
|
return (
|
||||||
|
normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||||
|
+ MASTER_GRIPPER_JOINT_CLOSE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_master_from_joint_to_position(x):
|
||||||
|
return unnormalize_master_gripper_position(
|
||||||
|
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_puppet_from_position_to_join(x):
|
||||||
|
return (
|
||||||
|
normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||||
|
+ PUPPET_GRIPPER_JOINT_CLOSE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_puppet_from_joint_to_position(x):
|
||||||
|
return unnormalize_puppet_gripper_position(
|
||||||
|
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||||
|
)
|
||||||
298
lerobot/common/envs/aloha/env.py
Normal file
298
lerobot/common/envs/aloha/env.py
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
from collections import deque
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.tensor_specs import (
|
||||||
|
BoundedTensorSpec,
|
||||||
|
CompositeSpec,
|
||||||
|
DiscreteTensorSpec,
|
||||||
|
UnboundedContinuousTensorSpec,
|
||||||
|
)
|
||||||
|
|
||||||
|
from lerobot.common.envs.abstract import AbstractEnv
|
||||||
|
from lerobot.common.envs.aloha.constants import (
|
||||||
|
ACTIONS,
|
||||||
|
ASSETS_DIR,
|
||||||
|
DT,
|
||||||
|
JOINTS,
|
||||||
|
)
|
||||||
|
from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask
|
||||||
|
from lerobot.common.envs.aloha.tasks.sim_end_effector import (
|
||||||
|
InsertionEndEffectorTask,
|
||||||
|
TransferCubeEndEffectorTask,
|
||||||
|
)
|
||||||
|
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||||
|
from lerobot.common.utils import set_global_seed
|
||||||
|
|
||||||
|
_has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||||
|
|
||||||
|
|
||||||
|
class AlohaEnv(AbstractEnv):
|
||||||
|
name = "aloha"
|
||||||
|
available_tasks = ["sim_insertion", "sim_transfer_cube"]
|
||||||
|
_reset_warning_issued = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task,
|
||||||
|
frame_skip: int = 1,
|
||||||
|
from_pixels: bool = False,
|
||||||
|
pixels_only: bool = False,
|
||||||
|
image_size=None,
|
||||||
|
seed=1337,
|
||||||
|
device="cpu",
|
||||||
|
num_prev_obs=1,
|
||||||
|
num_prev_action=0,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
task=task,
|
||||||
|
frame_skip=frame_skip,
|
||||||
|
from_pixels=from_pixels,
|
||||||
|
pixels_only=pixels_only,
|
||||||
|
image_size=image_size,
|
||||||
|
seed=seed,
|
||||||
|
device=device,
|
||||||
|
num_prev_obs=num_prev_obs,
|
||||||
|
num_prev_action=num_prev_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_env(self):
|
||||||
|
if not _has_gym:
|
||||||
|
raise ImportError("Cannot import gymnasium.")
|
||||||
|
|
||||||
|
if not self.from_pixels:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
self._env = self._make_env_task(self.task)
|
||||||
|
|
||||||
|
def render(self, mode="rgb_array", width=640, height=480):
|
||||||
|
# TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
|
||||||
|
image = self._env.physics.render(height=height, width=width, camera_id="top")
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _make_env_task(self, task_name):
|
||||||
|
# time limit is controlled by StepCounter in env factory
|
||||||
|
time_limit = float("inf")
|
||||||
|
|
||||||
|
if "sim_transfer_cube" in task_name:
|
||||||
|
xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
|
||||||
|
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||||
|
task = TransferCubeTask(random=False)
|
||||||
|
elif "sim_insertion" in task_name:
|
||||||
|
xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
|
||||||
|
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||||
|
task = InsertionTask(random=False)
|
||||||
|
elif "sim_end_effector_transfer_cube" in task_name:
|
||||||
|
raise NotImplementedError()
|
||||||
|
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
|
||||||
|
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||||
|
task = TransferCubeEndEffectorTask(random=False)
|
||||||
|
elif "sim_end_effector_insertion" in task_name:
|
||||||
|
raise NotImplementedError()
|
||||||
|
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
|
||||||
|
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||||
|
task = InsertionEndEffectorTask(random=False)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(task_name)
|
||||||
|
|
||||||
|
env = control.Environment(
|
||||||
|
physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False
|
||||||
|
)
|
||||||
|
return env
|
||||||
|
|
||||||
|
def _format_raw_obs(self, raw_obs):
|
||||||
|
if self.from_pixels:
|
||||||
|
image = torch.from_numpy(raw_obs["images"]["top"].copy())
|
||||||
|
image = einops.rearrange(image, "h w c -> c h w")
|
||||||
|
assert image.dtype == torch.uint8
|
||||||
|
obs = {"image": {"top": image}}
|
||||||
|
|
||||||
|
if not self.pixels_only:
|
||||||
|
obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32)
|
||||||
|
else:
|
||||||
|
# TODO(rcadene):
|
||||||
|
raise NotImplementedError()
|
||||||
|
# obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||||
|
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
|
if tensordict is not None and not AlohaEnv._reset_warning_issued:
|
||||||
|
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||||
|
AlohaEnv._reset_warning_issued = True
|
||||||
|
|
||||||
|
# Seed the environment and update the seed to be used for the next reset.
|
||||||
|
self._next_seed = self.set_seed(self._next_seed)
|
||||||
|
|
||||||
|
# TODO(rcadene): do not use global variable for this
|
||||||
|
if "sim_transfer_cube" in self.task:
|
||||||
|
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||||
|
elif "sim_insertion" in self.task:
|
||||||
|
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||||
|
|
||||||
|
raw_obs = self._env.reset()
|
||||||
|
|
||||||
|
obs = self._format_raw_obs(raw_obs.observation)
|
||||||
|
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
stacked_obs = {}
|
||||||
|
if "image" in obs:
|
||||||
|
self._prev_obs_image_queue = deque(
|
||||||
|
[obs["image"]["top"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||||
|
)
|
||||||
|
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||||
|
if "state" in obs:
|
||||||
|
self._prev_obs_state_queue = deque(
|
||||||
|
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||||
|
)
|
||||||
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
|
obs = stacked_obs
|
||||||
|
|
||||||
|
td = TensorDict(
|
||||||
|
{
|
||||||
|
"observation": TensorDict(obs, batch_size=[]),
|
||||||
|
"done": torch.tensor([False], dtype=torch.bool),
|
||||||
|
},
|
||||||
|
batch_size=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
return td
|
||||||
|
|
||||||
|
def _step(self, tensordict: TensorDict):
|
||||||
|
td = tensordict
|
||||||
|
action = td["action"].numpy()
|
||||||
|
assert action.ndim == 1
|
||||||
|
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||||
|
|
||||||
|
_, reward, _, raw_obs = self._env.step(action)
|
||||||
|
|
||||||
|
# TODO(rcadene): add an enum
|
||||||
|
success = done = reward == 4
|
||||||
|
obs = self._format_raw_obs(raw_obs)
|
||||||
|
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
stacked_obs = {}
|
||||||
|
if "image" in obs:
|
||||||
|
self._prev_obs_image_queue.append(obs["image"]["top"])
|
||||||
|
stacked_obs["image"] = {"top": torch.stack(list(self._prev_obs_image_queue))}
|
||||||
|
if "state" in obs:
|
||||||
|
self._prev_obs_state_queue.append(obs["state"])
|
||||||
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
|
obs = stacked_obs
|
||||||
|
|
||||||
|
td = TensorDict(
|
||||||
|
{
|
||||||
|
"observation": TensorDict(obs, batch_size=[]),
|
||||||
|
"reward": torch.tensor([reward], dtype=torch.float32),
|
||||||
|
# success and done are true when coverage > self.success_threshold in env
|
||||||
|
"done": torch.tensor([done], dtype=torch.bool),
|
||||||
|
"success": torch.tensor([success], dtype=torch.bool),
|
||||||
|
},
|
||||||
|
batch_size=[],
|
||||||
|
)
|
||||||
|
return td
|
||||||
|
|
||||||
|
def _make_spec(self):
|
||||||
|
obs = {}
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
if self.from_pixels:
|
||||||
|
if isinstance(self.image_size, int):
|
||||||
|
image_shape = (3, self.image_size, self.image_size)
|
||||||
|
elif OmegaConf.is_list(self.image_size) or isinstance(self.image_size, list):
|
||||||
|
assert len(self.image_size) == 3 # c h w
|
||||||
|
assert self.image_size[0] == 3 # c is RGB
|
||||||
|
image_shape = tuple(self.image_size)
|
||||||
|
else:
|
||||||
|
raise ValueError(self.image_size)
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||||
|
|
||||||
|
obs["image"] = {
|
||||||
|
"top": BoundedTensorSpec(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=image_shape,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if not self.pixels_only:
|
||||||
|
state_shape = (len(JOINTS),)
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||||
|
|
||||||
|
obs["state"] = UnboundedContinuousTensorSpec(
|
||||||
|
# TODO: add low and high bounds
|
||||||
|
shape=state_shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||||
|
state_shape = (len(JOINTS),)
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||||
|
|
||||||
|
obs["state"] = UnboundedContinuousTensorSpec(
|
||||||
|
# TODO: add low and high bounds
|
||||||
|
shape=state_shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.observation_spec = CompositeSpec({"observation": obs})
|
||||||
|
|
||||||
|
# TODO(rcadene): valid when controling end effector?
|
||||||
|
# action_space = self._env.action_spec()
|
||||||
|
# self.action_spec = BoundedTensorSpec(
|
||||||
|
# low=action_space.minimum,
|
||||||
|
# high=action_space.maximum,
|
||||||
|
# shape=action_space.shape,
|
||||||
|
# dtype=torch.float32,
|
||||||
|
# device=self.device,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# TODO(rcaene): add bounds (where are they????)
|
||||||
|
self.action_spec = BoundedTensorSpec(
|
||||||
|
shape=(len(ACTIONS)),
|
||||||
|
low=-1,
|
||||||
|
high=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.done_spec = CompositeSpec(
|
||||||
|
{
|
||||||
|
"done": DiscreteTensorSpec(
|
||||||
|
2,
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"success": DiscreteTensorSpec(
|
||||||
|
2,
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_seed(self, seed: Optional[int]):
|
||||||
|
set_global_seed(seed)
|
||||||
|
# TODO(rcadene): seed the env
|
||||||
|
# self._env.seed(seed)
|
||||||
|
logging.warning("Aloha env is not seeded")
|
||||||
219
lerobot/common/envs/aloha/tasks/sim.py
Normal file
219
lerobot/common/envs/aloha/tasks/sim.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
import collections
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from dm_control.suite import base
|
||||||
|
|
||||||
|
from lerobot.common.envs.aloha.constants import (
|
||||||
|
START_ARM_POSE,
|
||||||
|
normalize_puppet_gripper_position,
|
||||||
|
normalize_puppet_gripper_velocity,
|
||||||
|
unnormalize_puppet_gripper_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
BOX_POSE = [None] # to be changed from outside
|
||||||
|
|
||||||
|
"""
|
||||||
|
Environment for simulated robot bi-manual manipulation, with joint position control
|
||||||
|
Action space: [left_arm_qpos (6), # absolute joint position
|
||||||
|
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||||
|
right_arm_qpos (6), # absolute joint position
|
||||||
|
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||||
|
|
||||||
|
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||||
|
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||||
|
right_arm_qpos (6), # absolute joint position
|
||||||
|
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||||
|
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||||
|
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||||
|
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||||
|
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||||
|
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BimanualViperXTask(base.Task):
|
||||||
|
def __init__(self, random=None):
|
||||||
|
super().__init__(random=random)
|
||||||
|
|
||||||
|
def before_step(self, action, physics):
|
||||||
|
left_arm_action = action[:6]
|
||||||
|
right_arm_action = action[7 : 7 + 6]
|
||||||
|
normalized_left_gripper_action = action[6]
|
||||||
|
normalized_right_gripper_action = action[7 + 6]
|
||||||
|
|
||||||
|
left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action)
|
||||||
|
right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action)
|
||||||
|
|
||||||
|
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
|
||||||
|
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
|
||||||
|
|
||||||
|
env_action = np.concatenate(
|
||||||
|
[left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action]
|
||||||
|
)
|
||||||
|
super().before_step(env_action, physics)
|
||||||
|
return
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
super().initialize_episode(physics)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_qpos(physics):
|
||||||
|
qpos_raw = physics.data.qpos.copy()
|
||||||
|
left_qpos_raw = qpos_raw[:8]
|
||||||
|
right_qpos_raw = qpos_raw[8:16]
|
||||||
|
left_arm_qpos = left_qpos_raw[:6]
|
||||||
|
right_arm_qpos = right_qpos_raw[:6]
|
||||||
|
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
|
||||||
|
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
|
||||||
|
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_qvel(physics):
|
||||||
|
qvel_raw = physics.data.qvel.copy()
|
||||||
|
left_qvel_raw = qvel_raw[:8]
|
||||||
|
right_qvel_raw = qvel_raw[8:16]
|
||||||
|
left_arm_qvel = left_qvel_raw[:6]
|
||||||
|
right_arm_qvel = right_qvel_raw[:6]
|
||||||
|
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
|
||||||
|
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
|
||||||
|
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_env_state(physics):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs["qpos"] = self.get_qpos(physics)
|
||||||
|
obs["qvel"] = self.get_qvel(physics)
|
||||||
|
obs["env_state"] = self.get_env_state(physics)
|
||||||
|
obs["images"] = {}
|
||||||
|
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
|
||||||
|
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
|
||||||
|
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
|
||||||
|
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
# return whether left gripper is holding the box
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class TransferCubeTask(BimanualViperXTask):
|
||||||
|
def __init__(self, random=None):
|
||||||
|
super().__init__(random=random)
|
||||||
|
self.max_reward = 4
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||||
|
# reset qpos, control and box position
|
||||||
|
with physics.reset_context():
|
||||||
|
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||||
|
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||||
|
assert BOX_POSE[0] is not None
|
||||||
|
physics.named.data.qpos[-7:] = BOX_POSE[0]
|
||||||
|
# print(f"{BOX_POSE=}")
|
||||||
|
super().initialize_episode(physics)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_env_state(physics):
|
||||||
|
env_state = physics.data.qpos.copy()[16:]
|
||||||
|
return env_state
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
# return whether left gripper is holding the box
|
||||||
|
all_contact_pairs = []
|
||||||
|
for i_contact in range(physics.data.ncon):
|
||||||
|
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||||
|
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||||
|
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||||
|
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||||
|
contact_pair = (name_geom_1, name_geom_2)
|
||||||
|
all_contact_pairs.append(contact_pair)
|
||||||
|
|
||||||
|
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||||
|
touch_table = ("red_box", "table") in all_contact_pairs
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
if touch_right_gripper:
|
||||||
|
reward = 1
|
||||||
|
if touch_right_gripper and not touch_table: # lifted
|
||||||
|
reward = 2
|
||||||
|
if touch_left_gripper: # attempted transfer
|
||||||
|
reward = 3
|
||||||
|
if touch_left_gripper and not touch_table: # successful transfer
|
||||||
|
reward = 4
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
class InsertionTask(BimanualViperXTask):
|
||||||
|
def __init__(self, random=None):
|
||||||
|
super().__init__(random=random)
|
||||||
|
self.max_reward = 4
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||||
|
# reset qpos, control and box position
|
||||||
|
with physics.reset_context():
|
||||||
|
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||||
|
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||||
|
assert BOX_POSE[0] is not None
|
||||||
|
physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects
|
||||||
|
# print(f"{BOX_POSE=}")
|
||||||
|
super().initialize_episode(physics)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_env_state(physics):
|
||||||
|
env_state = physics.data.qpos.copy()[16:]
|
||||||
|
return env_state
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
# return whether peg touches the pin
|
||||||
|
all_contact_pairs = []
|
||||||
|
for i_contact in range(physics.data.ncon):
|
||||||
|
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||||
|
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||||
|
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||||
|
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||||
|
contact_pair = (name_geom_1, name_geom_2)
|
||||||
|
all_contact_pairs.append(contact_pair)
|
||||||
|
|
||||||
|
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||||
|
touch_left_gripper = (
|
||||||
|
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
)
|
||||||
|
|
||||||
|
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||||
|
socket_touch_table = (
|
||||||
|
("socket-1", "table") in all_contact_pairs
|
||||||
|
or ("socket-2", "table") in all_contact_pairs
|
||||||
|
or ("socket-3", "table") in all_contact_pairs
|
||||||
|
or ("socket-4", "table") in all_contact_pairs
|
||||||
|
)
|
||||||
|
peg_touch_socket = (
|
||||||
|
("red_peg", "socket-1") in all_contact_pairs
|
||||||
|
or ("red_peg", "socket-2") in all_contact_pairs
|
||||||
|
or ("red_peg", "socket-3") in all_contact_pairs
|
||||||
|
or ("red_peg", "socket-4") in all_contact_pairs
|
||||||
|
)
|
||||||
|
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
if touch_left_gripper and touch_right_gripper: # touch both
|
||||||
|
reward = 1
|
||||||
|
if (
|
||||||
|
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
|
||||||
|
): # grasp both
|
||||||
|
reward = 2
|
||||||
|
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||||
|
reward = 3
|
||||||
|
if pin_touched: # successful insertion
|
||||||
|
reward = 4
|
||||||
|
return reward
|
||||||
263
lerobot/common/envs/aloha/tasks/sim_end_effector.py
Normal file
263
lerobot/common/envs/aloha/tasks/sim_end_effector.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
import collections
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from dm_control.suite import base
|
||||||
|
|
||||||
|
from lerobot.common.envs.aloha.constants import (
|
||||||
|
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||||
|
START_ARM_POSE,
|
||||||
|
normalize_puppet_gripper_position,
|
||||||
|
normalize_puppet_gripper_velocity,
|
||||||
|
unnormalize_puppet_gripper_position,
|
||||||
|
)
|
||||||
|
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||||
|
|
||||||
|
"""
|
||||||
|
Environment for simulated robot bi-manual manipulation, with end-effector control.
|
||||||
|
Action space: [left_arm_pose (7), # position and quaternion for end effector
|
||||||
|
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||||
|
right_arm_pose (7), # position and quaternion for end effector
|
||||||
|
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||||
|
|
||||||
|
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||||
|
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||||
|
right_arm_qpos (6), # absolute joint position
|
||||||
|
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||||
|
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||||
|
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||||
|
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||||
|
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||||
|
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BimanualViperXEndEffectorTask(base.Task):
|
||||||
|
def __init__(self, random=None):
|
||||||
|
super().__init__(random=random)
|
||||||
|
|
||||||
|
def before_step(self, action, physics):
|
||||||
|
a_len = len(action) // 2
|
||||||
|
action_left = action[:a_len]
|
||||||
|
action_right = action[a_len:]
|
||||||
|
|
||||||
|
# set mocap position and quat
|
||||||
|
# left
|
||||||
|
np.copyto(physics.data.mocap_pos[0], action_left[:3])
|
||||||
|
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
|
||||||
|
# right
|
||||||
|
np.copyto(physics.data.mocap_pos[1], action_right[:3])
|
||||||
|
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
|
||||||
|
|
||||||
|
# set gripper
|
||||||
|
g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7])
|
||||||
|
g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7])
|
||||||
|
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
|
||||||
|
|
||||||
|
def initialize_robots(self, physics):
|
||||||
|
# reset joint position
|
||||||
|
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||||
|
|
||||||
|
# reset mocap to align with end effector
|
||||||
|
# to obtain these numbers:
|
||||||
|
# (1) make an ee_sim env and reset to the same start_pose
|
||||||
|
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
|
||||||
|
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
|
||||||
|
# repeat the same for right side
|
||||||
|
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
|
||||||
|
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
|
||||||
|
# right
|
||||||
|
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
|
||||||
|
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
|
||||||
|
|
||||||
|
# reset gripper control
|
||||||
|
close_gripper_control = np.array(
|
||||||
|
[
|
||||||
|
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||||
|
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||||
|
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||||
|
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
np.copyto(physics.data.ctrl, close_gripper_control)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
super().initialize_episode(physics)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_qpos(physics):
|
||||||
|
qpos_raw = physics.data.qpos.copy()
|
||||||
|
left_qpos_raw = qpos_raw[:8]
|
||||||
|
right_qpos_raw = qpos_raw[8:16]
|
||||||
|
left_arm_qpos = left_qpos_raw[:6]
|
||||||
|
right_arm_qpos = right_qpos_raw[:6]
|
||||||
|
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
|
||||||
|
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
|
||||||
|
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_qvel(physics):
|
||||||
|
qvel_raw = physics.data.qvel.copy()
|
||||||
|
left_qvel_raw = qvel_raw[:8]
|
||||||
|
right_qvel_raw = qvel_raw[8:16]
|
||||||
|
left_arm_qvel = left_qvel_raw[:6]
|
||||||
|
right_arm_qvel = right_qvel_raw[:6]
|
||||||
|
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
|
||||||
|
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
|
||||||
|
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_env_state(physics):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
# note: it is important to do .copy()
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs["qpos"] = self.get_qpos(physics)
|
||||||
|
obs["qvel"] = self.get_qvel(physics)
|
||||||
|
obs["env_state"] = self.get_env_state(physics)
|
||||||
|
obs["images"] = {}
|
||||||
|
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
|
||||||
|
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
|
||||||
|
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
|
||||||
|
# used in scripted policy to obtain starting pose
|
||||||
|
obs["mocap_pose_left"] = np.concatenate(
|
||||||
|
[physics.data.mocap_pos[0], physics.data.mocap_quat[0]]
|
||||||
|
).copy()
|
||||||
|
obs["mocap_pose_right"] = np.concatenate(
|
||||||
|
[physics.data.mocap_pos[1], physics.data.mocap_quat[1]]
|
||||||
|
).copy()
|
||||||
|
|
||||||
|
# used when replaying joint trajectory
|
||||||
|
obs["gripper_ctrl"] = physics.data.ctrl.copy()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask):
|
||||||
|
def __init__(self, random=None):
|
||||||
|
super().__init__(random=random)
|
||||||
|
self.max_reward = 4
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
self.initialize_robots(physics)
|
||||||
|
# randomize box position
|
||||||
|
cube_pose = sample_box_pose()
|
||||||
|
box_start_idx = physics.model.name2id("red_box_joint", "joint")
|
||||||
|
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
|
||||||
|
# print(f"randomized cube position to {cube_position}")
|
||||||
|
|
||||||
|
super().initialize_episode(physics)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_env_state(physics):
|
||||||
|
env_state = physics.data.qpos.copy()[16:]
|
||||||
|
return env_state
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
# return whether left gripper is holding the box
|
||||||
|
all_contact_pairs = []
|
||||||
|
for i_contact in range(physics.data.ncon):
|
||||||
|
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||||
|
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||||
|
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||||
|
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||||
|
contact_pair = (name_geom_1, name_geom_2)
|
||||||
|
all_contact_pairs.append(contact_pair)
|
||||||
|
|
||||||
|
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||||
|
touch_table = ("red_box", "table") in all_contact_pairs
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
if touch_right_gripper:
|
||||||
|
reward = 1
|
||||||
|
if touch_right_gripper and not touch_table: # lifted
|
||||||
|
reward = 2
|
||||||
|
if touch_left_gripper: # attempted transfer
|
||||||
|
reward = 3
|
||||||
|
if touch_left_gripper and not touch_table: # successful transfer
|
||||||
|
reward = 4
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
class InsertionEndEffectorTask(BimanualViperXEndEffectorTask):
|
||||||
|
def __init__(self, random=None):
|
||||||
|
super().__init__(random=random)
|
||||||
|
self.max_reward = 4
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
self.initialize_robots(physics)
|
||||||
|
# randomize peg and socket position
|
||||||
|
peg_pose, socket_pose = sample_insertion_pose()
|
||||||
|
|
||||||
|
def id2index(j_id):
|
||||||
|
return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
|
||||||
|
|
||||||
|
peg_start_id = physics.model.name2id("red_peg_joint", "joint")
|
||||||
|
peg_start_idx = id2index(peg_start_id)
|
||||||
|
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
|
||||||
|
# print(f"randomized cube position to {cube_position}")
|
||||||
|
|
||||||
|
socket_start_id = physics.model.name2id("blue_socket_joint", "joint")
|
||||||
|
socket_start_idx = id2index(socket_start_id)
|
||||||
|
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
|
||||||
|
# print(f"randomized cube position to {cube_position}")
|
||||||
|
|
||||||
|
super().initialize_episode(physics)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_env_state(physics):
|
||||||
|
env_state = physics.data.qpos.copy()[16:]
|
||||||
|
return env_state
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
# return whether peg touches the pin
|
||||||
|
all_contact_pairs = []
|
||||||
|
for i_contact in range(physics.data.ncon):
|
||||||
|
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||||
|
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||||
|
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||||
|
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||||
|
contact_pair = (name_geom_1, name_geom_2)
|
||||||
|
all_contact_pairs.append(contact_pair)
|
||||||
|
|
||||||
|
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||||
|
touch_left_gripper = (
|
||||||
|
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
)
|
||||||
|
|
||||||
|
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||||
|
socket_touch_table = (
|
||||||
|
("socket-1", "table") in all_contact_pairs
|
||||||
|
or ("socket-2", "table") in all_contact_pairs
|
||||||
|
or ("socket-3", "table") in all_contact_pairs
|
||||||
|
or ("socket-4", "table") in all_contact_pairs
|
||||||
|
)
|
||||||
|
peg_touch_socket = (
|
||||||
|
("red_peg", "socket-1") in all_contact_pairs
|
||||||
|
or ("red_peg", "socket-2") in all_contact_pairs
|
||||||
|
or ("red_peg", "socket-3") in all_contact_pairs
|
||||||
|
or ("red_peg", "socket-4") in all_contact_pairs
|
||||||
|
)
|
||||||
|
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
if touch_left_gripper and touch_right_gripper: # touch both
|
||||||
|
reward = 1
|
||||||
|
if (
|
||||||
|
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
|
||||||
|
): # grasp both
|
||||||
|
reward = 2
|
||||||
|
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||||
|
reward = 3
|
||||||
|
if pin_touched: # successful insertion
|
||||||
|
reward = 4
|
||||||
|
return reward
|
||||||
39
lerobot/common/envs/aloha/utils.py
Normal file
39
lerobot/common/envs/aloha/utils.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def sample_box_pose():
|
||||||
|
x_range = [0.0, 0.2]
|
||||||
|
y_range = [0.4, 0.6]
|
||||||
|
z_range = [0.05, 0.05]
|
||||||
|
|
||||||
|
ranges = np.vstack([x_range, y_range, z_range])
|
||||||
|
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||||
|
|
||||||
|
cube_quat = np.array([1, 0, 0, 0])
|
||||||
|
return np.concatenate([cube_position, cube_quat])
|
||||||
|
|
||||||
|
|
||||||
|
def sample_insertion_pose():
|
||||||
|
# Peg
|
||||||
|
x_range = [0.1, 0.2]
|
||||||
|
y_range = [0.4, 0.6]
|
||||||
|
z_range = [0.05, 0.05]
|
||||||
|
|
||||||
|
ranges = np.vstack([x_range, y_range, z_range])
|
||||||
|
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||||
|
|
||||||
|
peg_quat = np.array([1, 0, 0, 0])
|
||||||
|
peg_pose = np.concatenate([peg_position, peg_quat])
|
||||||
|
|
||||||
|
# Socket
|
||||||
|
x_range = [-0.2, -0.1]
|
||||||
|
y_range = [0.4, 0.6]
|
||||||
|
z_range = [0.05, 0.05]
|
||||||
|
|
||||||
|
ranges = np.vstack([x_range, y_range, z_range])
|
||||||
|
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||||
|
|
||||||
|
socket_quat = np.array([1, 0, 0, 0])
|
||||||
|
socket_pose = np.concatenate([socket_position, socket_quat])
|
||||||
|
|
||||||
|
return peg_pose, socket_pose
|
||||||
@@ -1,43 +1,64 @@
|
|||||||
import importlib
|
from torchrl.envs import SerialEnv
|
||||||
|
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
||||||
import gymnasium as gym
|
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
def make_env(cfg, transform=None):
|
||||||
"""
|
"""
|
||||||
Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and
|
Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
|
||||||
returns batched observation, reward, terminated, truncated of `num_parallel_envs` items.
|
environments. The env therefore returns batches.`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"obs_type": "pixels_agent_pos",
|
"frame_skip": cfg.env.action_repeat,
|
||||||
"render_mode": "rgb_array",
|
"from_pixels": cfg.env.from_pixels,
|
||||||
"max_episode_steps": cfg.env.episode_length,
|
"pixels_only": cfg.env.pixels_only,
|
||||||
"visualization_width": 384,
|
"image_size": cfg.env.image_size,
|
||||||
"visualization_height": 384,
|
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
package_name = f"gym_{cfg.env.name}"
|
if cfg.env.name == "simxarm":
|
||||||
|
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
||||||
|
|
||||||
try:
|
kwargs["task"] = cfg.env.task
|
||||||
importlib.import_module(package_name)
|
clsfunc = SimxarmEnv
|
||||||
except ModuleNotFoundError as e:
|
elif cfg.env.name == "pusht":
|
||||||
print(
|
from lerobot.common.envs.pusht.env import PushtEnv
|
||||||
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`"
|
|
||||||
)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||||
|
|
||||||
if num_parallel_envs == 0:
|
clsfunc = PushtEnv
|
||||||
# non-batched version of the env that returns an observation of shape (c)
|
elif cfg.env.name == "aloha":
|
||||||
env = gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||||
|
|
||||||
|
kwargs["task"] = cfg.env.task
|
||||||
|
clsfunc = AlohaEnv
|
||||||
else:
|
else:
|
||||||
# batched version of the env that returns an observation of shape (b, c)
|
raise ValueError(cfg.env.name)
|
||||||
env = gym.vector.SyncVectorEnv(
|
|
||||||
[
|
|
||||||
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
|
||||||
for _ in range(num_parallel_envs)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return env
|
def _make_env(seed):
|
||||||
|
nonlocal kwargs
|
||||||
|
kwargs["seed"] = seed
|
||||||
|
env = clsfunc(**kwargs)
|
||||||
|
|
||||||
|
# limit rollout to max_steps
|
||||||
|
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
|
||||||
|
|
||||||
|
if transform is not None:
|
||||||
|
# useful to add normalization
|
||||||
|
if isinstance(transform, Compose):
|
||||||
|
for tf in transform:
|
||||||
|
env.append_transform(tf.clone())
|
||||||
|
elif isinstance(transform, Transform):
|
||||||
|
env.append_transform(transform.clone())
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
return SerialEnv(
|
||||||
|
cfg.rollout_batch_size,
|
||||||
|
create_env_fn=_make_env,
|
||||||
|
create_env_kwargs=[
|
||||||
|
{"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
245
lerobot/common/envs/pusht/env.py
Normal file
245
lerobot/common/envs/pusht/env.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
from collections import deque
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.tensor_specs import (
|
||||||
|
BoundedTensorSpec,
|
||||||
|
CompositeSpec,
|
||||||
|
DiscreteTensorSpec,
|
||||||
|
UnboundedContinuousTensorSpec,
|
||||||
|
)
|
||||||
|
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||||
|
|
||||||
|
from lerobot.common.envs.abstract import AbstractEnv
|
||||||
|
from lerobot.common.utils import set_global_seed
|
||||||
|
|
||||||
|
_has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||||
|
|
||||||
|
|
||||||
|
class PushtEnv(AbstractEnv):
|
||||||
|
name = "pusht"
|
||||||
|
available_tasks = ["pusht"]
|
||||||
|
_reset_warning_issued = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task="pusht",
|
||||||
|
frame_skip: int = 1,
|
||||||
|
from_pixels: bool = False,
|
||||||
|
pixels_only: bool = False,
|
||||||
|
image_size=None,
|
||||||
|
seed=1337,
|
||||||
|
device="cpu",
|
||||||
|
num_prev_obs=1,
|
||||||
|
num_prev_action=0,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
task=task,
|
||||||
|
frame_skip=frame_skip,
|
||||||
|
from_pixels=from_pixels,
|
||||||
|
pixels_only=pixels_only,
|
||||||
|
image_size=image_size,
|
||||||
|
seed=seed,
|
||||||
|
device=device,
|
||||||
|
num_prev_obs=num_prev_obs,
|
||||||
|
num_prev_action=num_prev_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_env(self):
|
||||||
|
if not _has_gym:
|
||||||
|
raise ImportError("Cannot import gymnasium.")
|
||||||
|
|
||||||
|
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
||||||
|
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||||
|
|
||||||
|
if not self.from_pixels:
|
||||||
|
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||||
|
from lerobot.common.envs.pusht.pusht_image_env import PushTImageEnv
|
||||||
|
|
||||||
|
self._env = PushTImageEnv(render_size=self.image_size)
|
||||||
|
|
||||||
|
def render(self, mode="rgb_array", width=96, height=96, with_marker=True):
|
||||||
|
"""
|
||||||
|
with_marker adds a cursor showing the targeted action for the controller.
|
||||||
|
"""
|
||||||
|
if width != height:
|
||||||
|
raise NotImplementedError()
|
||||||
|
tmp = self._env.render_size
|
||||||
|
if width != self._env.render_size:
|
||||||
|
self._env.render_cache = None
|
||||||
|
self._env.render_size = width
|
||||||
|
out = self._env.render(mode).copy()
|
||||||
|
if with_marker and self._env.latest_action is not None:
|
||||||
|
action = np.array(self._env.latest_action)
|
||||||
|
coord = (action / 512 * self._env.render_size).astype(np.int32)
|
||||||
|
marker_size = int(8 / 96 * self._env.render_size)
|
||||||
|
thickness = int(1 / 96 * self._env.render_size)
|
||||||
|
cv2.drawMarker(
|
||||||
|
out,
|
||||||
|
coord,
|
||||||
|
color=(255, 0, 0),
|
||||||
|
markerType=cv2.MARKER_CROSS,
|
||||||
|
markerSize=marker_size,
|
||||||
|
thickness=thickness,
|
||||||
|
)
|
||||||
|
self._env.render_size = tmp
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _format_raw_obs(self, raw_obs):
|
||||||
|
if self.from_pixels:
|
||||||
|
image = torch.from_numpy(raw_obs["image"])
|
||||||
|
obs = {"image": image}
|
||||||
|
|
||||||
|
if not self.pixels_only:
|
||||||
|
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
|
||||||
|
else:
|
||||||
|
# TODO:
|
||||||
|
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||||
|
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
|
if tensordict is not None and not PushtEnv._reset_warning_issued:
|
||||||
|
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||||
|
PushtEnv._reset_warning_issued = True
|
||||||
|
|
||||||
|
# Seed the environment and update the seed to be used for the next reset.
|
||||||
|
self._next_seed = self.set_seed(self._next_seed)
|
||||||
|
raw_obs = self._env.reset()
|
||||||
|
|
||||||
|
obs = self._format_raw_obs(raw_obs)
|
||||||
|
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
stacked_obs = {}
|
||||||
|
if "image" in obs:
|
||||||
|
self._prev_obs_image_queue = deque(
|
||||||
|
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||||
|
)
|
||||||
|
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||||
|
if "state" in obs:
|
||||||
|
self._prev_obs_state_queue = deque(
|
||||||
|
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||||
|
)
|
||||||
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
|
obs = stacked_obs
|
||||||
|
|
||||||
|
td = TensorDict(
|
||||||
|
{
|
||||||
|
"observation": TensorDict(obs, batch_size=[]),
|
||||||
|
"done": torch.tensor([False], dtype=torch.bool),
|
||||||
|
},
|
||||||
|
batch_size=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
return td
|
||||||
|
|
||||||
|
def _step(self, tensordict: TensorDict):
|
||||||
|
td = tensordict
|
||||||
|
action = td["action"].numpy()
|
||||||
|
assert action.ndim == 1
|
||||||
|
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||||
|
|
||||||
|
raw_obs, reward, done, info = self._env.step(action)
|
||||||
|
|
||||||
|
obs = self._format_raw_obs(raw_obs)
|
||||||
|
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
stacked_obs = {}
|
||||||
|
if "image" in obs:
|
||||||
|
self._prev_obs_image_queue.append(obs["image"])
|
||||||
|
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||||
|
if "state" in obs:
|
||||||
|
self._prev_obs_state_queue.append(obs["state"])
|
||||||
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
|
obs = stacked_obs
|
||||||
|
|
||||||
|
td = TensorDict(
|
||||||
|
{
|
||||||
|
"observation": TensorDict(obs, batch_size=[]),
|
||||||
|
"reward": torch.tensor([reward], dtype=torch.float32),
|
||||||
|
# success and done are true when coverage > self.success_threshold in env
|
||||||
|
"done": torch.tensor([done], dtype=torch.bool),
|
||||||
|
"success": torch.tensor([done], dtype=torch.bool),
|
||||||
|
},
|
||||||
|
batch_size=[],
|
||||||
|
)
|
||||||
|
return td
|
||||||
|
|
||||||
|
def _make_spec(self):
|
||||||
|
obs = {}
|
||||||
|
if self.from_pixels:
|
||||||
|
image_shape = (3, self.image_size, self.image_size)
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||||
|
|
||||||
|
obs["image"] = BoundedTensorSpec(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=image_shape,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
if not self.pixels_only:
|
||||||
|
state_shape = self._env.observation_space["agent_pos"].shape
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||||
|
|
||||||
|
obs["state"] = BoundedTensorSpec(
|
||||||
|
low=0,
|
||||||
|
high=512,
|
||||||
|
shape=state_shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||||
|
state_shape = self._env.observation_space["observation"].shape
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||||
|
|
||||||
|
obs["state"] = UnboundedContinuousTensorSpec(
|
||||||
|
# TODO:
|
||||||
|
shape=state_shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.observation_spec = CompositeSpec({"observation": obs})
|
||||||
|
|
||||||
|
self.action_spec = _gym_to_torchrl_spec_transform(
|
||||||
|
self._env.action_space,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.done_spec = CompositeSpec(
|
||||||
|
{
|
||||||
|
"done": DiscreteTensorSpec(
|
||||||
|
2,
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"success": DiscreteTensorSpec(
|
||||||
|
2,
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_seed(self, seed: Optional[int]):
|
||||||
|
# Set global seed.
|
||||||
|
set_global_seed(seed)
|
||||||
|
# Set PushTImageEnv seed as it relies on it's own internal _seed attribute.
|
||||||
|
self._env.seed(seed)
|
||||||
378
lerobot/common/envs/pusht/pusht_env.py
Normal file
378
lerobot/common/envs/pusht/pusht_env.py
Normal file
@@ -0,0 +1,378 @@
|
|||||||
|
import collections
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
import pygame
|
||||||
|
import pymunk
|
||||||
|
import pymunk.pygame_util
|
||||||
|
import shapely.geometry as sg
|
||||||
|
import skimage.transform as st
|
||||||
|
from gymnasium import spaces
|
||||||
|
from pymunk.vec2d import Vec2d
|
||||||
|
|
||||||
|
from lerobot.common.envs.pusht.pymunk_override import DrawOptions
|
||||||
|
|
||||||
|
|
||||||
|
def pymunk_to_shapely(body, shapes):
|
||||||
|
geoms = []
|
||||||
|
for shape in shapes:
|
||||||
|
if isinstance(shape, pymunk.shapes.Poly):
|
||||||
|
verts = [body.local_to_world(v) for v in shape.get_vertices()]
|
||||||
|
verts += [verts[0]]
|
||||||
|
geoms.append(sg.Polygon(verts))
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported shape type {type(shape)}")
|
||||||
|
geom = sg.MultiPolygon(geoms)
|
||||||
|
return geom
|
||||||
|
|
||||||
|
|
||||||
|
class PushTEnv(gym.Env):
|
||||||
|
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10}
|
||||||
|
reward_range = (0.0, 1.0)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
legacy=True, # compatibility with original
|
||||||
|
block_cog=None,
|
||||||
|
damping=None,
|
||||||
|
render_action=True,
|
||||||
|
render_size=96,
|
||||||
|
reset_to_state=None,
|
||||||
|
):
|
||||||
|
self._seed = None
|
||||||
|
self.seed()
|
||||||
|
self.window_size = ws = 512 # The size of the PyGame window
|
||||||
|
self.render_size = render_size
|
||||||
|
self.sim_hz = 100
|
||||||
|
# Local controller params.
|
||||||
|
self.k_p, self.k_v = 100, 20 # PD control.z
|
||||||
|
self.control_hz = self.metadata["video.frames_per_second"]
|
||||||
|
# legcay set_state for data compatibility
|
||||||
|
self.legacy = legacy
|
||||||
|
|
||||||
|
# agent_pos, block_pos, block_angle
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=np.array([0, 0, 0, 0, 0], dtype=np.float64),
|
||||||
|
high=np.array([ws, ws, ws, ws, np.pi * 2], dtype=np.float64),
|
||||||
|
shape=(5,),
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
# positional goal for agent
|
||||||
|
self.action_space = spaces.Box(
|
||||||
|
low=np.array([0, 0], dtype=np.float64),
|
||||||
|
high=np.array([ws, ws], dtype=np.float64),
|
||||||
|
shape=(2,),
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.block_cog = block_cog
|
||||||
|
self.damping = damping
|
||||||
|
self.render_action = render_action
|
||||||
|
|
||||||
|
"""
|
||||||
|
If human-rendering is used, `self.window` will be a reference
|
||||||
|
to the window that we draw to. `self.clock` will be a clock that is used
|
||||||
|
to ensure that the environment is rendered at the correct framerate in
|
||||||
|
human-mode. They will remain `None` until human-mode is used for the
|
||||||
|
first time.
|
||||||
|
"""
|
||||||
|
self.window = None
|
||||||
|
self.clock = None
|
||||||
|
self.screen = None
|
||||||
|
|
||||||
|
self.space = None
|
||||||
|
self.teleop = None
|
||||||
|
self.render_buffer = None
|
||||||
|
self.latest_action = None
|
||||||
|
self.reset_to_state = reset_to_state
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
seed = self._seed
|
||||||
|
self._setup()
|
||||||
|
if self.block_cog is not None:
|
||||||
|
self.block.center_of_gravity = self.block_cog
|
||||||
|
if self.damping is not None:
|
||||||
|
self.space.damping = self.damping
|
||||||
|
|
||||||
|
# use legacy RandomState for compatibility
|
||||||
|
state = self.reset_to_state
|
||||||
|
if state is None:
|
||||||
|
rs = np.random.RandomState(seed=seed)
|
||||||
|
state = np.array(
|
||||||
|
[
|
||||||
|
rs.randint(50, 450),
|
||||||
|
rs.randint(50, 450),
|
||||||
|
rs.randint(100, 400),
|
||||||
|
rs.randint(100, 400),
|
||||||
|
rs.randn() * 2 * np.pi - np.pi,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self._set_state(state)
|
||||||
|
|
||||||
|
observation = self._get_obs()
|
||||||
|
return observation
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
dt = 1.0 / self.sim_hz
|
||||||
|
self.n_contact_points = 0
|
||||||
|
n_steps = self.sim_hz // self.control_hz
|
||||||
|
if action is not None:
|
||||||
|
self.latest_action = action
|
||||||
|
for _ in range(n_steps):
|
||||||
|
# Step PD control.
|
||||||
|
# self.agent.velocity = self.k_p * (act - self.agent.position) # P control works too.
|
||||||
|
acceleration = self.k_p * (action - self.agent.position) + self.k_v * (
|
||||||
|
Vec2d(0, 0) - self.agent.velocity
|
||||||
|
)
|
||||||
|
self.agent.velocity += acceleration * dt
|
||||||
|
|
||||||
|
# Step physics.
|
||||||
|
self.space.step(dt)
|
||||||
|
|
||||||
|
# compute reward
|
||||||
|
goal_body = self._get_goal_pose_body(self.goal_pose)
|
||||||
|
goal_geom = pymunk_to_shapely(goal_body, self.block.shapes)
|
||||||
|
block_geom = pymunk_to_shapely(self.block, self.block.shapes)
|
||||||
|
|
||||||
|
intersection_area = goal_geom.intersection(block_geom).area
|
||||||
|
goal_area = goal_geom.area
|
||||||
|
coverage = intersection_area / goal_area
|
||||||
|
reward = np.clip(coverage / self.success_threshold, 0, 1)
|
||||||
|
done = coverage > self.success_threshold
|
||||||
|
|
||||||
|
observation = self._get_obs()
|
||||||
|
info = self._get_info()
|
||||||
|
|
||||||
|
return observation, reward, done, info
|
||||||
|
|
||||||
|
def render(self, mode):
|
||||||
|
return self._render_frame(mode)
|
||||||
|
|
||||||
|
def teleop_agent(self):
|
||||||
|
TeleopAgent = collections.namedtuple("TeleopAgent", ["act"])
|
||||||
|
|
||||||
|
def act(obs):
|
||||||
|
act = None
|
||||||
|
mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen)
|
||||||
|
if self.teleop or (mouse_position - self.agent.position).length < 30:
|
||||||
|
self.teleop = True
|
||||||
|
act = mouse_position
|
||||||
|
return act
|
||||||
|
|
||||||
|
return TeleopAgent(act)
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
obs = np.array(
|
||||||
|
tuple(self.agent.position) + tuple(self.block.position) + (self.block.angle % (2 * np.pi),)
|
||||||
|
)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _get_goal_pose_body(self, pose):
|
||||||
|
mass = 1
|
||||||
|
inertia = pymunk.moment_for_box(mass, (50, 100))
|
||||||
|
body = pymunk.Body(mass, inertia)
|
||||||
|
# preserving the legacy assignment order for compatibility
|
||||||
|
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
|
||||||
|
body.position = pose[:2].tolist()
|
||||||
|
body.angle = pose[2]
|
||||||
|
return body
|
||||||
|
|
||||||
|
def _get_info(self):
|
||||||
|
n_steps = self.sim_hz // self.control_hz
|
||||||
|
n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps))
|
||||||
|
info = {
|
||||||
|
"pos_agent": np.array(self.agent.position),
|
||||||
|
"vel_agent": np.array(self.agent.velocity),
|
||||||
|
"block_pose": np.array(list(self.block.position) + [self.block.angle]),
|
||||||
|
"goal_pose": self.goal_pose,
|
||||||
|
"n_contacts": n_contact_points_per_step,
|
||||||
|
}
|
||||||
|
return info
|
||||||
|
|
||||||
|
def _render_frame(self, mode):
|
||||||
|
if self.window is None and mode == "human":
|
||||||
|
pygame.init()
|
||||||
|
pygame.display.init()
|
||||||
|
self.window = pygame.display.set_mode((self.window_size, self.window_size))
|
||||||
|
if self.clock is None and mode == "human":
|
||||||
|
self.clock = pygame.time.Clock()
|
||||||
|
|
||||||
|
canvas = pygame.Surface((self.window_size, self.window_size))
|
||||||
|
canvas.fill((255, 255, 255))
|
||||||
|
self.screen = canvas
|
||||||
|
|
||||||
|
draw_options = DrawOptions(canvas)
|
||||||
|
|
||||||
|
# Draw goal pose.
|
||||||
|
goal_body = self._get_goal_pose_body(self.goal_pose)
|
||||||
|
for shape in self.block.shapes:
|
||||||
|
goal_points = [
|
||||||
|
pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface)
|
||||||
|
for v in shape.get_vertices()
|
||||||
|
]
|
||||||
|
goal_points += [goal_points[0]]
|
||||||
|
pygame.draw.polygon(canvas, self.goal_color, goal_points)
|
||||||
|
|
||||||
|
# Draw agent and block.
|
||||||
|
self.space.debug_draw(draw_options)
|
||||||
|
|
||||||
|
if mode == "human":
|
||||||
|
# The following line copies our drawings from `canvas` to the visible window
|
||||||
|
self.window.blit(canvas, canvas.get_rect())
|
||||||
|
pygame.event.pump()
|
||||||
|
pygame.display.update()
|
||||||
|
|
||||||
|
# the clock is already ticked during in step for "human"
|
||||||
|
|
||||||
|
img = np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))
|
||||||
|
img = cv2.resize(img, (self.render_size, self.render_size))
|
||||||
|
if self.render_action and self.latest_action is not None:
|
||||||
|
action = np.array(self.latest_action)
|
||||||
|
coord = (action / 512 * 96).astype(np.int32)
|
||||||
|
marker_size = int(8 / 96 * self.render_size)
|
||||||
|
thickness = int(1 / 96 * self.render_size)
|
||||||
|
cv2.drawMarker(
|
||||||
|
img,
|
||||||
|
coord,
|
||||||
|
color=(255, 0, 0),
|
||||||
|
markerType=cv2.MARKER_CROSS,
|
||||||
|
markerSize=marker_size,
|
||||||
|
thickness=thickness,
|
||||||
|
)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.window is not None:
|
||||||
|
pygame.display.quit()
|
||||||
|
pygame.quit()
|
||||||
|
|
||||||
|
def seed(self, seed=None):
|
||||||
|
if seed is None:
|
||||||
|
seed = np.random.randint(0, 25536)
|
||||||
|
self._seed = seed
|
||||||
|
self.np_random = np.random.default_rng(seed)
|
||||||
|
|
||||||
|
def _handle_collision(self, arbiter, space, data):
|
||||||
|
self.n_contact_points += len(arbiter.contact_point_set.points)
|
||||||
|
|
||||||
|
def _set_state(self, state):
|
||||||
|
if isinstance(state, np.ndarray):
|
||||||
|
state = state.tolist()
|
||||||
|
pos_agent = state[:2]
|
||||||
|
pos_block = state[2:4]
|
||||||
|
rot_block = state[4]
|
||||||
|
self.agent.position = pos_agent
|
||||||
|
# setting angle rotates with respect to center of mass
|
||||||
|
# therefore will modify the geometric position
|
||||||
|
# if not the same as CoM
|
||||||
|
# therefore should be modified first.
|
||||||
|
if self.legacy:
|
||||||
|
# for compatibility with legacy data
|
||||||
|
self.block.position = pos_block
|
||||||
|
self.block.angle = rot_block
|
||||||
|
else:
|
||||||
|
self.block.angle = rot_block
|
||||||
|
self.block.position = pos_block
|
||||||
|
|
||||||
|
# Run physics to take effect
|
||||||
|
self.space.step(1.0 / self.sim_hz)
|
||||||
|
|
||||||
|
def _set_state_local(self, state_local):
|
||||||
|
agent_pos_local = state_local[:2]
|
||||||
|
block_pose_local = state_local[2:]
|
||||||
|
tf_img_obj = st.AffineTransform(translation=self.goal_pose[:2], rotation=self.goal_pose[2])
|
||||||
|
tf_obj_new = st.AffineTransform(translation=block_pose_local[:2], rotation=block_pose_local[2])
|
||||||
|
tf_img_new = st.AffineTransform(matrix=tf_img_obj.params @ tf_obj_new.params)
|
||||||
|
agent_pos_new = tf_img_new(agent_pos_local)
|
||||||
|
new_state = np.array(list(agent_pos_new[0]) + list(tf_img_new.translation) + [tf_img_new.rotation])
|
||||||
|
self._set_state(new_state)
|
||||||
|
return new_state
|
||||||
|
|
||||||
|
def _setup(self):
|
||||||
|
self.space = pymunk.Space()
|
||||||
|
self.space.gravity = 0, 0
|
||||||
|
self.space.damping = 0
|
||||||
|
self.teleop = False
|
||||||
|
self.render_buffer = []
|
||||||
|
|
||||||
|
# Add walls.
|
||||||
|
walls = [
|
||||||
|
self._add_segment((5, 506), (5, 5), 2),
|
||||||
|
self._add_segment((5, 5), (506, 5), 2),
|
||||||
|
self._add_segment((506, 5), (506, 506), 2),
|
||||||
|
self._add_segment((5, 506), (506, 506), 2),
|
||||||
|
]
|
||||||
|
self.space.add(*walls)
|
||||||
|
|
||||||
|
# Add agent, block, and goal zone.
|
||||||
|
self.agent = self.add_circle((256, 400), 15)
|
||||||
|
self.block = self.add_tee((256, 300), 0)
|
||||||
|
self.goal_color = pygame.Color("LightGreen")
|
||||||
|
self.goal_pose = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||||
|
|
||||||
|
# Add collision handling
|
||||||
|
self.collision_handeler = self.space.add_collision_handler(0, 0)
|
||||||
|
self.collision_handeler.post_solve = self._handle_collision
|
||||||
|
self.n_contact_points = 0
|
||||||
|
|
||||||
|
self.max_score = 50 * 100
|
||||||
|
self.success_threshold = 0.95 # 95% coverage.
|
||||||
|
|
||||||
|
def _add_segment(self, a, b, radius):
|
||||||
|
shape = pymunk.Segment(self.space.static_body, a, b, radius)
|
||||||
|
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
|
||||||
|
return shape
|
||||||
|
|
||||||
|
def add_circle(self, position, radius):
|
||||||
|
body = pymunk.Body(body_type=pymunk.Body.KINEMATIC)
|
||||||
|
body.position = position
|
||||||
|
body.friction = 1
|
||||||
|
shape = pymunk.Circle(body, radius)
|
||||||
|
shape.color = pygame.Color("RoyalBlue")
|
||||||
|
self.space.add(body, shape)
|
||||||
|
return body
|
||||||
|
|
||||||
|
def add_box(self, position, height, width):
|
||||||
|
mass = 1
|
||||||
|
inertia = pymunk.moment_for_box(mass, (height, width))
|
||||||
|
body = pymunk.Body(mass, inertia)
|
||||||
|
body.position = position
|
||||||
|
shape = pymunk.Poly.create_box(body, (height, width))
|
||||||
|
shape.color = pygame.Color("LightSlateGray")
|
||||||
|
self.space.add(body, shape)
|
||||||
|
return body
|
||||||
|
|
||||||
|
def add_tee(self, position, angle, scale=30, color="LightSlateGray", mask=None):
|
||||||
|
if mask is None:
|
||||||
|
mask = pymunk.ShapeFilter.ALL_MASKS()
|
||||||
|
mass = 1
|
||||||
|
length = 4
|
||||||
|
vertices1 = [
|
||||||
|
(-length * scale / 2, scale),
|
||||||
|
(length * scale / 2, scale),
|
||||||
|
(length * scale / 2, 0),
|
||||||
|
(-length * scale / 2, 0),
|
||||||
|
]
|
||||||
|
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||||
|
vertices2 = [
|
||||||
|
(-scale / 2, scale),
|
||||||
|
(-scale / 2, length * scale),
|
||||||
|
(scale / 2, length * scale),
|
||||||
|
(scale / 2, scale),
|
||||||
|
]
|
||||||
|
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||||
|
body = pymunk.Body(mass, inertia1 + inertia2)
|
||||||
|
shape1 = pymunk.Poly(body, vertices1)
|
||||||
|
shape2 = pymunk.Poly(body, vertices2)
|
||||||
|
shape1.color = pygame.Color(color)
|
||||||
|
shape2.color = pygame.Color(color)
|
||||||
|
shape1.filter = pymunk.ShapeFilter(mask=mask)
|
||||||
|
shape2.filter = pymunk.ShapeFilter(mask=mask)
|
||||||
|
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
|
||||||
|
body.position = position
|
||||||
|
body.angle = angle
|
||||||
|
body.friction = 1
|
||||||
|
self.space.add(body, shape1, shape2)
|
||||||
|
return body
|
||||||
41
lerobot/common/envs/pusht/pusht_image_env.py
Normal file
41
lerobot/common/envs/pusht/pusht_image_env.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import numpy as np
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||||
|
|
||||||
|
|
||||||
|
class PushTImageEnv(PushTEnv):
|
||||||
|
metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10}
|
||||||
|
|
||||||
|
# Note: legacy defaults to True for compatibility with original
|
||||||
|
def __init__(self, legacy=True, block_cog=None, damping=None, render_size=96):
|
||||||
|
super().__init__(
|
||||||
|
legacy=legacy, block_cog=block_cog, damping=damping, render_size=render_size, render_action=False
|
||||||
|
)
|
||||||
|
ws = self.window_size
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"image": spaces.Box(low=0, high=1, shape=(3, render_size, render_size), dtype=np.float32),
|
||||||
|
"agent_pos": spaces.Box(low=0, high=ws, shape=(2,), dtype=np.float32),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.render_cache = None
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
img = super()._render_frame(mode="rgb_array")
|
||||||
|
|
||||||
|
agent_pos = np.array(self.agent.position)
|
||||||
|
img_obs = np.moveaxis(img, -1, 0)
|
||||||
|
obs = {"image": img_obs, "agent_pos": agent_pos}
|
||||||
|
|
||||||
|
self.render_cache = img
|
||||||
|
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def render(self, mode):
|
||||||
|
assert mode == "rgb_array"
|
||||||
|
|
||||||
|
if self.render_cache is None:
|
||||||
|
self._get_obs()
|
||||||
|
|
||||||
|
return self.render_cache
|
||||||
244
lerobot/common/envs/pusht/pymunk_override.py
Normal file
244
lerobot/common/envs/pusht/pymunk_override.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# pymunk
|
||||||
|
# Copyright (c) 2007-2016 Victor Blomqvist
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
"""This submodule contains helper functions to help with quick prototyping
|
||||||
|
using pymunk together with pygame.
|
||||||
|
|
||||||
|
Intended to help with debugging and prototyping, not for actual production use
|
||||||
|
in a full application. The methods contained in this module is opinionated
|
||||||
|
about your coordinate system and not in any way optimized.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__docformat__ = "reStructuredText"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DrawOptions",
|
||||||
|
"get_mouse_pos",
|
||||||
|
"to_pygame",
|
||||||
|
"from_pygame",
|
||||||
|
# "lighten",
|
||||||
|
"positive_y_is_up",
|
||||||
|
]
|
||||||
|
|
||||||
|
from typing import Sequence, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pygame
|
||||||
|
import pymunk
|
||||||
|
from pymunk.space_debug_draw_options import SpaceDebugColor
|
||||||
|
from pymunk.vec2d import Vec2d
|
||||||
|
|
||||||
|
positive_y_is_up: bool = False
|
||||||
|
"""Make increasing values of y point upwards.
|
||||||
|
|
||||||
|
When True::
|
||||||
|
|
||||||
|
y
|
||||||
|
^
|
||||||
|
| . (3, 3)
|
||||||
|
|
|
||||||
|
| . (2, 2)
|
||||||
|
|
|
||||||
|
+------ > x
|
||||||
|
|
||||||
|
When False::
|
||||||
|
|
||||||
|
+------ > x
|
||||||
|
|
|
||||||
|
| . (2, 2)
|
||||||
|
|
|
||||||
|
| . (3, 3)
|
||||||
|
v
|
||||||
|
y
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DrawOptions(pymunk.SpaceDebugDrawOptions):
|
||||||
|
def __init__(self, surface: pygame.Surface) -> None:
|
||||||
|
"""Draw a pymunk.Space on a pygame.Surface object.
|
||||||
|
|
||||||
|
Typical usage::
|
||||||
|
|
||||||
|
>>> import pymunk
|
||||||
|
>>> surface = pygame.Surface((10,10))
|
||||||
|
>>> space = pymunk.Space()
|
||||||
|
>>> options = pymunk.pygame_util.DrawOptions(surface)
|
||||||
|
>>> space.debug_draw(options)
|
||||||
|
|
||||||
|
You can control the color of a shape by setting shape.color to the color
|
||||||
|
you want it drawn in::
|
||||||
|
|
||||||
|
>>> c = pymunk.Circle(None, 10)
|
||||||
|
>>> c.color = pygame.Color("pink")
|
||||||
|
|
||||||
|
See pygame_util.demo.py for a full example
|
||||||
|
|
||||||
|
Since pygame uses a coordinate system where y points down (in contrast
|
||||||
|
to many other cases), you either have to make the physics simulation
|
||||||
|
with Pymunk also behave in that way, or flip everything when you draw.
|
||||||
|
|
||||||
|
The easiest is probably to just make the simulation behave the same
|
||||||
|
way as Pygame does. In that way all coordinates used are in the same
|
||||||
|
orientation and easy to reason about::
|
||||||
|
|
||||||
|
>>> space = pymunk.Space()
|
||||||
|
>>> space.gravity = (0, -1000)
|
||||||
|
>>> body = pymunk.Body()
|
||||||
|
>>> body.position = (0, 0) # will be positioned in the top left corner
|
||||||
|
>>> space.debug_draw(options)
|
||||||
|
|
||||||
|
To flip the drawing its possible to set the module property
|
||||||
|
:py:data:`positive_y_is_up` to True. Then the pygame drawing will flip
|
||||||
|
the simulation upside down before drawing::
|
||||||
|
|
||||||
|
>>> positive_y_is_up = True
|
||||||
|
>>> body = pymunk.Body()
|
||||||
|
>>> body.position = (0, 0)
|
||||||
|
>>> # Body will be position in bottom left corner
|
||||||
|
|
||||||
|
:Parameters:
|
||||||
|
surface : pygame.Surface
|
||||||
|
Surface that the objects will be drawn on
|
||||||
|
"""
|
||||||
|
self.surface = surface
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def draw_circle(
|
||||||
|
self,
|
||||||
|
pos: Vec2d,
|
||||||
|
angle: float,
|
||||||
|
radius: float,
|
||||||
|
outline_color: SpaceDebugColor,
|
||||||
|
fill_color: SpaceDebugColor,
|
||||||
|
) -> None:
|
||||||
|
p = to_pygame(pos, self.surface)
|
||||||
|
|
||||||
|
pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0)
|
||||||
|
pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius - 4), 0)
|
||||||
|
|
||||||
|
# circle_edge = pos + Vec2d(radius, 0).rotated(angle)
|
||||||
|
# p2 = to_pygame(circle_edge, self.surface)
|
||||||
|
# line_r = 2 if radius > 20 else 1
|
||||||
|
# pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r)
|
||||||
|
|
||||||
|
def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None:
|
||||||
|
p1 = to_pygame(a, self.surface)
|
||||||
|
p2 = to_pygame(b, self.surface)
|
||||||
|
|
||||||
|
pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2])
|
||||||
|
|
||||||
|
def draw_fat_segment(
|
||||||
|
self,
|
||||||
|
a: Tuple[float, float],
|
||||||
|
b: Tuple[float, float],
|
||||||
|
radius: float,
|
||||||
|
outline_color: SpaceDebugColor,
|
||||||
|
fill_color: SpaceDebugColor,
|
||||||
|
) -> None:
|
||||||
|
p1 = to_pygame(a, self.surface)
|
||||||
|
p2 = to_pygame(b, self.surface)
|
||||||
|
|
||||||
|
r = round(max(1, radius * 2))
|
||||||
|
pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r)
|
||||||
|
if r > 2:
|
||||||
|
orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])]
|
||||||
|
if orthog[0] == 0 and orthog[1] == 0:
|
||||||
|
return
|
||||||
|
scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5
|
||||||
|
orthog[0] = round(orthog[0] * scale)
|
||||||
|
orthog[1] = round(orthog[1] * scale)
|
||||||
|
points = [
|
||||||
|
(p1[0] - orthog[0], p1[1] - orthog[1]),
|
||||||
|
(p1[0] + orthog[0], p1[1] + orthog[1]),
|
||||||
|
(p2[0] + orthog[0], p2[1] + orthog[1]),
|
||||||
|
(p2[0] - orthog[0], p2[1] - orthog[1]),
|
||||||
|
]
|
||||||
|
pygame.draw.polygon(self.surface, fill_color.as_int(), points)
|
||||||
|
pygame.draw.circle(
|
||||||
|
self.surface,
|
||||||
|
fill_color.as_int(),
|
||||||
|
(round(p1[0]), round(p1[1])),
|
||||||
|
round(radius),
|
||||||
|
)
|
||||||
|
pygame.draw.circle(
|
||||||
|
self.surface,
|
||||||
|
fill_color.as_int(),
|
||||||
|
(round(p2[0]), round(p2[1])),
|
||||||
|
round(radius),
|
||||||
|
)
|
||||||
|
|
||||||
|
def draw_polygon(
|
||||||
|
self,
|
||||||
|
verts: Sequence[Tuple[float, float]],
|
||||||
|
radius: float,
|
||||||
|
outline_color: SpaceDebugColor,
|
||||||
|
fill_color: SpaceDebugColor,
|
||||||
|
) -> None:
|
||||||
|
ps = [to_pygame(v, self.surface) for v in verts]
|
||||||
|
ps += [ps[0]]
|
||||||
|
|
||||||
|
radius = 2
|
||||||
|
pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps)
|
||||||
|
|
||||||
|
if radius > 0:
|
||||||
|
for i in range(len(verts)):
|
||||||
|
a = verts[i]
|
||||||
|
b = verts[(i + 1) % len(verts)]
|
||||||
|
self.draw_fat_segment(a, b, radius, fill_color, fill_color)
|
||||||
|
|
||||||
|
def draw_dot(self, size: float, pos: Tuple[float, float], color: SpaceDebugColor) -> None:
|
||||||
|
p = to_pygame(pos, self.surface)
|
||||||
|
pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mouse_pos(surface: pygame.Surface) -> Tuple[int, int]:
|
||||||
|
"""Get position of the mouse pointer in pymunk coordinates."""
|
||||||
|
p = pygame.mouse.get_pos()
|
||||||
|
return from_pygame(p, surface)
|
||||||
|
|
||||||
|
|
||||||
|
def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
|
||||||
|
"""Convenience method to convert pymunk coordinates to pygame surface
|
||||||
|
local coordinates.
|
||||||
|
|
||||||
|
Note that in case positive_y_is_up is False, this function won't actually do
|
||||||
|
anything except converting the point to integers.
|
||||||
|
"""
|
||||||
|
if positive_y_is_up:
|
||||||
|
return round(p[0]), surface.get_height() - round(p[1])
|
||||||
|
else:
|
||||||
|
return round(p[0]), round(p[1])
|
||||||
|
|
||||||
|
|
||||||
|
def from_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]:
|
||||||
|
"""Convenience method to convert pygame surface local coordinates to
|
||||||
|
pymunk coordinates
|
||||||
|
"""
|
||||||
|
return to_pygame(p, surface)
|
||||||
|
|
||||||
|
|
||||||
|
def light_color(color: SpaceDebugColor):
|
||||||
|
color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255]))
|
||||||
|
color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3])
|
||||||
|
return color
|
||||||
237
lerobot/common/envs/simxarm/env.py
Normal file
237
lerobot/common/envs/simxarm/env.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
from collections import deque
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from torchrl.data.tensor_specs import (
|
||||||
|
BoundedTensorSpec,
|
||||||
|
CompositeSpec,
|
||||||
|
DiscreteTensorSpec,
|
||||||
|
UnboundedContinuousTensorSpec,
|
||||||
|
)
|
||||||
|
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||||
|
|
||||||
|
from lerobot.common.envs.abstract import AbstractEnv
|
||||||
|
from lerobot.common.utils import set_global_seed
|
||||||
|
|
||||||
|
MAX_NUM_ACTIONS = 4
|
||||||
|
|
||||||
|
_has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||||
|
|
||||||
|
|
||||||
|
class SimxarmEnv(AbstractEnv):
|
||||||
|
name = "simxarm"
|
||||||
|
available_tasks = ["lift"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task,
|
||||||
|
frame_skip: int = 1,
|
||||||
|
from_pixels: bool = False,
|
||||||
|
pixels_only: bool = False,
|
||||||
|
image_size=None,
|
||||||
|
seed=1337,
|
||||||
|
device="cpu",
|
||||||
|
num_prev_obs=0,
|
||||||
|
num_prev_action=0,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
task=task,
|
||||||
|
frame_skip=frame_skip,
|
||||||
|
from_pixels=from_pixels,
|
||||||
|
pixels_only=pixels_only,
|
||||||
|
image_size=image_size,
|
||||||
|
seed=seed,
|
||||||
|
device=device,
|
||||||
|
num_prev_obs=num_prev_obs,
|
||||||
|
num_prev_action=num_prev_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_env(self):
|
||||||
|
if not _has_gym:
|
||||||
|
raise ImportError("Cannot import gymnasium.")
|
||||||
|
|
||||||
|
import gymnasium
|
||||||
|
|
||||||
|
from lerobot.common.envs.simxarm.simxarm import TASKS
|
||||||
|
|
||||||
|
if self.task not in TASKS:
|
||||||
|
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
||||||
|
|
||||||
|
self._env = TASKS[self.task]["env"]()
|
||||||
|
|
||||||
|
num_actions = len(TASKS[self.task]["action_space"])
|
||||||
|
self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||||
|
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
|
||||||
|
if "w" not in TASKS[self.task]["action_space"]:
|
||||||
|
self._action_padding[-1] = 1.0
|
||||||
|
|
||||||
|
def render(self, mode="rgb_array", width=384, height=384):
|
||||||
|
return self._env.render(mode, width=width, height=height)
|
||||||
|
|
||||||
|
def _format_raw_obs(self, raw_obs):
|
||||||
|
if self.from_pixels:
|
||||||
|
image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
|
||||||
|
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||||
|
image = torch.tensor(image.copy(), dtype=torch.uint8)
|
||||||
|
|
||||||
|
obs = {"image": image}
|
||||||
|
|
||||||
|
if not self.pixels_only:
|
||||||
|
obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
|
||||||
|
|
||||||
|
# obs = TensorDict(obs, batch_size=[])
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
|
td = tensordict
|
||||||
|
if td is None or td.is_empty():
|
||||||
|
raw_obs = self._env.reset()
|
||||||
|
|
||||||
|
obs = self._format_raw_obs(raw_obs)
|
||||||
|
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
stacked_obs = {}
|
||||||
|
if "image" in obs:
|
||||||
|
self._prev_obs_image_queue = deque(
|
||||||
|
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||||
|
)
|
||||||
|
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||||
|
if "state" in obs:
|
||||||
|
self._prev_obs_state_queue = deque(
|
||||||
|
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||||
|
)
|
||||||
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
|
obs = stacked_obs
|
||||||
|
|
||||||
|
td = TensorDict(
|
||||||
|
{
|
||||||
|
"observation": TensorDict(obs, batch_size=[]),
|
||||||
|
"done": torch.tensor([False], dtype=torch.bool),
|
||||||
|
},
|
||||||
|
batch_size=[],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
return td
|
||||||
|
|
||||||
|
def _step(self, tensordict: TensorDict):
|
||||||
|
td = tensordict
|
||||||
|
action = td["action"].numpy()
|
||||||
|
# step expects shape=(4,) so we pad if necessary
|
||||||
|
action = np.concatenate([action, self._action_padding])
|
||||||
|
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||||
|
sum_reward = 0
|
||||||
|
|
||||||
|
if action.ndim == 1:
|
||||||
|
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
||||||
|
else:
|
||||||
|
if self.frame_skip > 1:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
num_action_steps = action.shape[0]
|
||||||
|
for i in range(num_action_steps):
|
||||||
|
raw_obs, reward, done, info = self._env.step(action[i])
|
||||||
|
sum_reward += reward
|
||||||
|
|
||||||
|
obs = self._format_raw_obs(raw_obs)
|
||||||
|
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
stacked_obs = {}
|
||||||
|
if "image" in obs:
|
||||||
|
self._prev_obs_image_queue.append(obs["image"])
|
||||||
|
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||||
|
if "state" in obs:
|
||||||
|
self._prev_obs_state_queue.append(obs["state"])
|
||||||
|
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||||
|
obs = stacked_obs
|
||||||
|
|
||||||
|
td = TensorDict(
|
||||||
|
{
|
||||||
|
"observation": self._format_raw_obs(raw_obs),
|
||||||
|
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||||
|
"done": torch.tensor([done], dtype=torch.bool),
|
||||||
|
"success": torch.tensor([info["success"]], dtype=torch.bool),
|
||||||
|
},
|
||||||
|
batch_size=[],
|
||||||
|
)
|
||||||
|
return td
|
||||||
|
|
||||||
|
def _make_spec(self):
|
||||||
|
obs = {}
|
||||||
|
if self.from_pixels:
|
||||||
|
image_shape = (3, self.image_size, self.image_size)
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||||
|
|
||||||
|
obs["image"] = BoundedTensorSpec(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=image_shape,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
if not self.pixels_only:
|
||||||
|
state_shape = (len(self._env.robot_state),)
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||||
|
|
||||||
|
obs["state"] = UnboundedContinuousTensorSpec(
|
||||||
|
shape=state_shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||||
|
state_shape = self._env.observation_space["observation"].shape
|
||||||
|
if self.num_prev_obs > 0:
|
||||||
|
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||||
|
|
||||||
|
obs["state"] = UnboundedContinuousTensorSpec(
|
||||||
|
# TODO:
|
||||||
|
shape=state_shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.observation_spec = CompositeSpec({"observation": obs})
|
||||||
|
|
||||||
|
self.action_spec = _gym_to_torchrl_spec_transform(
|
||||||
|
self._action_space,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.done_spec = CompositeSpec(
|
||||||
|
{
|
||||||
|
"done": DiscreteTensorSpec(
|
||||||
|
2,
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"success": DiscreteTensorSpec(
|
||||||
|
2,
|
||||||
|
shape=(1,),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_seed(self, seed: Optional[int]):
|
||||||
|
set_global_seed(seed)
|
||||||
|
self._seed = seed
|
||||||
|
# TODO(aliberts): change self._reset so that it takes in a seed value
|
||||||
|
logging.warning("simxarm env is not properly seeded")
|
||||||
166
lerobot/common/envs/simxarm/simxarm/__init__.py
Normal file
166
lerobot/common/envs/simxarm/simxarm/__init__.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
from collections import OrderedDict, deque
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
from gymnasium.wrappers import TimeLimit
|
||||||
|
|
||||||
|
from lerobot.common.envs.simxarm.simxarm.tasks.base import Base as Base
|
||||||
|
from lerobot.common.envs.simxarm.simxarm.tasks.lift import Lift
|
||||||
|
from lerobot.common.envs.simxarm.simxarm.tasks.peg_in_box import PegInBox
|
||||||
|
from lerobot.common.envs.simxarm.simxarm.tasks.push import Push
|
||||||
|
from lerobot.common.envs.simxarm.simxarm.tasks.reach import Reach
|
||||||
|
|
||||||
|
TASKS = OrderedDict(
|
||||||
|
(
|
||||||
|
(
|
||||||
|
"reach",
|
||||||
|
{
|
||||||
|
"env": Reach,
|
||||||
|
"action_space": "xyz",
|
||||||
|
"episode_length": 50,
|
||||||
|
"description": "Reach a target location with the end effector",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"push",
|
||||||
|
{
|
||||||
|
"env": Push,
|
||||||
|
"action_space": "xyz",
|
||||||
|
"episode_length": 50,
|
||||||
|
"description": "Push a cube to a target location",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"peg_in_box",
|
||||||
|
{
|
||||||
|
"env": PegInBox,
|
||||||
|
"action_space": "xyz",
|
||||||
|
"episode_length": 50,
|
||||||
|
"description": "Insert a peg into a box",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"lift",
|
||||||
|
{
|
||||||
|
"env": Lift,
|
||||||
|
"action_space": "xyzw",
|
||||||
|
"episode_length": 50,
|
||||||
|
"description": "Lift a cube above a height threshold",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SimXarmWrapper(gym.Wrapper):
|
||||||
|
"""
|
||||||
|
A wrapper for the SimXarm environments. This wrapper is used to
|
||||||
|
convert the action and observation spaces to the correct format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env, task, obs_mode, image_size, action_repeat, frame_stack=1, channel_last=False):
|
||||||
|
super().__init__(env)
|
||||||
|
self._env = env
|
||||||
|
self.obs_mode = obs_mode
|
||||||
|
self.image_size = image_size
|
||||||
|
self.action_repeat = action_repeat
|
||||||
|
self.frame_stack = frame_stack
|
||||||
|
self._frames = deque([], maxlen=frame_stack)
|
||||||
|
self.channel_last = channel_last
|
||||||
|
self._max_episode_steps = task["episode_length"] // action_repeat
|
||||||
|
|
||||||
|
image_shape = (
|
||||||
|
(image_size, image_size, 3 * frame_stack)
|
||||||
|
if channel_last
|
||||||
|
else (3 * frame_stack, image_size, image_size)
|
||||||
|
)
|
||||||
|
if obs_mode == "state":
|
||||||
|
self.observation_space = env.observation_space["observation"]
|
||||||
|
elif obs_mode == "rgb":
|
||||||
|
self.observation_space = gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8)
|
||||||
|
elif obs_mode == "all":
|
||||||
|
self.observation_space = gym.spaces.Dict(
|
||||||
|
state=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32),
|
||||||
|
rgb=gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown obs_mode {obs_mode}. Must be one of [rgb, all, state]")
|
||||||
|
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(len(task["action_space"]),))
|
||||||
|
self.action_padding = np.zeros(4 - len(task["action_space"]), dtype=np.float32)
|
||||||
|
if "w" not in task["action_space"]:
|
||||||
|
self.action_padding[-1] = 1.0
|
||||||
|
|
||||||
|
def _render_obs(self):
|
||||||
|
obs = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
|
||||||
|
if not self.channel_last:
|
||||||
|
obs = obs.transpose(2, 0, 1)
|
||||||
|
return obs.copy()
|
||||||
|
|
||||||
|
def _update_frames(self, reset=False):
|
||||||
|
pixels = self._render_obs()
|
||||||
|
self._frames.append(pixels)
|
||||||
|
if reset:
|
||||||
|
for _ in range(1, self.frame_stack):
|
||||||
|
self._frames.append(pixels)
|
||||||
|
assert len(self._frames) == self.frame_stack
|
||||||
|
|
||||||
|
def transform_obs(self, obs, reset=False):
|
||||||
|
if self.obs_mode == "state":
|
||||||
|
return obs["observation"]
|
||||||
|
elif self.obs_mode == "rgb":
|
||||||
|
self._update_frames(reset=reset)
|
||||||
|
rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
|
||||||
|
return rgb_obs
|
||||||
|
elif self.obs_mode == "all":
|
||||||
|
self._update_frames(reset=reset)
|
||||||
|
rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0)
|
||||||
|
return OrderedDict((("rgb", rgb_obs), ("state", self.robot_state)))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown obs_mode {self.obs_mode}. Must be one of [rgb, all, state]")
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return self.transform_obs(self._env.reset(), reset=True)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
action = np.concatenate([action, self.action_padding])
|
||||||
|
reward = 0.0
|
||||||
|
for _ in range(self.action_repeat):
|
||||||
|
obs, r, done, info = self._env.step(action)
|
||||||
|
reward += r
|
||||||
|
return self.transform_obs(obs), reward, done, info
|
||||||
|
|
||||||
|
def render(self, mode="rgb_array", width=384, height=384, **kwargs):
|
||||||
|
return self._env.render(mode, width=width, height=height)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
return self._env.robot_state
|
||||||
|
|
||||||
|
|
||||||
|
def make(task, obs_mode="state", image_size=84, action_repeat=1, frame_stack=1, channel_last=False, seed=0):
|
||||||
|
"""
|
||||||
|
Create a new environment.
|
||||||
|
Args:
|
||||||
|
task (str): The task to create an environment for. Must be one of:
|
||||||
|
- 'reach'
|
||||||
|
- 'push'
|
||||||
|
- 'peg-in-box'
|
||||||
|
- 'lift'
|
||||||
|
obs_mode (str): The observation mode to use. Must be one of:
|
||||||
|
- 'state': Only state observations
|
||||||
|
- 'rgb': RGB images
|
||||||
|
- 'all': RGB images and state observations
|
||||||
|
image_size (int): The size of the image observations
|
||||||
|
action_repeat (int): The number of times to repeat the action
|
||||||
|
seed (int): The random seed to use
|
||||||
|
Returns:
|
||||||
|
gym.Env: The environment
|
||||||
|
"""
|
||||||
|
if task not in TASKS:
|
||||||
|
raise ValueError(f"Unknown task {task}. Must be one of {list(TASKS.keys())}")
|
||||||
|
env = TASKS[task]["env"]()
|
||||||
|
env = TimeLimit(env, TASKS[task]["episode_length"])
|
||||||
|
env = SimXarmWrapper(env, TASKS[task], obs_mode, image_size, action_repeat, frame_stack, channel_last)
|
||||||
|
env.seed(seed)
|
||||||
|
|
||||||
|
return env
|
||||||
53
lerobot/common/envs/simxarm/simxarm/tasks/assets/lift.xml
Normal file
53
lerobot/common/envs/simxarm/simxarm/tasks/assets/lift.xml
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
|
||||||
|
<mujoco>
|
||||||
|
<compiler angle="radian" coordinate="local" meshdir="mesh" texturedir="texture"></compiler>
|
||||||
|
<size nconmax="2000" njmax="500"/>
|
||||||
|
|
||||||
|
<option timestep="0.002">
|
||||||
|
<flag warmstart="enable"></flag>
|
||||||
|
</option>
|
||||||
|
|
||||||
|
<include file="shared.xml"></include>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<body name="floor0" pos="0 0 0">
|
||||||
|
<geom name="floorgeom0" pos="1.2 -2.0 0" size="20.0 20.0 1" type="plane" condim="3" material="floor_mat"></geom>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<include file="xarm.xml"></include>
|
||||||
|
|
||||||
|
<body pos="0.75 0 0.6325" name="pedestal0">
|
||||||
|
<geom name="pedestalgeom0" size="0.1 0.1 0.01" pos="0.32 0.27 0" type="box" mass="2000" material="pedestal_mat"></geom>
|
||||||
|
<site pos="0.30 0.30 0" size="0.075 0.075 0.002" type="box" name="robotmountsite0" rgba="0.55 0.54 0.53 1" />
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body pos="1.5 0.075 0.3425" name="table0">
|
||||||
|
<geom name="tablegeom0" size="0.3 0.6 0.2" pos="0 0 0" type="box" material="table_mat" density="2000" friction="1 1 1"></geom>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="object" pos="1.405 0.3 0.58625">
|
||||||
|
<joint name="object_joint0" type="free" limited="false"></joint>
|
||||||
|
<geom size="0.035 0.035 0.035" type="box" name="object0" material="block_mat" density="50000" condim="4" friction="1 1 1" solimp="1 1 1" solref="0.02 1"></geom>
|
||||||
|
<site name="object_site" pos="0 0 0" size="0.035 0.035 0.035" rgba="1 0 0 0" type="box"></site>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="1.65 0 10" dir="-0.57 -0.57 -0.57" name="light0"></light>
|
||||||
|
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="0 -4 4" dir="0 1 -0.1" name="light1"></light>
|
||||||
|
<light directional="true" ambient="0.05 0.05 0.05" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="2.13 1.6 2.5" name="light2"></light>
|
||||||
|
<light pos="0 0 2" dir="0.2 0.2 -0.8" directional="true" diffuse="0.3 0.3 0.3" castshadow="false" name="light3"></light>
|
||||||
|
|
||||||
|
<camera fovy="50" name="camera0" pos="0.9559 1.0 1.1" euler="-1.1 -0.6 3.4" />
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<equality>
|
||||||
|
<connect body2="left_finger" body1="left_inner_knuckle" anchor="0.0 0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||||
|
<connect body2="right_finger" body1="right_inner_knuckle" anchor="0.0 -0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||||
|
<joint joint1="left_inner_knuckle_joint" joint2="right_inner_knuckle_joint"></joint>
|
||||||
|
</equality>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="left_inner_knuckle_joint" gear="200.0"/>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="right_inner_knuckle_joint" gear="200.0"/>
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:21fb81ae7fba19e3c6b2d2ca60c8051712ba273357287eb5a397d92d61c7a736
|
||||||
|
size 1211434
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:be68ce180d11630a667a5f37f4dffcc3feebe4217d4bb3912c813b6d9ca3ec66
|
||||||
|
size 3284
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:2c6448552bf6b1c4f17334d686a5320ce051bcdfe31431edf69303d8a570d1de
|
||||||
|
size 3284
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:748b9e197e6521914f18d1f6383a36f211136b3f33f2ad2a8c11b9f921c2cf86
|
||||||
|
size 6284
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:a44756eb72f9c214cb37e61dc209cd7073fdff3e4271a7423476ef6fd090d2d4
|
||||||
|
size 242684
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:e8e48692ad26837bb3d6a97582c89784d09948fc09bfe4e5a59017859ff04dac
|
||||||
|
size 366284
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:501665812b08d67e764390db781e839adc6896a9540301d60adf606f57648921
|
||||||
|
size 22284
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:34b541122df84d2ef5fcb91b715eb19659dc15ad8d44a191dde481f780265636
|
||||||
|
size 184184
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:61e641cd47c169ecef779683332e00e4914db729bf02dfb61bfbe69351827455
|
||||||
|
size 225584
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:9e2798e7946dd70046c95455d5ba96392d0b54a6069caba91dc4ca66e1379b42
|
||||||
|
size 237084
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:c757fee95f873191a0633c355c07a360032960771cabbd7593a6cdb0f1ffb089
|
||||||
|
size 243684
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:715ad5787c5dab57589937fd47289882707b5e1eb997e340d567785b02f4ec90
|
||||||
|
size 229084
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:85b320aa420497827223d16d492bba8de091173374e361396fc7a5dad7bdb0cb
|
||||||
|
size 399384
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:97115d848fbf802cb770cd9be639ae2af993103b9d9bbb0c50c943c738a36f18
|
||||||
|
size 231684
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:f6fcbc18258090eb56c21cfb17baa5ae43abc98b1958cd366f3a73b9898fc7f0
|
||||||
|
size 2106184
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:c5dee87c7f37baf554b8456ebfe0b3e8ed0b22b8938bd1add6505c2ad6d32c7d
|
||||||
|
size 242684
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:b41dd2c2c550281bf78d7cc6fa117b14786700e5c453560a0cb5fd6dfa0ffb3e
|
||||||
|
size 366284
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:75ca1107d0a42a0f03802a9a49cab48419b31851ee8935f8f1ca06be1c1c91e8
|
||||||
|
size 22284
|
||||||
@@ -0,0 +1,74 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
|
||||||
|
<mujoco>
|
||||||
|
<compiler angle="radian" coordinate="local" meshdir="mesh" texturedir="texture"></compiler>
|
||||||
|
<size nconmax="2000" njmax="500"/>
|
||||||
|
|
||||||
|
<option timestep="0.001">
|
||||||
|
<flag warmstart="enable"></flag>
|
||||||
|
</option>
|
||||||
|
|
||||||
|
<include file="shared.xml"></include>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<body name="floor0" pos="0 0 0">
|
||||||
|
<geom name="floorgeom0" pos="1.2 -2.0 0" size="1.0 10.0 1" type="plane" condim="3" material="floor_mat"></geom>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<include file="xarm.xml"></include>
|
||||||
|
|
||||||
|
<body pos="0.75 0 0.6325" name="pedestal0">
|
||||||
|
<geom name="pedestalgeom0" size="0.1 0.1 0.01" pos="0.32 0.27 0" type="box" mass="2000" material="pedestal_mat"></geom>
|
||||||
|
<site pos="0.30 0.30 0" size="0.075 0.075 0.002" type="box" name="robotmountsite0" rgba="0.55 0.54 0.53 1" />
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body pos="1.5 0.075 0.3425" name="table0">
|
||||||
|
<geom name="tablegeom0" size="0.3 0.6 0.2" pos="0 0 0" type="box" material="table_mat" density="2000" friction="1 0.005 0.0002"></geom>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="box0" pos="1.605 0.25 0.55">
|
||||||
|
<joint name="box_joint0" type="free" limited="false"></joint>
|
||||||
|
<site name="box_site" pos="0 0.075 -0.01" size="0.02" rgba="0 0 0 0" type="sphere"></site>
|
||||||
|
<geom name="box_side0" pos="0 0 0" size="0.065 0.002 0.04" type= "box" rgba="0.8 0.1 0.1 1" mass ="1" condim="4" />
|
||||||
|
<geom name="box_side1" pos="0 0.149 0" size="0.065 0.002 0.04" type="box" rgba="0.9 0.2 0.2 1" mass ="2" condim="4" />
|
||||||
|
<geom name="box_side2" pos="0.064 0.074 0" size="0.002 0.075 0.04" type="box" rgba="0.8 0.1 0.1 1" mass ="2" condim="4" />
|
||||||
|
<geom name="box_side3" pos="-0.064 0.074 0" size="0.002 0.075 0.04" type="box" rgba="0.9 0.2 0.2 1" mass ="2" condim="4" />
|
||||||
|
<geom name="box_side4" pos="-0 0.074 -0.038" size="0.065 0.075 0.002" type="box" rgba="0.5 0 0 1" mass ="2" condim="4"/>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="object0" pos="1.4 0.25 0.65">
|
||||||
|
<joint name="object_joint0" type="free" limited="false"></joint>
|
||||||
|
<geom name="object_target0" type="cylinder" pos="0 0 -0.05" size="0.03 0.035" rgba="0.6 0.8 0.5 1" mass ="0.1" condim="3" />
|
||||||
|
<site name="object_site" pos="0 0 -0.05" size="0.0325 0.0375" rgba="0 0 0 0" type="cylinder"></site>
|
||||||
|
<body name="B0" pos="0 0 0" euler="0 0 0 ">
|
||||||
|
<joint name="B0:joint" type="slide" limited="true" axis="0 0 1" damping="0.05" range="0.0001 0.0001001" solimpfriction="0.98 0.98 0.95" frictionloss="1"></joint>
|
||||||
|
<geom type="capsule" size="0.002 0.03" rgba="0 0 0 1" mass="0.001" condim="4"/>
|
||||||
|
<body name="B1" pos="0 0 0.04" euler="0 3.14 0 ">
|
||||||
|
<joint name="B1:joint1" type="hinge" axis="1 0 0" range="-0.1 0.1" frictionloss="1"></joint>
|
||||||
|
<joint name="B1:joint2" type="hinge" axis="0 1 0" range="-0.1 0.1" frictionloss="1"></joint>
|
||||||
|
<joint name="B1:joint3" type="hinge" axis="0 0 1" range="-0.1 0.1" frictionloss="1"></joint>
|
||||||
|
<geom type="capsule" size="0.002 0.004" rgba="1 0 0 0" mass="0.001" condim="4"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="1.65 0 10" dir="-0.57 -0.57 -0.57" name="light0"></light>
|
||||||
|
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="0 -4 4" dir="0 1 -0.1" name="light1"></light>
|
||||||
|
<light directional="true" ambient="0.05 0.05 0.05" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="2.13 1.6 2.5" name="light2"></light>
|
||||||
|
<light pos="0 0 2" dir="0.2 0.2 -0.8" directional="true" diffuse="0.3 0.3 0.3" castshadow="false" name="light3"></light>
|
||||||
|
|
||||||
|
<camera fovy="50" name="camera0" pos="0.9559 1.0 1.1" euler="-1.1 -0.6 3.4" />
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<equality>
|
||||||
|
<connect body2="left_finger" body1="left_inner_knuckle" anchor="0.0 0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||||
|
<connect body2="right_finger" body1="right_inner_knuckle" anchor="0.0 -0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||||
|
<weld body1="right_hand" body2="B1" solimp="0.99 0.99 0.99" solref="0.02 1"></weld>
|
||||||
|
<joint joint1="left_inner_knuckle_joint" joint2="right_inner_knuckle_joint"></joint>
|
||||||
|
</equality>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="left_inner_knuckle_joint" gear="200.0"/>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="right_inner_knuckle_joint" gear="200.0"/>
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
||||||
54
lerobot/common/envs/simxarm/simxarm/tasks/assets/push.xml
Normal file
54
lerobot/common/envs/simxarm/simxarm/tasks/assets/push.xml
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
|
||||||
|
<mujoco>
|
||||||
|
<compiler angle="radian" coordinate="local" meshdir="mesh" texturedir="texture"></compiler>
|
||||||
|
<size nconmax="2000" njmax="500"/>
|
||||||
|
|
||||||
|
<option timestep="0.002">
|
||||||
|
<flag warmstart="enable"></flag>
|
||||||
|
</option>
|
||||||
|
|
||||||
|
<include file="shared.xml"></include>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<body name="floor0" pos="0 0 0">
|
||||||
|
<geom name="floorgeom0" pos="1.2 -2.0 0" size="1.0 10.0 1" type="plane" condim="3" material="floor_mat"></geom>
|
||||||
|
<site name="target0" pos="1.565 0.3 0.545" size="0.0475 0.001" rgba="1 0 0 1" type="cylinder"></site>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<include file="xarm.xml"></include>
|
||||||
|
|
||||||
|
<body pos="0.75 0 0.6325" name="pedestal0">
|
||||||
|
<geom name="pedestalgeom0" size="0.1 0.1 0.01" pos="0.32 0.27 0" type="box" mass="2000" material="pedestal_mat"></geom>
|
||||||
|
<site pos="0.30 0.30 0" size="0.075 0.075 0.002" type="box" name="robotmountsite0" rgba="0.55 0.54 0.53 1" />
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body pos="1.5 0.075 0.3425" name="table0">
|
||||||
|
<geom name="tablegeom0" size="0.3 0.6 0.2" pos="0 0 0" type="box" material="table_mat" density="2000" friction="1 0.005 0.0002"></geom>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="object" pos="1.655 0.3 0.68">
|
||||||
|
<joint name="object_joint0" type="free" limited="false"></joint>
|
||||||
|
<geom size="0.024 0.024 0.024" type="box" name="object" material="block_mat" density="50000" condim="4" friction="1 1 1" solimp="1 1 1" solref="0.02 1"></geom>
|
||||||
|
<site name="object_site" pos="0 0 0" size="0.024 0.024 0.024" rgba="0 0 0 0" type="box"></site>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="1.65 0 10" dir="-0.57 -0.57 -0.57" name="light0"></light>
|
||||||
|
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="0 -4 4" dir="0 1 -0.1" name="light1"></light>
|
||||||
|
<light directional="true" ambient="0.05 0.05 0.05" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="2.13 1.6 2.5" name="light2"></light>
|
||||||
|
<light pos="0 0 2" dir="0.2 0.2 -0.8" directional="true" diffuse="0.3 0.3 0.3" castshadow="false" name="light3"></light>
|
||||||
|
|
||||||
|
<camera fovy="50" name="camera0" pos="0.9559 1.0 1.1" euler="-1.1 -0.6 3.4" />
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<equality>
|
||||||
|
<connect body2="left_finger" body1="left_inner_knuckle" anchor="0.0 0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||||
|
<connect body2="right_finger" body1="right_inner_knuckle" anchor="0.0 -0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||||
|
<joint joint1="left_inner_knuckle_joint" joint2="right_inner_knuckle_joint"></joint>
|
||||||
|
</equality>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="left_inner_knuckle_joint" gear="200.0"/>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="right_inner_knuckle_joint" gear="200.0"/>
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
||||||
48
lerobot/common/envs/simxarm/simxarm/tasks/assets/reach.xml
Normal file
48
lerobot/common/envs/simxarm/simxarm/tasks/assets/reach.xml
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
|
||||||
|
<mujoco>
|
||||||
|
<compiler angle="radian" coordinate="local" meshdir="mesh" texturedir="texture"></compiler>
|
||||||
|
<size nconmax="2000" njmax="500"/>
|
||||||
|
|
||||||
|
<option timestep="0.002">
|
||||||
|
<flag warmstart="enable"></flag>
|
||||||
|
</option>
|
||||||
|
|
||||||
|
<include file="shared.xml"></include>
|
||||||
|
|
||||||
|
<worldbody>
|
||||||
|
<body name="floor0" pos="0 0 0">
|
||||||
|
<geom name="floorgeom0" pos="1.2 -2.0 0" size="1.0 10.0 1" type="plane" condim="3" material="floor_mat"></geom>
|
||||||
|
<site name="target0" pos="1.605 0.3 0.58" size="0.0475 0.001" rgba="1 0 0 1" type="cylinder"></site>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<include file="xarm.xml"></include>
|
||||||
|
|
||||||
|
<body pos="0.75 0 0.6325" name="pedestal0">
|
||||||
|
<geom name="pedestalgeom0" size="0.1 0.1 0.01" pos="0.32 0.27 0" type="box" mass="2000" material="pedestal_mat"></geom>
|
||||||
|
<site pos="0.30 0.30 0" size="0.075 0.075 0.002" type="box" name="robotmountsite0" rgba="0.55 0.54 0.53 1" />
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body pos="1.5 0.075 0.3425" name="table0">
|
||||||
|
<geom name="tablegeom0" size="0.3 0.6 0.2" pos="0 0 0" type="box" material="table_mat" density="2000" friction="1 0.005 0.0002"></geom>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="1.65 0 10" dir="-0.57 -0.57 -0.57" name="light0"></light>
|
||||||
|
<light directional="true" ambient="0.1 0.1 0.1" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="0 -4 4" dir="0 1 -0.1" name="light1"></light>
|
||||||
|
<light directional="true" ambient="0.05 0.05 0.05" diffuse="0 0 0" specular="0 0 0" castshadow="false" pos="2.13 1.6 2.5" name="light2"></light>
|
||||||
|
<light pos="0 0 2" dir="0.2 0.2 -0.8" directional="true" diffuse="0.3 0.3 0.3" castshadow="false" name="light3"></light>
|
||||||
|
|
||||||
|
<camera fovy="50" name="camera0" pos="0.9559 1.0 1.1" euler="-1.1 -0.6 3.4" />
|
||||||
|
</worldbody>
|
||||||
|
|
||||||
|
<equality>
|
||||||
|
<connect body2="left_finger" body1="left_inner_knuckle" anchor="0.0 0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||||
|
<connect body2="right_finger" body1="right_inner_knuckle" anchor="0.0 -0.035 0.042" solimp="0.9 0.95 0.001 0.5 2" solref="0.0002 1.0" ></connect>
|
||||||
|
<joint joint1="left_inner_knuckle_joint" joint2="right_inner_knuckle_joint"></joint>
|
||||||
|
</equality>
|
||||||
|
|
||||||
|
<actuator>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="left_inner_knuckle_joint" gear="200.0"/>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="right_inner_knuckle_joint" gear="200.0"/>
|
||||||
|
</actuator>
|
||||||
|
</mujoco>
|
||||||
51
lerobot/common/envs/simxarm/simxarm/tasks/assets/shared.xml
Normal file
51
lerobot/common/envs/simxarm/simxarm/tasks/assets/shared.xml
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
<mujoco>
|
||||||
|
<asset>
|
||||||
|
<texture type="skybox" builtin="gradient" rgb1="0.0 0.0 0.0" rgb2="0.0 0.0 0.0" width="32" height="32"></texture>
|
||||||
|
<material name="floor_mat" specular="0" shininess="0.0" reflectance="0" rgba="0.043 0.055 0.051 1"></material>
|
||||||
|
|
||||||
|
<material name="table_mat" specular="0.2" shininess="0.2" reflectance="0" rgba="1 1 1 1"></material>
|
||||||
|
<material name="pedestal_mat" specular="0.35" shininess="0.5" reflectance="0" rgba="0.705 0.585 0.405 1"></material>
|
||||||
|
<material name="block_mat" specular="0.5" shininess="0.9" reflectance="0.05" rgba="0.373 0.678 0.627 1"></material>
|
||||||
|
|
||||||
|
<material name="robot0:geomMat" shininess="0.03" specular="0.4"></material>
|
||||||
|
<material name="robot0:gripper_finger_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||||
|
<material name="robot0:gripper_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||||
|
<material name="background:gripper_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||||
|
<material name="robot0:arm_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||||
|
<material name="robot0:head_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||||
|
<material name="robot0:torso_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||||
|
<material name="robot0:base_mat" shininess="0.03" specular="0.4" reflectance="0"></material>
|
||||||
|
|
||||||
|
<mesh name="link_base" file="link_base.stl" />
|
||||||
|
<mesh name="link1" file="link1.stl" />
|
||||||
|
<mesh name="link2" file="link2.stl" />
|
||||||
|
<mesh name="link3" file="link3.stl" />
|
||||||
|
<mesh name="link4" file="link4.stl" />
|
||||||
|
<mesh name="link5" file="link5.stl" />
|
||||||
|
<mesh name="link6" file="link6.stl" />
|
||||||
|
<mesh name="link7" file="link7.stl" />
|
||||||
|
<mesh name="base_link" file="base_link.stl" />
|
||||||
|
<mesh name="left_outer_knuckle" file="left_outer_knuckle.stl" />
|
||||||
|
<mesh name="left_finger" file="left_finger.stl" />
|
||||||
|
<mesh name="left_inner_knuckle" file="left_inner_knuckle.stl" />
|
||||||
|
<mesh name="right_outer_knuckle" file="right_outer_knuckle.stl" />
|
||||||
|
<mesh name="right_finger" file="right_finger.stl" />
|
||||||
|
<mesh name="right_inner_knuckle" file="right_inner_knuckle.stl" />
|
||||||
|
</asset>
|
||||||
|
|
||||||
|
<equality>
|
||||||
|
<weld body1="robot0:mocap2" body2="link7" solimp="0.9 0.95 0.001" solref="0.02 1"></weld>
|
||||||
|
</equality>
|
||||||
|
|
||||||
|
<default>
|
||||||
|
<joint armature="1" damping="0.1" limited="true"/>
|
||||||
|
<default class="robot0:blue">
|
||||||
|
<geom rgba="0.086 0.506 0.767 1.0"></geom>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
<default class="robot0:grey">
|
||||||
|
<geom rgba="0.356 0.361 0.376 1.0"></geom>
|
||||||
|
</default>
|
||||||
|
</default>
|
||||||
|
|
||||||
|
</mujoco>
|
||||||
88
lerobot/common/envs/simxarm/simxarm/tasks/assets/xarm.xml
Normal file
88
lerobot/common/envs/simxarm/simxarm/tasks/assets/xarm.xml
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
<mujoco model="xarm7">
|
||||||
|
<body mocap="true" name="robot0:mocap2" pos="0 0 0">
|
||||||
|
<geom conaffinity="0" contype="0" pos="0 0 0" rgba="0 0.5 0 0" size="0.005 0.005 0.005" type="box"></geom>
|
||||||
|
<geom conaffinity="0" contype="0" pos="0 0 0" rgba="0.5 0 0 0" size="1 0.005 0.005" type="box"></geom>
|
||||||
|
<geom conaffinity="0" contype="0" pos="0 0 0" rgba="0 0 0.5 0" size="0.005 1 0.001" type="box"></geom>
|
||||||
|
<geom conaffinity="0" contype="0" pos="0 0 0" rgba="0.5 0.5 0 0" size="0.005 0.005 1" type="box"></geom>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="link0" pos="1.09 0.28 0.655">
|
||||||
|
<geom name="bb" type="mesh" mesh="link_base" material="robot0:base_mat" rgba="1 1 1 1"/>
|
||||||
|
<body name="link1" pos="0 0 0.267">
|
||||||
|
<inertial pos="-0.0042142 0.02821 -0.0087788" quat="0.917781 -0.277115 0.0606681 0.277858" mass="0.42603" diaginertia="0.00144551 0.00137757 0.000823511" />
|
||||||
|
<joint name="joint1" pos="0 0 0" axis="0 0 1" limited="true" range="-6.28319 6.28319" damping="10" frictionloss="1" />
|
||||||
|
<geom name="j1" type="mesh" mesh="link1" material="robot0:arm_mat" rgba="1 1 1 1"/>
|
||||||
|
<body name="link2" pos="0 0 0" quat="0.707105 -0.707108 0 0">
|
||||||
|
<inertial pos="-3.3178e-05 -0.12849 0.026337" quat="0.447793 0.894132 -0.00224061 0.00218314" mass="0.56095" diaginertia="0.00319151 0.00311598 0.000980804" />
|
||||||
|
<joint name="joint2" pos="0 0 0" axis="0 0 1" limited="true" range="-2.059 2.0944" damping="10" frictionloss="1" />
|
||||||
|
<geom name="j2" type="mesh" mesh="link2" material="robot0:head_mat" rgba="1 1 1 1"/>
|
||||||
|
<body name="link3" pos="0 -0.293 0" quat="0.707105 0.707108 0 0">
|
||||||
|
<inertial pos="0.04223 -0.023258 -0.0096674" quat="0.883205 0.339803 0.323238 0.000542237" mass="0.44463" diaginertia="0.00133227 0.00119126 0.000780475" />
|
||||||
|
<joint name="joint3" pos="0 0 0" axis="0 0 1" limited="true" range="-6.28319 6.28319" damping="5" frictionloss="1" />
|
||||||
|
<geom name="j3" type="mesh" mesh="link3" material="robot0:gripper_mat" rgba="1 1 1 1"/>
|
||||||
|
<body name="link4" pos="0.0525 0 0" quat="0.707105 0.707108 0 0">
|
||||||
|
<inertial pos="0.067148 -0.10732 0.024479" quat="0.0654142 0.483317 -0.738663 0.465298" mass="0.52387" diaginertia="0.00288984 0.00282705 0.000894409" />
|
||||||
|
<joint name="joint4" pos="0 0 0" axis="0 0 1" limited="true" range="-0.19198 3.927" damping="5" frictionloss="1" />
|
||||||
|
<geom name="j4" type="mesh" mesh="link4" material="robot0:arm_mat" rgba="1 1 1 1"/>
|
||||||
|
<body name="link5" pos="0.0775 -0.3425 0" quat="0.707105 0.707108 0 0">
|
||||||
|
<inertial pos="-0.00023397 0.036705 -0.080064" quat="0.981064 -0.19003 0.00637998 0.0369004" mass="0.18554" diaginertia="0.00099553 0.000988613 0.000247126" />
|
||||||
|
<joint name="joint5" pos="0 0 0" axis="0 0 1" limited="true" range="-6.28319 6.28319" damping="5" frictionloss="1" />
|
||||||
|
<geom name="j5" type="mesh" material="robot0:gripper_mat" rgba="1 1 1 1" mesh="link5" />
|
||||||
|
<body name="link6" pos="0 0 0" quat="0.707105 0.707108 0 0">
|
||||||
|
<inertial pos="0.058911 0.028469 0.0068428" quat="-0.188705 0.793535 0.166088 0.554173" mass="0.31344" diaginertia="0.000827892 0.000768871 0.000386708" />
|
||||||
|
<joint name="joint6" pos="0 0 0" axis="0 0 1" limited="true" range="-1.69297 3.14159" damping="2" frictionloss="1" />
|
||||||
|
<geom name="j6" type="mesh" material="robot0:gripper_mat" rgba="1 1 1 1" mesh="link6" />
|
||||||
|
<body name="link7" pos="0.076 0.097 0" quat="0.707105 -0.707108 0 0">
|
||||||
|
<inertial pos="-0.000420033 -0.00287433 0.0257078" quat="0.999372 -0.0349129 -0.00605634 0.000551744" mass="0.85624" diaginertia="0.00137671 0.00118744 0.000514968" />
|
||||||
|
<joint name="joint7" pos="0 0 0" axis="0 0 1" limited="true" range="-6.28319 6.28319" damping="2" frictionloss="1" />
|
||||||
|
<geom name="j8" material="robot0:gripper_mat" type="mesh" rgba="0.753 0.753 0.753 1" mesh="link7" />
|
||||||
|
<geom name="j9" material="robot0:gripper_mat" type="mesh" rgba="1 1 1 1" mesh="base_link" />
|
||||||
|
<site name="grasp" pos="0 0 0.16" rgba="1 0 0 0" type="sphere" size="0.01" group="1"/>
|
||||||
|
<body name="left_outer_knuckle" pos="0 0.035 0.059098">
|
||||||
|
<inertial pos="0 0.021559 0.015181" quat="0.47789 0.87842 0 0" mass="0.033618" diaginertia="1.9111e-05 1.79089e-05 1.90167e-06" />
|
||||||
|
<joint name="drive_joint" pos="0 0 0" axis="1 0 0" limited="true" range="0 0.85" />
|
||||||
|
<geom type="mesh" rgba="0 0 0 1" conaffinity="1" contype="0" mesh="left_outer_knuckle" />
|
||||||
|
<body name="left_finger" pos="0 0.035465 0.042039">
|
||||||
|
<inertial pos="0 -0.016413 0.029258" quat="0.697634 0.115353 -0.115353 0.697634" mass="0.048304" diaginertia="1.88037e-05 1.7493e-05 3.56792e-06" />
|
||||||
|
<joint name="left_finger_joint" pos="0 0 0" axis="-1 0 0" limited="true" range="0 0.85" />
|
||||||
|
<geom name="j10" material="robot0:gripper_finger_mat" type="mesh" rgba="0 0 0 1" conaffinity="3" contype="2" mesh="left_finger" friction='1.5 1.5 1.5' solref='0.01 1' solimp='0.99 0.99 0.01'/>
|
||||||
|
<body name="right_hand" pos="0 -0.03 0.05" quat="-0.7071 0 0 0.7071">
|
||||||
|
<site name="ee" pos="0 0 0" rgba="0 0 1 0" type="sphere" group="1"/>
|
||||||
|
<site name="ee_x" pos="0 0 0" size="0.005 .1" quat="0.707105 0.707108 0 0 " rgba="1 0 0 0" type="cylinder" group="1"/>
|
||||||
|
<site name="ee_z" pos="0 0 0" size="0.005 .1" quat="0.707105 0 0 0.707108" rgba="0 0 1 0" type="cylinder" group="1"/>
|
||||||
|
<site name="ee_y" pos="0 0 0" size="0.005 .1" quat="0.707105 0 0.707108 0 " rgba="0 1 0 0" type="cylinder" group="1"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
<body name="left_inner_knuckle" pos="0 0.02 0.074098">
|
||||||
|
<inertial pos="1.86601e-06 0.0220468 0.0261335" quat="0.664139 -0.242732 0.242713 0.664146" mass="0.0230126" diaginertia="8.34216e-06 6.0949e-06 2.75601e-06" />
|
||||||
|
<joint name="left_inner_knuckle_joint" pos="0 0 0" axis="1 0 0" limited="true" range="0 0.85" />
|
||||||
|
<geom type="mesh" rgba="0 0 0 1" conaffinity="1" contype="0" mesh="left_inner_knuckle" friction='1.5 1.5 1.5' solref='0.01 1' solimp='0.99 0.99 0.01'/>
|
||||||
|
</body>
|
||||||
|
<body name="right_outer_knuckle" pos="0 -0.035 0.059098">
|
||||||
|
<inertial pos="0 -0.021559 0.015181" quat="0.87842 0.47789 0 0" mass="0.033618" diaginertia="1.9111e-05 1.79089e-05 1.90167e-06" />
|
||||||
|
<joint name="right_outer_knuckle_joint" pos="0 0 0" axis="-1 0 0" limited="true" range="0 0.85" />
|
||||||
|
<geom type="mesh" rgba="0 0 0 1" conaffinity="1" contype="0" mesh="right_outer_knuckle" />
|
||||||
|
<body name="right_finger" pos="0 -0.035465 0.042039">
|
||||||
|
<inertial pos="0 0.016413 0.029258" quat="0.697634 -0.115356 0.115356 0.697634" mass="0.048304" diaginertia="1.88038e-05 1.7493e-05 3.56779e-06" />
|
||||||
|
<joint name="right_finger_joint" pos="0 0 0" axis="1 0 0" limited="true" range="0 0.85" />
|
||||||
|
<geom name="j11" material="robot0:gripper_finger_mat" type="mesh" rgba="0 0 0 1" conaffinity="3" contype="2" mesh="right_finger" friction='1.5 1.5 1.5' solref='0.01 1' solimp='0.99 0.99 0.01'/>
|
||||||
|
<body name="left_hand" pos="0 0.03 0.05" quat="-0.7071 0 0 0.7071">
|
||||||
|
<site name="ee_2" pos="0 0 0" rgba="1 0 0 0" type="sphere" size="0.01" group="1"/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
<body name="right_inner_knuckle" pos="0 -0.02 0.074098">
|
||||||
|
<inertial pos="1.866e-06 -0.022047 0.026133" quat="0.66415 0.242702 -0.242721 0.664144" mass="0.023013" diaginertia="8.34209e-06 6.0949e-06 2.75601e-06" />
|
||||||
|
<joint name="right_inner_knuckle_joint" pos="0 0 0" axis="-1 0 0" limited="true" range="0 0.85" />
|
||||||
|
<geom type="mesh" rgba="0 0 0 1" conaffinity="1" contype="0" mesh="right_inner_knuckle" friction='1.5 1.5 1.5' solref='0.01 1' solimp='0.99 0.99 0.01'/>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</body>
|
||||||
|
</mujoco>
|
||||||
145
lerobot/common/envs/simxarm/simxarm/tasks/base.py
Normal file
145
lerobot/common/envs/simxarm/simxarm/tasks/base.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import mujoco
|
||||||
|
import numpy as np
|
||||||
|
from gymnasium_robotics.envs import robot_env
|
||||||
|
|
||||||
|
from lerobot.common.envs.simxarm.simxarm.tasks import mocap
|
||||||
|
|
||||||
|
|
||||||
|
class Base(robot_env.MujocoRobotEnv):
|
||||||
|
"""
|
||||||
|
Superclass for all simxarm environments.
|
||||||
|
Args:
|
||||||
|
xml_name (str): name of the xml environment file
|
||||||
|
gripper_rotation (list): initial rotation of the gripper (given as a quaternion)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, xml_name, gripper_rotation=None):
|
||||||
|
if gripper_rotation is None:
|
||||||
|
gripper_rotation = [0, 1, 0, 0]
|
||||||
|
self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32)
|
||||||
|
self.center_of_table = np.array([1.655, 0.3, 0.63625])
|
||||||
|
self.max_z = 1.2
|
||||||
|
self.min_z = 0.2
|
||||||
|
super().__init__(
|
||||||
|
model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"),
|
||||||
|
n_substeps=20,
|
||||||
|
n_actions=4,
|
||||||
|
initial_qpos={},
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self):
|
||||||
|
return self.n_substeps * self.model.opt.timestep
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eef(self):
|
||||||
|
return self._utils.get_site_xpos(self.model, self.data, "grasp")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def obj(self):
|
||||||
|
return self._utils.get_site_xpos(self.model, self.data, "object_site")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def robot_state(self):
|
||||||
|
gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint")
|
||||||
|
return np.concatenate([self.eef, gripper_angle])
|
||||||
|
|
||||||
|
def is_success(self):
|
||||||
|
return NotImplementedError()
|
||||||
|
|
||||||
|
def get_reward(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _sample_goal(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_obs(self):
|
||||||
|
return self._get_obs()
|
||||||
|
|
||||||
|
def _step_callback(self):
|
||||||
|
self._mujoco.mj_forward(self.model, self.data)
|
||||||
|
|
||||||
|
def _limit_gripper(self, gripper_pos, pos_ctrl):
|
||||||
|
if gripper_pos[0] > self.center_of_table[0] - 0.105 + 0.15:
|
||||||
|
pos_ctrl[0] = min(pos_ctrl[0], 0)
|
||||||
|
if gripper_pos[0] < self.center_of_table[0] - 0.105 - 0.3:
|
||||||
|
pos_ctrl[0] = max(pos_ctrl[0], 0)
|
||||||
|
if gripper_pos[1] > self.center_of_table[1] + 0.3:
|
||||||
|
pos_ctrl[1] = min(pos_ctrl[1], 0)
|
||||||
|
if gripper_pos[1] < self.center_of_table[1] - 0.3:
|
||||||
|
pos_ctrl[1] = max(pos_ctrl[1], 0)
|
||||||
|
if gripper_pos[2] > self.max_z:
|
||||||
|
pos_ctrl[2] = min(pos_ctrl[2], 0)
|
||||||
|
if gripper_pos[2] < self.min_z:
|
||||||
|
pos_ctrl[2] = max(pos_ctrl[2], 0)
|
||||||
|
return pos_ctrl
|
||||||
|
|
||||||
|
def _apply_action(self, action):
|
||||||
|
assert action.shape == (4,)
|
||||||
|
action = action.copy()
|
||||||
|
pos_ctrl, gripper_ctrl = action[:3], action[3]
|
||||||
|
pos_ctrl = self._limit_gripper(
|
||||||
|
self._utils.get_site_xpos(self.model, self.data, "grasp"), pos_ctrl
|
||||||
|
) * (1 / self.n_substeps)
|
||||||
|
gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl])
|
||||||
|
mocap.apply_action(
|
||||||
|
self.model,
|
||||||
|
self._model_names,
|
||||||
|
self.data,
|
||||||
|
np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _render_callback(self):
|
||||||
|
self._mujoco.mj_forward(self.model, self.data)
|
||||||
|
|
||||||
|
def _reset_sim(self):
|
||||||
|
self.data.time = self.initial_time
|
||||||
|
self.data.qpos[:] = np.copy(self.initial_qpos)
|
||||||
|
self.data.qvel[:] = np.copy(self.initial_qvel)
|
||||||
|
self._sample_goal()
|
||||||
|
self._mujoco.mj_step(self.model, self.data, nstep=10)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _set_gripper(self, gripper_pos, gripper_rotation):
|
||||||
|
self._utils.set_mocap_pos(self.model, self.data, "robot0:mocap", gripper_pos)
|
||||||
|
self._utils.set_mocap_quat(self.model, self.data, "robot0:mocap", gripper_rotation)
|
||||||
|
self._utils.set_joint_qpos(self.model, self.data, "right_outer_knuckle_joint", 0)
|
||||||
|
self.data.qpos[10] = 0.0
|
||||||
|
self.data.qpos[12] = 0.0
|
||||||
|
|
||||||
|
def _env_setup(self, initial_qpos):
|
||||||
|
for name, value in initial_qpos.items():
|
||||||
|
self.data.set_joint_qpos(name, value)
|
||||||
|
mocap.reset(self.model, self.data)
|
||||||
|
mujoco.mj_forward(self.model, self.data)
|
||||||
|
self._sample_goal()
|
||||||
|
mujoco.mj_forward(self.model, self.data)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._reset_sim()
|
||||||
|
return self._get_obs()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
assert action.shape == (4,)
|
||||||
|
assert self.action_space.contains(action), "{!r} ({}) invalid".format(action, type(action))
|
||||||
|
self._apply_action(action)
|
||||||
|
self._mujoco.mj_step(self.model, self.data, nstep=2)
|
||||||
|
self._step_callback()
|
||||||
|
obs = self._get_obs()
|
||||||
|
reward = self.get_reward()
|
||||||
|
done = False
|
||||||
|
info = {"is_success": self.is_success(), "success": self.is_success()}
|
||||||
|
return obs, reward, done, info
|
||||||
|
|
||||||
|
def render(self, mode="rgb_array", width=384, height=384):
|
||||||
|
self._render_callback()
|
||||||
|
# HACK
|
||||||
|
self.model.vis.global_.offwidth = width
|
||||||
|
self.model.vis.global_.offheight = height
|
||||||
|
return self.mujoco_renderer.render(mode)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.mujoco_renderer is not None:
|
||||||
|
self.mujoco_renderer.close()
|
||||||
100
lerobot/common/envs/simxarm/simxarm/tasks/lift.py
Normal file
100
lerobot/common/envs/simxarm/simxarm/tasks/lift.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.common.envs.simxarm.simxarm import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Lift(Base):
|
||||||
|
def __init__(self):
|
||||||
|
self._z_threshold = 0.15
|
||||||
|
super().__init__("lift")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def z_target(self):
|
||||||
|
return self._init_z + self._z_threshold
|
||||||
|
|
||||||
|
def is_success(self):
|
||||||
|
return self.obj[2] >= self.z_target
|
||||||
|
|
||||||
|
def get_reward(self):
|
||||||
|
reach_dist = np.linalg.norm(self.obj - self.eef)
|
||||||
|
reach_dist_xy = np.linalg.norm(self.obj[:-1] - self.eef[:-1])
|
||||||
|
pick_completed = self.obj[2] >= (self.z_target - 0.01)
|
||||||
|
obj_dropped = (self.obj[2] < (self._init_z + 0.005)) and (reach_dist > 0.02)
|
||||||
|
|
||||||
|
# Reach
|
||||||
|
if reach_dist < 0.05:
|
||||||
|
reach_reward = -reach_dist + max(self._action[-1], 0) / 50
|
||||||
|
elif reach_dist_xy < 0.05:
|
||||||
|
reach_reward = -reach_dist
|
||||||
|
else:
|
||||||
|
z_bonus = np.linalg.norm(np.linalg.norm(self.obj[-1] - self.eef[-1]))
|
||||||
|
reach_reward = -reach_dist - 2 * z_bonus
|
||||||
|
|
||||||
|
# Pick
|
||||||
|
if pick_completed and not obj_dropped:
|
||||||
|
pick_reward = self.z_target
|
||||||
|
elif (reach_dist < 0.1) and (self.obj[2] > (self._init_z + 0.005)):
|
||||||
|
pick_reward = min(self.z_target, self.obj[2])
|
||||||
|
else:
|
||||||
|
pick_reward = 0
|
||||||
|
|
||||||
|
return reach_reward / 100 + pick_reward
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
eef_velp = self._utils.get_site_xvelp(self.model, self.data, "grasp") * self.dt
|
||||||
|
gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint")
|
||||||
|
eef = self.eef - self.center_of_table
|
||||||
|
|
||||||
|
obj = self.obj - self.center_of_table
|
||||||
|
obj_rot = self._utils.get_joint_qpos(self.model, self.data, "object_joint0")[-4:]
|
||||||
|
obj_velp = self._utils.get_site_xvelp(self.model, self.data, "object_site") * self.dt
|
||||||
|
obj_velr = self._utils.get_site_xvelr(self.model, self.data, "object_site") * self.dt
|
||||||
|
|
||||||
|
obs = np.concatenate(
|
||||||
|
[
|
||||||
|
eef,
|
||||||
|
eef_velp,
|
||||||
|
obj,
|
||||||
|
obj_rot,
|
||||||
|
obj_velp,
|
||||||
|
obj_velr,
|
||||||
|
eef - obj,
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
np.linalg.norm(eef - obj),
|
||||||
|
np.linalg.norm(eef[:-1] - obj[:-1]),
|
||||||
|
self.z_target,
|
||||||
|
self.z_target - obj[-1],
|
||||||
|
self.z_target - eef[-1],
|
||||||
|
]
|
||||||
|
),
|
||||||
|
gripper_angle,
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": eef}
|
||||||
|
|
||||||
|
def _sample_goal(self):
|
||||||
|
# Gripper
|
||||||
|
gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3)
|
||||||
|
super()._set_gripper(gripper_pos, self.gripper_rotation)
|
||||||
|
|
||||||
|
# Object
|
||||||
|
object_pos = self.center_of_table - np.array([0.15, 0.10, 0.07])
|
||||||
|
object_pos[0] += self.np_random.uniform(-0.05, 0.05, size=1)
|
||||||
|
object_pos[1] += self.np_random.uniform(-0.05, 0.05, size=1)
|
||||||
|
object_qpos = self._utils.get_joint_qpos(self.model, self.data, "object_joint0")
|
||||||
|
object_qpos[:3] = object_pos
|
||||||
|
self._utils.set_joint_qpos(self.model, self.data, "object_joint0", object_qpos)
|
||||||
|
self._init_z = object_pos[2]
|
||||||
|
|
||||||
|
# Goal
|
||||||
|
return object_pos + np.array([0, 0, self._z_threshold])
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._action = np.zeros(4)
|
||||||
|
return super().reset()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
self._action = action.copy()
|
||||||
|
return super().step(action)
|
||||||
67
lerobot/common/envs/simxarm/simxarm/tasks/mocap.py
Normal file
67
lerobot/common/envs/simxarm/simxarm/tasks/mocap.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
# import mujoco_py
|
||||||
|
import mujoco
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def apply_action(model, model_names, data, action):
|
||||||
|
if model.nmocap > 0:
|
||||||
|
pos_action, gripper_action = np.split(action, (model.nmocap * 7,))
|
||||||
|
if data.ctrl is not None:
|
||||||
|
for i in range(gripper_action.shape[0]):
|
||||||
|
data.ctrl[i] = gripper_action[i]
|
||||||
|
pos_action = pos_action.reshape(model.nmocap, 7)
|
||||||
|
pos_delta, quat_delta = pos_action[:, :3], pos_action[:, 3:]
|
||||||
|
reset_mocap2body_xpos(model, model_names, data)
|
||||||
|
data.mocap_pos[:] = data.mocap_pos + pos_delta
|
||||||
|
data.mocap_quat[:] = data.mocap_quat + quat_delta
|
||||||
|
|
||||||
|
|
||||||
|
def reset(model, data):
|
||||||
|
if model.nmocap > 0 and model.eq_data is not None:
|
||||||
|
for i in range(model.eq_data.shape[0]):
|
||||||
|
# if sim.model.eq_type[i] == mujoco_py.const.EQ_WELD:
|
||||||
|
if model.eq_type[i] == mujoco.mjtEq.mjEQ_WELD:
|
||||||
|
# model.eq_data[i, :] = np.array([0., 0., 0., 1., 0., 0., 0.])
|
||||||
|
model.eq_data[i, :] = np.array(
|
||||||
|
[
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# sim.forward()
|
||||||
|
mujoco.mj_forward(model, data)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_mocap2body_xpos(model, model_names, data):
|
||||||
|
if model.eq_type is None or model.eq_obj1id is None or model.eq_obj2id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# For all weld constraints
|
||||||
|
for eq_type, obj1_id, obj2_id in zip(model.eq_type, model.eq_obj1id, model.eq_obj2id, strict=False):
|
||||||
|
# if eq_type != mujoco_py.const.EQ_WELD:
|
||||||
|
if eq_type != mujoco.mjtEq.mjEQ_WELD:
|
||||||
|
continue
|
||||||
|
# body2 = model.body_id2name(obj2_id)
|
||||||
|
body2 = model_names.body_id2name[obj2_id]
|
||||||
|
if body2 == "B0" or body2 == "B9" or body2 == "B1":
|
||||||
|
continue
|
||||||
|
mocap_id = model.body_mocapid[obj1_id]
|
||||||
|
if mocap_id != -1:
|
||||||
|
# obj1 is the mocap, obj2 is the welded body
|
||||||
|
body_idx = obj2_id
|
||||||
|
else:
|
||||||
|
# obj2 is the mocap, obj1 is the welded body
|
||||||
|
mocap_id = model.body_mocapid[obj2_id]
|
||||||
|
body_idx = obj1_id
|
||||||
|
assert mocap_id != -1
|
||||||
|
data.mocap_pos[mocap_id][:] = data.xpos[body_idx]
|
||||||
|
data.mocap_quat[mocap_id][:] = data.xquat[body_idx]
|
||||||
86
lerobot/common/envs/simxarm/simxarm/tasks/peg_in_box.py
Normal file
86
lerobot/common/envs/simxarm/simxarm/tasks/peg_in_box.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.common.envs.simxarm.simxarm import Base
|
||||||
|
|
||||||
|
|
||||||
|
class PegInBox(Base):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("peg_in_box")
|
||||||
|
|
||||||
|
def _reset_sim(self):
|
||||||
|
self._act_magnitude = 0
|
||||||
|
super()._reset_sim()
|
||||||
|
for _ in range(10):
|
||||||
|
self._apply_action(np.array([0, 0, 0, 1], dtype=np.float32))
|
||||||
|
self.sim.step()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def box(self):
|
||||||
|
return self.sim.data.get_site_xpos("box_site")
|
||||||
|
|
||||||
|
def is_success(self):
|
||||||
|
return np.linalg.norm(self.obj - self.box) <= 0.05
|
||||||
|
|
||||||
|
def get_reward(self):
|
||||||
|
dist_xy = np.linalg.norm(self.obj[:2] - self.box[:2])
|
||||||
|
dist_xyz = np.linalg.norm(self.obj - self.box)
|
||||||
|
return float(dist_xy <= 0.045) * (2 - 6 * dist_xyz) - 0.2 * np.square(self._act_magnitude) - dist_xy
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt
|
||||||
|
gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint")
|
||||||
|
eef, box = self.eef - self.center_of_table, self.box - self.center_of_table
|
||||||
|
|
||||||
|
obj = self.obj - self.center_of_table
|
||||||
|
obj_rot = self.sim.data.get_joint_qpos("object_joint0")[-4:]
|
||||||
|
obj_velp = self.sim.data.get_site_xvelp("object_site") * self.dt
|
||||||
|
obj_velr = self.sim.data.get_site_xvelr("object_site") * self.dt
|
||||||
|
|
||||||
|
obs = np.concatenate(
|
||||||
|
[
|
||||||
|
eef,
|
||||||
|
eef_velp,
|
||||||
|
box,
|
||||||
|
obj,
|
||||||
|
obj_rot,
|
||||||
|
obj_velp,
|
||||||
|
obj_velr,
|
||||||
|
eef - box,
|
||||||
|
eef - obj,
|
||||||
|
obj - box,
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
np.linalg.norm(eef - box),
|
||||||
|
np.linalg.norm(eef - obj),
|
||||||
|
np.linalg.norm(obj - box),
|
||||||
|
gripper_angle,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": box}
|
||||||
|
|
||||||
|
def _sample_goal(self):
|
||||||
|
# Gripper
|
||||||
|
gripper_pos = np.array([1.280, 0.295, 0.9]) + self.np_random.uniform(-0.05, 0.05, size=3)
|
||||||
|
super()._set_gripper(gripper_pos, self.gripper_rotation)
|
||||||
|
|
||||||
|
# Object
|
||||||
|
object_pos = gripper_pos - np.array([0, 0, 0.06]) + self.np_random.uniform(-0.005, 0.005, size=3)
|
||||||
|
object_qpos = self.sim.data.get_joint_qpos("object_joint0")
|
||||||
|
object_qpos[:3] = object_pos
|
||||||
|
self.sim.data.set_joint_qpos("object_joint0", object_qpos)
|
||||||
|
|
||||||
|
# Box
|
||||||
|
box_pos = np.array([1.61, 0.18, 0.58])
|
||||||
|
box_pos[:2] += self.np_random.uniform(-0.11, 0.11, size=2)
|
||||||
|
box_qpos = self.sim.data.get_joint_qpos("box_joint0")
|
||||||
|
box_qpos[:3] = box_pos
|
||||||
|
self.sim.data.set_joint_qpos("box_joint0", box_qpos)
|
||||||
|
|
||||||
|
return self.box
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
self._act_magnitude = np.linalg.norm(action[:3])
|
||||||
|
return super().step(action)
|
||||||
78
lerobot/common/envs/simxarm/simxarm/tasks/push.py
Normal file
78
lerobot/common/envs/simxarm/simxarm/tasks/push.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.common.envs.simxarm.simxarm import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Push(Base):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("push")
|
||||||
|
|
||||||
|
def _reset_sim(self):
|
||||||
|
self._act_magnitude = 0
|
||||||
|
super()._reset_sim()
|
||||||
|
|
||||||
|
def is_success(self):
|
||||||
|
return np.linalg.norm(self.obj - self.goal) <= 0.05
|
||||||
|
|
||||||
|
def get_reward(self):
|
||||||
|
dist = np.linalg.norm(self.obj - self.goal)
|
||||||
|
penalty = self._act_magnitude**2
|
||||||
|
return -(dist + 0.15 * penalty)
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt
|
||||||
|
gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint")
|
||||||
|
eef, goal = self.eef - self.center_of_table, self.goal - self.center_of_table
|
||||||
|
|
||||||
|
obj = self.obj - self.center_of_table
|
||||||
|
obj_rot = self.sim.data.get_joint_qpos("object_joint0")[-4:]
|
||||||
|
obj_velp = self.sim.data.get_site_xvelp("object_site") * self.dt
|
||||||
|
obj_velr = self.sim.data.get_site_xvelr("object_site") * self.dt
|
||||||
|
|
||||||
|
obs = np.concatenate(
|
||||||
|
[
|
||||||
|
eef,
|
||||||
|
eef_velp,
|
||||||
|
goal,
|
||||||
|
obj,
|
||||||
|
obj_rot,
|
||||||
|
obj_velp,
|
||||||
|
obj_velr,
|
||||||
|
eef - goal,
|
||||||
|
eef - obj,
|
||||||
|
obj - goal,
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
np.linalg.norm(eef - goal),
|
||||||
|
np.linalg.norm(eef - obj),
|
||||||
|
np.linalg.norm(obj - goal),
|
||||||
|
gripper_angle,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": goal}
|
||||||
|
|
||||||
|
def _sample_goal(self):
|
||||||
|
# Gripper
|
||||||
|
gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3)
|
||||||
|
super()._set_gripper(gripper_pos, self.gripper_rotation)
|
||||||
|
|
||||||
|
# Object
|
||||||
|
object_pos = self.center_of_table - np.array([0.25, 0, 0.07])
|
||||||
|
object_pos[0] += self.np_random.uniform(-0.08, 0.08, size=1)
|
||||||
|
object_pos[1] += self.np_random.uniform(-0.08, 0.08, size=1)
|
||||||
|
object_qpos = self.sim.data.get_joint_qpos("object_joint0")
|
||||||
|
object_qpos[:3] = object_pos
|
||||||
|
self.sim.data.set_joint_qpos("object_joint0", object_qpos)
|
||||||
|
|
||||||
|
# Goal
|
||||||
|
self.goal = np.array([1.600, 0.200, 0.545])
|
||||||
|
self.goal[:2] += self.np_random.uniform(-0.1, 0.1, size=2)
|
||||||
|
self.sim.model.site_pos[self.sim.model.site_name2id("target0")] = self.goal
|
||||||
|
return self.goal
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
self._act_magnitude = np.linalg.norm(action[:3])
|
||||||
|
return super().step(action)
|
||||||
44
lerobot/common/envs/simxarm/simxarm/tasks/reach.py
Normal file
44
lerobot/common/envs/simxarm/simxarm/tasks/reach.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.common.envs.simxarm.simxarm import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Reach(Base):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("reach")
|
||||||
|
|
||||||
|
def _reset_sim(self):
|
||||||
|
self._act_magnitude = 0
|
||||||
|
super()._reset_sim()
|
||||||
|
|
||||||
|
def is_success(self):
|
||||||
|
return np.linalg.norm(self.eef - self.goal) <= 0.05
|
||||||
|
|
||||||
|
def get_reward(self):
|
||||||
|
dist = np.linalg.norm(self.eef - self.goal)
|
||||||
|
penalty = self._act_magnitude**2
|
||||||
|
return -(dist + 0.15 * penalty)
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt
|
||||||
|
gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint")
|
||||||
|
eef, goal = self.eef - self.center_of_table, self.goal - self.center_of_table
|
||||||
|
obs = np.concatenate(
|
||||||
|
[eef, eef_velp, goal, eef - goal, np.array([np.linalg.norm(eef - goal), gripper_angle])], axis=0
|
||||||
|
)
|
||||||
|
return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": goal}
|
||||||
|
|
||||||
|
def _sample_goal(self):
|
||||||
|
# Gripper
|
||||||
|
gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3)
|
||||||
|
super()._set_gripper(gripper_pos, self.gripper_rotation)
|
||||||
|
|
||||||
|
# Goal
|
||||||
|
self.goal = np.array([1.550, 0.287, 0.580])
|
||||||
|
self.goal[:2] += self.np_random.uniform(-0.125, 0.125, size=2)
|
||||||
|
self.sim.model.site_pos[self.sim.model.site_name2id("target0")] = self.goal
|
||||||
|
return self.goal
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
self._act_magnitude = np.linalg.norm(action[:3])
|
||||||
|
return super().step(action)
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
import einops
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_observation(observation):
|
|
||||||
# map to expected inputs for the policy
|
|
||||||
obs = {}
|
|
||||||
|
|
||||||
if isinstance(observation["pixels"], dict):
|
|
||||||
imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
|
|
||||||
else:
|
|
||||||
imgs = {"observation.image": observation["pixels"]}
|
|
||||||
|
|
||||||
for imgkey, img in imgs.items():
|
|
||||||
img = torch.from_numpy(img)
|
|
||||||
|
|
||||||
# sanity check that images are channel last
|
|
||||||
_, h, w, c = img.shape
|
|
||||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
|
||||||
|
|
||||||
# sanity check that images are uint8
|
|
||||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
|
||||||
|
|
||||||
# convert to channel first of type float32 in range [0,1]
|
|
||||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
|
||||||
img = img.type(torch.float32)
|
|
||||||
img /= 255
|
|
||||||
|
|
||||||
obs[imgkey] = img
|
|
||||||
|
|
||||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
|
||||||
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
|
||||||
|
|
||||||
return obs
|
|
||||||
|
|
||||||
|
|
||||||
def postprocess_action(action):
|
|
||||||
action = action.to("cpu").numpy()
|
|
||||||
assert (
|
|
||||||
action.ndim == 2
|
|
||||||
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
|
|
||||||
return action
|
|
||||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
|
|
||||||
def log_output_dir(out_dir):
|
def log_output_dir(out_dir):
|
||||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||||
@@ -67,11 +68,11 @@ class Logger:
|
|||||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||||
self._wandb = wandb
|
self._wandb = wandb
|
||||||
|
|
||||||
def save_model(self, policy, identifier):
|
def save_model(self, policy: AbstractPolicy, identifier):
|
||||||
if self._save_model:
|
if self._save_model:
|
||||||
self._model_dir.mkdir(parents=True, exist_ok=True)
|
self._model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
fp = self._model_dir / f"{str(identifier)}.pt"
|
fp = self._model_dir / f"{str(identifier)}.pt"
|
||||||
policy.save(fp)
|
policy.save_pretrained(fp)
|
||||||
if self._wandb and not self._disable_wandb_artifact:
|
if self._wandb and not self._disable_wandb_artifact:
|
||||||
# note wandb artifact does not accept ":" in its name
|
# note wandb artifact does not accept ":" in its name
|
||||||
artifact = self._wandb.Artifact(
|
artifact = self._wandb.Artifact(
|
||||||
|
|||||||
93
lerobot/common/policies/abstract.py
Normal file
93
lerobot/common/policies/abstract.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
"""Base policy which all policies should be derived from.
|
||||||
|
|
||||||
|
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||||
|
documentation for more information.
|
||||||
|
|
||||||
|
The policy is a PyTorchModelHubMixin, which means that it can be saved and loaded from the Hugging Face Hub and/or to a local directory.
|
||||||
|
# Save policy weights to local directory
|
||||||
|
>>> policy.save_pretrained("my-awesome-policy")
|
||||||
|
|
||||||
|
# Push policy weights to the Hub
|
||||||
|
>>> policy.push_to_hub("my-awesome-policy")
|
||||||
|
|
||||||
|
# Download and initialize policy from the Hub
|
||||||
|
>>> policy = MyPolicy.from_pretrained("username/my-awesome-policy")
|
||||||
|
|
||||||
|
Note:
|
||||||
|
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||||
|
1. set the required class attributes:
|
||||||
|
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||||
|
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||||
|
- for classes inheriting from `AbstractPolicy`: `name`
|
||||||
|
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||||
|
3. update variables in `tests/test_available.py` by importing your new class
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||||
|
|
||||||
|
def __init__(self, n_action_steps: int | None = None):
|
||||||
|
"""
|
||||||
|
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||||
|
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||||
|
adds that dimension.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
||||||
|
self.n_action_steps = n_action_steps
|
||||||
|
self.clear_action_queue()
|
||||||
|
|
||||||
|
def update(self, replay_buffer, step):
|
||||||
|
"""One step of the policy's learning algorithm."""
|
||||||
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
|
def save(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
|
||||||
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
|
def load(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
|
||||||
|
d = torch.load(fp)
|
||||||
|
self.load_state_dict(d)
|
||||||
|
|
||||||
|
def select_actions(self, observation) -> Tensor:
|
||||||
|
"""Select an action (or trajectory of actions) based on an observation during rollout.
|
||||||
|
|
||||||
|
If n_action_steps was provided at initialization, this should return a (batch_size, n_action_steps, *) tensor of
|
||||||
|
actions. Otherwise if n_actions_steps is None, this should return a (batch_size, *) tensor of actions.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
|
def clear_action_queue(self):
|
||||||
|
"""This should be called whenever the environment is reset."""
|
||||||
|
if self.n_action_steps is not None:
|
||||||
|
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs) -> Tensor:
|
||||||
|
"""Inference step that makes multi-step policies compatible with their single-step environments.
|
||||||
|
|
||||||
|
WARNING: In general, this should not be overriden.
|
||||||
|
|
||||||
|
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit
|
||||||
|
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an
|
||||||
|
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment
|
||||||
|
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that
|
||||||
|
the subclass doesn't have to.
|
||||||
|
|
||||||
|
This method effectively wraps the `select_actions` method of the subclass. The following assumptions are made:
|
||||||
|
1. The `select_actions` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is
|
||||||
|
the action trajectory horizon and * is the action dimensions.
|
||||||
|
2. Prior to the `select_actions` method being called, theres is an `n_action_steps` instance attribute defined.
|
||||||
|
"""
|
||||||
|
if self.n_action_steps is None:
|
||||||
|
return self.select_actions(*args, **kwargs)
|
||||||
|
if len(self._action_queue) == 0:
|
||||||
|
# `select_actions` returns a (batch_size, n_action_steps, *) tensor, but the queue effectively has shape
|
||||||
|
# (n_action_steps, batch_size, *), hence the transpose.
|
||||||
|
self._action_queue.extend(self.select_actions(*args, **kwargs).transpose(0, 1))
|
||||||
|
return self._action_queue.popleft()
|
||||||
115
lerobot/common/policies/act/backbone.py
Normal file
115
lerobot/common/policies/act/backbone.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torch import nn
|
||||||
|
from torchvision.models._utils import IntermediateLayerGetter
|
||||||
|
|
||||||
|
from .position_encoding import build_position_encoding
|
||||||
|
from .utils import NestedTensor, is_main_process
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenBatchNorm2d(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||||
|
|
||||||
|
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||||
|
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||||
|
produce nans.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("weight", torch.ones(n))
|
||||||
|
self.register_buffer("bias", torch.zeros(n))
|
||||||
|
self.register_buffer("running_mean", torch.zeros(n))
|
||||||
|
self.register_buffer("running_var", torch.ones(n))
|
||||||
|
|
||||||
|
def _load_from_state_dict(
|
||||||
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
|
):
|
||||||
|
num_batches_tracked_key = prefix + "num_batches_tracked"
|
||||||
|
if num_batches_tracked_key in state_dict:
|
||||||
|
del state_dict[num_batches_tracked_key]
|
||||||
|
|
||||||
|
super()._load_from_state_dict(
|
||||||
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# move reshapes to the beginning
|
||||||
|
# to make it fuser-friendly
|
||||||
|
w = self.weight.reshape(1, -1, 1, 1)
|
||||||
|
b = self.bias.reshape(1, -1, 1, 1)
|
||||||
|
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||||
|
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||||
|
eps = 1e-5
|
||||||
|
scale = w * (rv + eps).rsqrt()
|
||||||
|
bias = b - rm * scale
|
||||||
|
return x * scale + bias
|
||||||
|
|
||||||
|
|
||||||
|
class BackboneBase(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||||
|
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||||
|
# parameter.requires_grad_(False)
|
||||||
|
if return_interm_layers:
|
||||||
|
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||||
|
else:
|
||||||
|
return_layers = {"layer4": "0"}
|
||||||
|
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||||
|
self.num_channels = num_channels
|
||||||
|
|
||||||
|
def forward(self, tensor):
|
||||||
|
xs = self.body(tensor)
|
||||||
|
return xs
|
||||||
|
# out: Dict[str, NestedTensor] = {}
|
||||||
|
# for name, x in xs.items():
|
||||||
|
# m = tensor_list.mask
|
||||||
|
# assert m is not None
|
||||||
|
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||||
|
# out[name] = NestedTensor(x, mask)
|
||||||
|
# return out
|
||||||
|
|
||||||
|
|
||||||
|
class Backbone(BackboneBase):
|
||||||
|
"""ResNet backbone with frozen BatchNorm."""
|
||||||
|
|
||||||
|
def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
|
||||||
|
backbone = getattr(torchvision.models, name)(
|
||||||
|
replace_stride_with_dilation=[False, False, dilation],
|
||||||
|
pretrained=is_main_process(),
|
||||||
|
norm_layer=FrozenBatchNorm2d,
|
||||||
|
) # pretrained # TODO do we want frozen batch_norm??
|
||||||
|
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
|
||||||
|
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||||
|
|
||||||
|
|
||||||
|
class Joiner(nn.Sequential):
|
||||||
|
def __init__(self, backbone, position_embedding):
|
||||||
|
super().__init__(backbone, position_embedding)
|
||||||
|
|
||||||
|
def forward(self, tensor_list: NestedTensor):
|
||||||
|
xs = self[0](tensor_list)
|
||||||
|
out: List[NestedTensor] = []
|
||||||
|
pos = []
|
||||||
|
for _, x in xs.items():
|
||||||
|
out.append(x)
|
||||||
|
# position encoding
|
||||||
|
pos.append(self[1](x).to(x.dtype))
|
||||||
|
|
||||||
|
return out, pos
|
||||||
|
|
||||||
|
|
||||||
|
def build_backbone(args):
|
||||||
|
position_embedding = build_position_encoding(args)
|
||||||
|
train_backbone = args.lr_backbone > 0
|
||||||
|
return_interm_layers = args.masks
|
||||||
|
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
||||||
|
model = Joiner(backbone, position_embedding)
|
||||||
|
model.num_channels = backbone.num_channels
|
||||||
|
return model
|
||||||
@@ -1,150 +0,0 @@
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
class ActionChunkingTransformerConfig(PretrainedConfig):
|
|
||||||
"""Configuration class for the Action Chunking Transformers policy.
|
|
||||||
|
|
||||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
|
||||||
|
|
||||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
|
||||||
Those are: `input_shapes` and 'output_shapes`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
|
||||||
current step and additional steps going back).
|
|
||||||
chunk_size: The size of the action prediction "chunks" in units of environment steps.
|
|
||||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
|
||||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
|
||||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
|
||||||
environment, and throws the other 50 out.
|
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
|
||||||
The key represents the input data name, and the value is a list indicating the dimensions
|
|
||||||
of the corresponding data. For example, "observation.images.top" refers to an input from the
|
|
||||||
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
|
||||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
|
||||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
|
||||||
The key represents the output data name, and the value is a list indicating the dimensions
|
|
||||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
|
||||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
|
||||||
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
|
||||||
and the value specifies the normalization mode to apply. The two availables
|
|
||||||
modes are "mean_std" which substracts the mean and divide by the standard
|
|
||||||
deviation and "min_max" which rescale in a [-1, 1] range.
|
|
||||||
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
|
||||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
|
||||||
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
|
|
||||||
torchvision.
|
|
||||||
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
|
||||||
convolution.
|
|
||||||
pre_norm: Whether to use "pre-norm" in the transformer blocks.
|
|
||||||
d_model: The transformer blocks' main hidden dimension.
|
|
||||||
n_heads: The number of heads to use in the transformer blocks' multi-head attention.
|
|
||||||
dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
|
|
||||||
layers.
|
|
||||||
feedforward_activation: The activation to use in the transformer block's feed-forward layers.
|
|
||||||
n_encoder_layers: The number of transformer layers to use for the transformer encoder.
|
|
||||||
n_decoder_layers: The number of transformer layers to use for the transformer decoder.
|
|
||||||
use_vae: Whether to use a variational objective during training. This introduces another transformer
|
|
||||||
which is used as the VAE's encoder (not to be confused with the transformer encoder - see
|
|
||||||
documentation in the policy class).
|
|
||||||
latent_dim: The VAE's latent dimension.
|
|
||||||
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
|
||||||
use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given
|
|
||||||
environment step.
|
|
||||||
dropout: Dropout to use in the transformer layers (see code for details).
|
|
||||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
|
||||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from lerobot import ActionChunkingTransformerConfig
|
|
||||||
|
|
||||||
>>> # Initializing an ACT style configuration
|
|
||||||
>>> configuration = ActionChunkingTransformerConfig()
|
|
||||||
|
|
||||||
>>> # Initializing a model (with random weights) from the ACT style configuration
|
|
||||||
>>> model = ActionChunkingTransformerPolicy(configuration)
|
|
||||||
|
|
||||||
>>> # Accessing the model configuration
|
|
||||||
>>> configuration = model.config
|
|
||||||
```"""
|
|
||||||
|
|
||||||
# Input / output structure.
|
|
||||||
n_obs_steps: int = 1
|
|
||||||
chunk_size: int = 100
|
|
||||||
n_action_steps: int = 100
|
|
||||||
|
|
||||||
input_shapes: dict[str, list[str]] = {
|
|
||||||
"observation.images.top": [3, 480, 640],
|
|
||||||
"observation.state": [14],
|
|
||||||
}
|
|
||||||
|
|
||||||
output_shapes: dict[str, list[str]] = {"action": [14]}
|
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
|
||||||
normalize_input_modes: dict[str, str] = {
|
|
||||||
"observation.image": "mean_std",
|
|
||||||
"observation.state": "mean_std",
|
|
||||||
}
|
|
||||||
|
|
||||||
unnormalize_output_modes: dict[str, str] = {"action": "mean_std"}
|
|
||||||
|
|
||||||
# Architecture.
|
|
||||||
# Vision backbone.
|
|
||||||
vision_backbone: str = "resnet18"
|
|
||||||
use_pretrained_backbone: bool = True
|
|
||||||
replace_final_stride_with_dilation: int = False
|
|
||||||
# Transformer layers.
|
|
||||||
pre_norm: bool = False
|
|
||||||
d_model: int = 512
|
|
||||||
n_heads: int = 8
|
|
||||||
dim_feedforward: int = 3200
|
|
||||||
feedforward_activation: str = "relu"
|
|
||||||
n_encoder_layers: int = 4
|
|
||||||
n_decoder_layers: int = 1
|
|
||||||
# VAE.
|
|
||||||
use_vae: bool = True
|
|
||||||
latent_dim: int = 32
|
|
||||||
n_vae_encoder_layers: int = 4
|
|
||||||
|
|
||||||
# Inference.
|
|
||||||
use_temporal_aggregation: bool = False
|
|
||||||
|
|
||||||
# Training and loss computation.
|
|
||||||
dropout: float = 0.1
|
|
||||||
kl_weight: float = 10.0
|
|
||||||
|
|
||||||
# ---
|
|
||||||
# TODO(alexander-soare): Remove these from the policy config.
|
|
||||||
batch_size: int = 8
|
|
||||||
lr: float = 1e-5
|
|
||||||
lr_backbone: float = 1e-5
|
|
||||||
weight_decay: float = 1e-4
|
|
||||||
grad_clip_norm: float = 10
|
|
||||||
utd: int = 1
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
"""Input validation (not exhaustive)."""
|
|
||||||
if not self.vision_backbone.startswith("resnet"):
|
|
||||||
raise ValueError(
|
|
||||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
|
||||||
)
|
|
||||||
if self.use_temporal_aggregation:
|
|
||||||
raise NotImplementedError("Temporal aggregation is not yet implemented.")
|
|
||||||
if self.n_action_steps > self.chunk_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
|
||||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
|
||||||
)
|
|
||||||
if self.n_obs_steps != 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
|
||||||
)
|
|
||||||
# Check that there is only one image.
|
|
||||||
# TODO(alexander-soare): generalize this to multiple images.
|
|
||||||
if (
|
|
||||||
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
|
|
||||||
or "observation.images.top" not in self.input_shapes
|
|
||||||
):
|
|
||||||
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')
|
|
||||||
212
lerobot/common/policies/act/detr_vae.py
Normal file
212
lerobot/common/policies/act/detr_vae.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
from .backbone import build_backbone
|
||||||
|
from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer
|
||||||
|
|
||||||
|
|
||||||
|
def reparametrize(mu, logvar):
|
||||||
|
std = logvar.div(2).exp()
|
||||||
|
eps = Variable(std.data.new(std.size()).normal_())
|
||||||
|
return mu + std * eps
|
||||||
|
|
||||||
|
|
||||||
|
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||||
|
def get_position_angle_vec(position):
|
||||||
|
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||||
|
|
||||||
|
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||||
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||||
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||||
|
|
||||||
|
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class DETRVAE(nn.Module):
|
||||||
|
"""This is the DETR module that performs object detection"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae
|
||||||
|
):
|
||||||
|
"""Initializes the model.
|
||||||
|
Parameters:
|
||||||
|
backbones: torch module of the backbone to be used. See backbone.py
|
||||||
|
transformer: torch module of the transformer architecture. See transformer.py
|
||||||
|
state_dim: robot state dimension of the environment
|
||||||
|
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||||
|
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||||
|
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_queries = num_queries
|
||||||
|
self.camera_names = camera_names
|
||||||
|
self.transformer = transformer
|
||||||
|
self.encoder = encoder
|
||||||
|
self.vae = vae
|
||||||
|
hidden_dim = transformer.d_model
|
||||||
|
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||||
|
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||||
|
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||||
|
if backbones is not None:
|
||||||
|
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
||||||
|
self.backbones = nn.ModuleList(backbones)
|
||||||
|
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||||
|
else:
|
||||||
|
# input_dim = 14 + 7 # robot_state + env_state
|
||||||
|
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||||
|
# TODO(rcadene): understand what is env_state, and why it needs to be 7
|
||||||
|
self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim)
|
||||||
|
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||||
|
self.backbones = None
|
||||||
|
|
||||||
|
# encoder extra parameters
|
||||||
|
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||||
|
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||||
|
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
|
||||||
|
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
|
||||||
|
self.latent_proj = nn.Linear(
|
||||||
|
hidden_dim, self.latent_dim * 2
|
||||||
|
) # project hidden state to latent std, var
|
||||||
|
self.register_buffer(
|
||||||
|
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
|
||||||
|
) # [CLS], qpos, a_seq
|
||||||
|
|
||||||
|
# decoder extra parameters
|
||||||
|
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
||||||
|
self.additional_pos_embed = nn.Embedding(
|
||||||
|
2, hidden_dim
|
||||||
|
) # learned position embedding for proprio and latent
|
||||||
|
|
||||||
|
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
|
||||||
|
"""
|
||||||
|
qpos: batch, qpos_dim
|
||||||
|
image: batch, num_cam, channel, height, width
|
||||||
|
env_state: None
|
||||||
|
actions: batch, seq, action_dim
|
||||||
|
"""
|
||||||
|
is_training = actions is not None # train or val
|
||||||
|
bs, _ = qpos.shape
|
||||||
|
### Obtain latent z from action sequence
|
||||||
|
if self.vae and is_training:
|
||||||
|
# project action sequence to embedding dim, and concat with a CLS token
|
||||||
|
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
|
||||||
|
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||||
|
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
||||||
|
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
||||||
|
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
|
||||||
|
encoder_input = torch.cat(
|
||||||
|
[cls_embed, qpos_embed, action_embed], axis=1
|
||||||
|
) # (bs, seq+1, hidden_dim)
|
||||||
|
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
|
||||||
|
# do not mask cls token
|
||||||
|
# cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
|
||||||
|
# is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||||
|
# obtain position embedding
|
||||||
|
pos_embed = self.pos_table.clone().detach()
|
||||||
|
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||||
|
# query model
|
||||||
|
encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad)
|
||||||
|
encoder_output = encoder_output[0] # take cls output only
|
||||||
|
latent_info = self.latent_proj(encoder_output)
|
||||||
|
mu = latent_info[:, : self.latent_dim]
|
||||||
|
logvar = latent_info[:, self.latent_dim :]
|
||||||
|
latent_sample = reparametrize(mu, logvar)
|
||||||
|
latent_input = self.latent_out_proj(latent_sample)
|
||||||
|
else:
|
||||||
|
mu = logvar = None
|
||||||
|
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
|
||||||
|
latent_input = self.latent_out_proj(latent_sample)
|
||||||
|
|
||||||
|
if self.backbones is not None:
|
||||||
|
# Image observation features and position embeddings
|
||||||
|
all_cam_features = []
|
||||||
|
all_cam_pos = []
|
||||||
|
for cam_id, _ in enumerate(self.camera_names):
|
||||||
|
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||||
|
features = features[0] # take the last layer feature
|
||||||
|
pos = pos[0]
|
||||||
|
all_cam_features.append(self.input_proj(features))
|
||||||
|
all_cam_pos.append(pos)
|
||||||
|
# proprioception features
|
||||||
|
proprio_input = self.input_proj_robot_state(qpos)
|
||||||
|
# fold camera dimension into width dimension
|
||||||
|
src = torch.cat(all_cam_features, axis=3)
|
||||||
|
pos = torch.cat(all_cam_pos, axis=3)
|
||||||
|
hs = self.transformer(
|
||||||
|
src,
|
||||||
|
None,
|
||||||
|
self.query_embed.weight,
|
||||||
|
pos,
|
||||||
|
latent_input,
|
||||||
|
proprio_input,
|
||||||
|
self.additional_pos_embed.weight,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
qpos = self.input_proj_robot_state(qpos)
|
||||||
|
env_state = self.input_proj_env_state(env_state)
|
||||||
|
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
|
||||||
|
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
|
||||||
|
a_hat = self.action_head(hs)
|
||||||
|
is_pad_hat = self.is_pad_head(hs)
|
||||||
|
return a_hat, is_pad_hat, [mu, logvar]
|
||||||
|
|
||||||
|
|
||||||
|
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||||
|
if hidden_depth == 0:
|
||||||
|
mods = [nn.Linear(input_dim, output_dim)]
|
||||||
|
else:
|
||||||
|
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||||
|
for _ in range(hidden_depth - 1):
|
||||||
|
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||||
|
mods.append(nn.Linear(hidden_dim, output_dim))
|
||||||
|
trunk = nn.Sequential(*mods)
|
||||||
|
return trunk
|
||||||
|
|
||||||
|
|
||||||
|
def build_encoder(args):
|
||||||
|
d_model = args.hidden_dim # 256
|
||||||
|
dropout = args.dropout # 0.1
|
||||||
|
nhead = args.nheads # 8
|
||||||
|
dim_feedforward = args.dim_feedforward # 2048
|
||||||
|
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
|
||||||
|
normalize_before = args.pre_norm # False
|
||||||
|
activation = "relu"
|
||||||
|
|
||||||
|
encoder_layer = TransformerEncoderLayer(
|
||||||
|
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||||
|
)
|
||||||
|
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||||
|
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||||
|
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def build(args):
|
||||||
|
# From state
|
||||||
|
# backbone = None # from state for now, no need for conv nets
|
||||||
|
# From image
|
||||||
|
backbones = []
|
||||||
|
backbone = build_backbone(args)
|
||||||
|
backbones.append(backbone)
|
||||||
|
|
||||||
|
transformer = build_transformer(args)
|
||||||
|
|
||||||
|
encoder = build_encoder(args)
|
||||||
|
|
||||||
|
model = DETRVAE(
|
||||||
|
backbones,
|
||||||
|
transformer,
|
||||||
|
encoder,
|
||||||
|
state_dim=args.state_dim,
|
||||||
|
action_dim=args.action_dim,
|
||||||
|
num_queries=args.num_queries,
|
||||||
|
camera_names=args.camera_names,
|
||||||
|
vae=args.vae,
|
||||||
|
)
|
||||||
|
|
||||||
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
|
||||||
|
|
||||||
|
return model
|
||||||
@@ -1,606 +0,0 @@
|
|||||||
"""Action Chunking Transformer Policy
|
|
||||||
|
|
||||||
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
|
||||||
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
from collections import deque
|
|
||||||
from itertools import chain
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import einops
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F # noqa: N812
|
|
||||||
import torchvision
|
|
||||||
from torch import Tensor, nn
|
|
||||||
from torchvision.models._utils import IntermediateLayerGetter
|
|
||||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
|
||||||
|
|
||||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
|
||||||
|
|
||||||
|
|
||||||
class ActionChunkingTransformerPolicy(nn.Module):
|
|
||||||
"""
|
|
||||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
|
||||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
|
||||||
|
|
||||||
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
|
|
||||||
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
|
|
||||||
model that encodes the target data (a sequence of actions), and the condition (the robot
|
|
||||||
joint-space).
|
|
||||||
- A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
|
|
||||||
cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
|
|
||||||
have an option to train this model without the variational objective (in which case we drop the
|
|
||||||
`vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
|
|
||||||
|
|
||||||
Transformer
|
|
||||||
Used alone for inference
|
|
||||||
(acts as VAE decoder
|
|
||||||
during training)
|
|
||||||
┌───────────────────────┐
|
|
||||||
│ Outputs │
|
|
||||||
│ ▲ │
|
|
||||||
│ ┌─────►┌───────┐ │
|
|
||||||
┌──────┐ │ │ │Transf.│ │
|
|
||||||
│ │ │ ├─────►│decoder│ │
|
|
||||||
┌────┴────┐ │ │ │ │ │ │
|
|
||||||
│ │ │ │ ┌───┴───┬─►│ │ │
|
|
||||||
│ VAE │ │ │ │ │ └───────┘ │
|
|
||||||
│ encoder │ │ │ │Transf.│ │
|
|
||||||
│ │ │ │ │encoder│ │
|
|
||||||
└───▲─────┘ │ │ │ │ │
|
|
||||||
│ │ │ └───▲───┘ │
|
|
||||||
│ │ │ │ │
|
|
||||||
inputs └─────┼─────┘ │
|
|
||||||
│ │
|
|
||||||
└───────────────────────┘
|
|
||||||
"""
|
|
||||||
|
|
||||||
name = "act"
|
|
||||||
|
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
|
||||||
configuration class is used.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
if cfg is None:
|
|
||||||
cfg = ActionChunkingTransformerConfig()
|
|
||||||
self.cfg = cfg
|
|
||||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
|
||||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
|
||||||
|
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
|
||||||
if self.cfg.use_vae:
|
|
||||||
self.vae_encoder = _TransformerEncoder(cfg)
|
|
||||||
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
|
|
||||||
# Projection layer for joint-space configuration to hidden dimension.
|
|
||||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
|
||||||
cfg.input_shapes["observation.state"][0], cfg.d_model
|
|
||||||
)
|
|
||||||
# Projection layer for action (joint-space target) to hidden dimension.
|
|
||||||
self.vae_encoder_action_input_proj = nn.Linear(
|
|
||||||
cfg.input_shapes["observation.state"][0], cfg.d_model
|
|
||||||
)
|
|
||||||
self.latent_dim = cfg.latent_dim
|
|
||||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
|
||||||
self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2)
|
|
||||||
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
|
|
||||||
# dimension.
|
|
||||||
self.register_buffer(
|
|
||||||
"vae_encoder_pos_enc",
|
|
||||||
_create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Backbone for image feature extraction.
|
|
||||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
|
||||||
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
|
||||||
pretrained=cfg.use_pretrained_backbone,
|
|
||||||
norm_layer=FrozenBatchNorm2d,
|
|
||||||
)
|
|
||||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
|
|
||||||
# map).
|
|
||||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
|
||||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
|
||||||
|
|
||||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
|
||||||
self.encoder = _TransformerEncoder(cfg)
|
|
||||||
self.decoder = _TransformerDecoder(cfg)
|
|
||||||
|
|
||||||
# Transformer encoder input projections. The tokens will be structured like
|
|
||||||
# [latent, robot_state, image_feature_map_pixels].
|
|
||||||
self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model)
|
|
||||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model)
|
|
||||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
|
||||||
backbone_model.fc.in_features, cfg.d_model, kernel_size=1
|
|
||||||
)
|
|
||||||
# Transformer encoder positional embeddings.
|
|
||||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model)
|
|
||||||
self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2)
|
|
||||||
|
|
||||||
# Transformer decoder.
|
|
||||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
|
||||||
self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model)
|
|
||||||
|
|
||||||
# Final action regression head on the output of the transformer's decoder.
|
|
||||||
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
|
|
||||||
|
|
||||||
self._reset_parameters()
|
|
||||||
self._create_optimizer()
|
|
||||||
|
|
||||||
def _create_optimizer(self):
|
|
||||||
optimizer_params_dicts = [
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad
|
|
||||||
],
|
|
||||||
"lr": self.cfg.lr_backbone,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
self.optimizer = torch.optim.AdamW(
|
|
||||||
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
|
|
||||||
)
|
|
||||||
|
|
||||||
def _reset_parameters(self):
|
|
||||||
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
|
||||||
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
|
|
||||||
if p.dim() > 1:
|
|
||||||
nn.init.xavier_uniform_(p)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""This should be called whenever the environment is reset."""
|
|
||||||
if self.cfg.n_action_steps is not None:
|
|
||||||
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
|
||||||
|
|
||||||
@torch.no_grad
|
|
||||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
|
||||||
"""Select a single action given environment observations.
|
|
||||||
|
|
||||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
|
||||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
|
||||||
queue is empty.
|
|
||||||
"""
|
|
||||||
self.eval()
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
|
||||||
|
|
||||||
if len(self._action_queue) == 0:
|
|
||||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
|
||||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
|
||||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
|
||||||
|
|
||||||
# TODO(rcadene): make _forward return output dictionary?
|
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
|
||||||
|
|
||||||
self._action_queue.extend(actions.transpose(0, 1))
|
|
||||||
return self._action_queue.popleft()
|
|
||||||
|
|
||||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
|
|
||||||
|
|
||||||
l1_loss = (
|
|
||||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
|
||||||
).mean()
|
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss}
|
|
||||||
if self.cfg.use_vae:
|
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
|
||||||
# KL-divergence per batch element, then take the mean over the batch.
|
|
||||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
|
||||||
mean_kld = (
|
|
||||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
|
||||||
)
|
|
||||||
loss_dict["kld_loss"] = mean_kld
|
|
||||||
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
|
|
||||||
else:
|
|
||||||
loss_dict["loss"] = l1_loss
|
|
||||||
|
|
||||||
return loss_dict
|
|
||||||
|
|
||||||
def update(self, batch, **_) -> dict:
|
|
||||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
|
||||||
start_time = time.time()
|
|
||||||
self.train()
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
|
||||||
|
|
||||||
loss_dict = self.forward(batch)
|
|
||||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
|
||||||
loss = loss_dict["loss"]
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
||||||
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer.step()
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"loss": loss.item(),
|
|
||||||
"grad_norm": float(grad_norm),
|
|
||||||
"lr": self.cfg.lr,
|
|
||||||
"update_s": time.time() - start_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
||||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
|
||||||
|
|
||||||
This function expects `batch` to have (at least):
|
|
||||||
{
|
|
||||||
"observation.state": (B, state_dim) batch of robot states.
|
|
||||||
"observation.images.{name}": (B, C, H, W) tensor of images.
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
# Stack images in the order dictated by input_shapes.
|
|
||||||
batch["observation.images"] = torch.stack(
|
|
||||||
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
|
|
||||||
dim=-4,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
|
||||||
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
|
||||||
|
|
||||||
`batch` should have the following structure:
|
|
||||||
|
|
||||||
{
|
|
||||||
"observation.state": (B, state_dim) batch of robot states.
|
|
||||||
"observation.images": (B, n_cameras, C, H, W) batch of images.
|
|
||||||
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
|
||||||
}
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(B, chunk_size, action_dim) batch of action sequences
|
|
||||||
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
|
|
||||||
latent dimension.
|
|
||||||
"""
|
|
||||||
if self.cfg.use_vae and self.training:
|
|
||||||
assert (
|
|
||||||
"action" in batch
|
|
||||||
), "actions must be provided when using the variational objective in training mode."
|
|
||||||
|
|
||||||
self._stack_images(batch)
|
|
||||||
|
|
||||||
batch_size = batch["observation.state"].shape[0]
|
|
||||||
|
|
||||||
# Prepare the latent for input to the transformer encoder.
|
|
||||||
if self.cfg.use_vae and "action" in batch:
|
|
||||||
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
|
||||||
cls_embed = einops.repeat(
|
|
||||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
|
||||||
) # (B, 1, D)
|
|
||||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
|
|
||||||
1
|
|
||||||
) # (B, 1, D)
|
|
||||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
|
||||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
|
||||||
|
|
||||||
# Prepare fixed positional embedding.
|
|
||||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
|
||||||
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
|
||||||
|
|
||||||
# Forward pass through VAE encoder to get the latent PDF parameters.
|
|
||||||
cls_token_out = self.vae_encoder(
|
|
||||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
|
||||||
)[0] # select the class token, with shape (B, D)
|
|
||||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
|
||||||
mu = latent_pdf_params[:, : self.latent_dim]
|
|
||||||
# This is 2log(sigma). Done this way to match the original implementation.
|
|
||||||
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
|
||||||
|
|
||||||
# Sample the latent with the reparameterization trick.
|
|
||||||
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
|
||||||
else:
|
|
||||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
|
||||||
mu = log_sigma_x2 = None
|
|
||||||
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
|
||||||
batch["observation.state"].device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare all other transformer encoder inputs.
|
|
||||||
# Camera observation features and positional embeddings.
|
|
||||||
all_cam_features = []
|
|
||||||
all_cam_pos_embeds = []
|
|
||||||
images = batch["observation.images"]
|
|
||||||
for cam_index in range(images.shape[-4]):
|
|
||||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
|
||||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
|
||||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
|
||||||
all_cam_features.append(cam_features)
|
|
||||||
all_cam_pos_embeds.append(cam_pos_embed)
|
|
||||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
|
|
||||||
encoder_in = torch.cat(all_cam_features, axis=3)
|
|
||||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
|
|
||||||
|
|
||||||
# Get positional embeddings for robot state and latent.
|
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
|
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample)
|
|
||||||
|
|
||||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
|
||||||
encoder_in = torch.cat(
|
|
||||||
[
|
|
||||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
|
||||||
encoder_in.flatten(2).permute(2, 0, 1),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
pos_embed = torch.cat(
|
|
||||||
[
|
|
||||||
self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1),
|
|
||||||
cam_pos_embed.flatten(2).permute(2, 0, 1),
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Forward pass through the transformer modules.
|
|
||||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
|
||||||
decoder_in = torch.zeros(
|
|
||||||
(self.cfg.chunk_size, batch_size, self.cfg.d_model),
|
|
||||||
dtype=pos_embed.dtype,
|
|
||||||
device=pos_embed.device,
|
|
||||||
)
|
|
||||||
decoder_out = self.decoder(
|
|
||||||
decoder_in,
|
|
||||||
encoder_out,
|
|
||||||
encoder_pos_embed=pos_embed,
|
|
||||||
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Move back to (B, S, C).
|
|
||||||
decoder_out = decoder_out.transpose(0, 1)
|
|
||||||
|
|
||||||
actions = self.action_head(decoder_out)
|
|
||||||
|
|
||||||
return actions, (mu, log_sigma_x2)
|
|
||||||
|
|
||||||
def save(self, fp):
|
|
||||||
torch.save(self.state_dict(), fp)
|
|
||||||
|
|
||||||
def load(self, fp):
|
|
||||||
d = torch.load(fp)
|
|
||||||
self.load_state_dict(d)
|
|
||||||
|
|
||||||
|
|
||||||
class _TransformerEncoder(nn.Module):
|
|
||||||
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
|
||||||
|
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
|
|
||||||
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(x, pos_embed=pos_embed)
|
|
||||||
x = self.norm(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class _TransformerEncoderLayer(nn.Module):
|
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
|
||||||
|
|
||||||
# Feed forward layers.
|
|
||||||
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
|
|
||||||
self.dropout = nn.Dropout(cfg.dropout)
|
|
||||||
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(cfg.d_model)
|
|
||||||
self.norm2 = nn.LayerNorm(cfg.d_model)
|
|
||||||
self.dropout1 = nn.Dropout(cfg.dropout)
|
|
||||||
self.dropout2 = nn.Dropout(cfg.dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(cfg.feedforward_activation)
|
|
||||||
self.pre_norm = cfg.pre_norm
|
|
||||||
|
|
||||||
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
|
|
||||||
skip = x
|
|
||||||
if self.pre_norm:
|
|
||||||
x = self.norm1(x)
|
|
||||||
q = k = x if pos_embed is None else x + pos_embed
|
|
||||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
|
||||||
x = skip + self.dropout1(x)
|
|
||||||
if self.pre_norm:
|
|
||||||
skip = x
|
|
||||||
x = self.norm2(x)
|
|
||||||
else:
|
|
||||||
x = self.norm1(x)
|
|
||||||
skip = x
|
|
||||||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
|
||||||
x = skip + self.dropout2(x)
|
|
||||||
if not self.pre_norm:
|
|
||||||
x = self.norm2(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class _TransformerDecoder(nn.Module):
|
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
|
||||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
|
||||||
super().__init__()
|
|
||||||
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
|
|
||||||
self.norm = nn.LayerNorm(cfg.d_model)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: Tensor,
|
|
||||||
encoder_out: Tensor,
|
|
||||||
decoder_pos_embed: Tensor | None = None,
|
|
||||||
encoder_pos_embed: Tensor | None = None,
|
|
||||||
) -> Tensor:
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(
|
|
||||||
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
|
||||||
)
|
|
||||||
if self.norm is not None:
|
|
||||||
x = self.norm(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class _TransformerDecoderLayer(nn.Module):
|
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
|
||||||
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
|
||||||
|
|
||||||
# Feed forward layers.
|
|
||||||
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
|
|
||||||
self.dropout = nn.Dropout(cfg.dropout)
|
|
||||||
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(cfg.d_model)
|
|
||||||
self.norm2 = nn.LayerNorm(cfg.d_model)
|
|
||||||
self.norm3 = nn.LayerNorm(cfg.d_model)
|
|
||||||
self.dropout1 = nn.Dropout(cfg.dropout)
|
|
||||||
self.dropout2 = nn.Dropout(cfg.dropout)
|
|
||||||
self.dropout3 = nn.Dropout(cfg.dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(cfg.feedforward_activation)
|
|
||||||
self.pre_norm = cfg.pre_norm
|
|
||||||
|
|
||||||
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
|
|
||||||
return tensor if pos_embed is None else tensor + pos_embed
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: Tensor,
|
|
||||||
encoder_out: Tensor,
|
|
||||||
decoder_pos_embed: Tensor | None = None,
|
|
||||||
encoder_pos_embed: Tensor | None = None,
|
|
||||||
) -> Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
|
|
||||||
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
|
|
||||||
cross-attending with.
|
|
||||||
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
|
|
||||||
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
|
|
||||||
Returns:
|
|
||||||
(DS, B, C) tensor of decoder output features.
|
|
||||||
"""
|
|
||||||
skip = x
|
|
||||||
if self.pre_norm:
|
|
||||||
x = self.norm1(x)
|
|
||||||
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
|
|
||||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
|
||||||
x = skip + self.dropout1(x)
|
|
||||||
if self.pre_norm:
|
|
||||||
skip = x
|
|
||||||
x = self.norm2(x)
|
|
||||||
else:
|
|
||||||
x = self.norm1(x)
|
|
||||||
skip = x
|
|
||||||
x = self.multihead_attn(
|
|
||||||
query=self.maybe_add_pos_embed(x, decoder_pos_embed),
|
|
||||||
key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
|
|
||||||
value=encoder_out,
|
|
||||||
)[0] # select just the output, not the attention weights
|
|
||||||
x = skip + self.dropout2(x)
|
|
||||||
if self.pre_norm:
|
|
||||||
skip = x
|
|
||||||
x = self.norm3(x)
|
|
||||||
else:
|
|
||||||
x = self.norm2(x)
|
|
||||||
skip = x
|
|
||||||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
|
||||||
x = skip + self.dropout3(x)
|
|
||||||
if not self.pre_norm:
|
|
||||||
x = self.norm3(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
|
|
||||||
"""1D sinusoidal positional embeddings as in Attention is All You Need.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_positions: Number of token positions required.
|
|
||||||
Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension).
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_position_angle_vec(position):
|
|
||||||
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
|
|
||||||
|
|
||||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
|
|
||||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
|
||||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
|
||||||
return torch.from_numpy(sinusoid_table).float()
|
|
||||||
|
|
||||||
|
|
||||||
class _SinusoidalPositionEmbedding2D(nn.Module):
|
|
||||||
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
|
|
||||||
|
|
||||||
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
|
|
||||||
for the vertical direction, and 1/W for the horizontal direction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dimension: int):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
dimension: The desired dimension of the embeddings.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.dimension = dimension
|
|
||||||
self._two_pi = 2 * math.pi
|
|
||||||
self._eps = 1e-6
|
|
||||||
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
|
|
||||||
self._temperature = 10000
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.
|
|
||||||
Returns:
|
|
||||||
A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
|
|
||||||
"""
|
|
||||||
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
|
|
||||||
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
|
|
||||||
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
|
|
||||||
y_range = not_mask.cumsum(1, dtype=torch.float32)
|
|
||||||
x_range = not_mask.cumsum(2, dtype=torch.float32)
|
|
||||||
|
|
||||||
# "Normalize" the position index such that it ranges in [0, 2π].
|
|
||||||
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
|
|
||||||
# are non-zero by construction. This is an artifact of the original code.
|
|
||||||
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
|
|
||||||
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
|
|
||||||
|
|
||||||
inverse_frequency = self._temperature ** (
|
|
||||||
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
|
|
||||||
)
|
|
||||||
|
|
||||||
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
|
||||||
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
|
||||||
|
|
||||||
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
|
|
||||||
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
|
|
||||||
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
|
|
||||||
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
|
|
||||||
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
|
|
||||||
|
|
||||||
return pos_embed
|
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_fn(activation: str) -> Callable:
|
|
||||||
"""Return an activation function given a string."""
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
if activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
if activation == "glu":
|
|
||||||
return F.glu
|
|
||||||
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
|
|
||||||
216
lerobot/common/policies/act/policy.py
Normal file
216
lerobot/common/policies/act/policy.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
|
from lerobot.common.policies.act.detr_vae import build
|
||||||
|
from lerobot.common.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
|
||||||
|
def build_act_model_and_optimizer(cfg):
|
||||||
|
model = build(cfg)
|
||||||
|
|
||||||
|
param_dicts = [
|
||||||
|
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
||||||
|
"lr": cfg.lr_backbone,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
|
||||||
|
|
||||||
|
return model, optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def kl_divergence(mu, logvar):
|
||||||
|
batch_size = mu.size(0)
|
||||||
|
assert batch_size != 0
|
||||||
|
if mu.data.ndimension() == 4:
|
||||||
|
mu = mu.view(mu.size(0), mu.size(1))
|
||||||
|
if logvar.data.ndimension() == 4:
|
||||||
|
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
||||||
|
|
||||||
|
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
||||||
|
total_kld = klds.sum(1).mean(0, True)
|
||||||
|
dimension_wise_kld = klds.mean(0)
|
||||||
|
mean_kld = klds.mean(1).mean(0, True)
|
||||||
|
|
||||||
|
return total_kld, dimension_wise_kld, mean_kld
|
||||||
|
|
||||||
|
|
||||||
|
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||||
|
name = "act"
|
||||||
|
|
||||||
|
def __init__(self, cfg, device, n_action_steps=1):
|
||||||
|
super().__init__(n_action_steps)
|
||||||
|
self.cfg = cfg
|
||||||
|
self.n_action_steps = n_action_steps
|
||||||
|
self.device = get_safe_torch_device(device)
|
||||||
|
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
||||||
|
self.kl_weight = self.cfg.kl_weight
|
||||||
|
logging.info(f"KL Weight {self.kl_weight}")
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
def update(self, replay_buffer, step):
|
||||||
|
del step
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
self.train()
|
||||||
|
|
||||||
|
num_slices = self.cfg.batch_size
|
||||||
|
batch_size = self.cfg.horizon * num_slices
|
||||||
|
|
||||||
|
assert batch_size % self.cfg.horizon == 0
|
||||||
|
assert batch_size % num_slices == 0
|
||||||
|
|
||||||
|
def process_batch(batch, horizon, num_slices):
|
||||||
|
# trajectory t = 64, horizon h = 16
|
||||||
|
# (t h) ... -> t h ...
|
||||||
|
batch = batch.reshape(num_slices, horizon)
|
||||||
|
|
||||||
|
image = batch["observation", "image", "top"]
|
||||||
|
image = image[:, 0] # first observation t=0
|
||||||
|
# batch, num_cam, channel, height, width
|
||||||
|
image = image.unsqueeze(1)
|
||||||
|
assert image.ndim == 5
|
||||||
|
image = image.float()
|
||||||
|
|
||||||
|
state = batch["observation", "state"]
|
||||||
|
state = state[:, 0] # first observation t=0
|
||||||
|
# batch, qpos_dim
|
||||||
|
assert state.ndim == 2
|
||||||
|
|
||||||
|
action = batch["action"]
|
||||||
|
# batch, seq, action_dim
|
||||||
|
assert action.ndim == 3
|
||||||
|
assert action.shape[1] == horizon
|
||||||
|
|
||||||
|
if self.cfg.n_obs_steps > 1:
|
||||||
|
raise NotImplementedError()
|
||||||
|
# # keep first n observations of the slice corresponding to t=[-1,0]
|
||||||
|
# image = image[:, : self.cfg.n_obs_steps]
|
||||||
|
# state = state[:, : self.cfg.n_obs_steps]
|
||||||
|
|
||||||
|
out = {
|
||||||
|
"obs": {
|
||||||
|
"image": image.to(self.device, non_blocking=True),
|
||||||
|
"agent_pos": state.to(self.device, non_blocking=True),
|
||||||
|
},
|
||||||
|
"action": action.to(self.device, non_blocking=True),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
|
||||||
|
batch = replay_buffer.sample(batch_size)
|
||||||
|
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||||
|
|
||||||
|
data_s = time.time() - start_time
|
||||||
|
|
||||||
|
loss = self.compute_loss(batch)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(),
|
||||||
|
self.cfg.grad_clip_norm,
|
||||||
|
error_if_nonfinite=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
# self.lr_scheduler.step()
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"loss": loss.item(),
|
||||||
|
"grad_norm": float(grad_norm),
|
||||||
|
# "lr": self.lr_scheduler.get_last_lr()[0],
|
||||||
|
"lr": self.cfg.lr,
|
||||||
|
"data_s": data_s,
|
||||||
|
"update_s": time.time() - start_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
def save(self, fp):
|
||||||
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
|
def load(self, fp, device=None):
|
||||||
|
d = torch.load(fp, map_location=device)
|
||||||
|
self.load_state_dict(d)
|
||||||
|
|
||||||
|
def compute_loss(self, batch):
|
||||||
|
loss_dict = self._forward(
|
||||||
|
qpos=batch["obs"]["agent_pos"],
|
||||||
|
image=batch["obs"]["image"],
|
||||||
|
actions=batch["action"],
|
||||||
|
)
|
||||||
|
loss = loss_dict["loss"]
|
||||||
|
return loss
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_actions(self, observation, step_count):
|
||||||
|
if observation["image"].shape[0] != 1:
|
||||||
|
raise NotImplementedError("Batch size > 1 not handled")
|
||||||
|
|
||||||
|
# TODO(rcadene): remove unused step_count
|
||||||
|
del step_count
|
||||||
|
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
# TODO(rcadene): remove hack
|
||||||
|
# add 1 camera dimension
|
||||||
|
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
|
||||||
|
|
||||||
|
obs_dict = {
|
||||||
|
"image": observation["image", "top"],
|
||||||
|
"agent_pos": observation["state"],
|
||||||
|
}
|
||||||
|
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
|
||||||
|
|
||||||
|
if self.cfg.temporal_agg:
|
||||||
|
# TODO(rcadene): implement temporal aggregation
|
||||||
|
raise NotImplementedError()
|
||||||
|
# all_time_actions[[t], t:t+num_queries] = action
|
||||||
|
# actions_for_curr_step = all_time_actions[:, t]
|
||||||
|
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
|
||||||
|
# actions_for_curr_step = actions_for_curr_step[actions_populated]
|
||||||
|
# k = 0.01
|
||||||
|
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
||||||
|
# exp_weights = exp_weights / exp_weights.sum()
|
||||||
|
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||||
|
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
# take first predicted action or n first actions
|
||||||
|
action = action[: self.n_action_steps]
|
||||||
|
return action
|
||||||
|
|
||||||
|
def _forward(self, qpos, image, actions=None, is_pad=None):
|
||||||
|
env_state = None
|
||||||
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
image = normalize(image)
|
||||||
|
|
||||||
|
is_training = actions is not None
|
||||||
|
if is_training: # training time
|
||||||
|
actions = actions[:, : self.model.num_queries]
|
||||||
|
if is_pad is not None:
|
||||||
|
is_pad = is_pad[:, : self.model.num_queries]
|
||||||
|
|
||||||
|
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||||
|
|
||||||
|
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||||
|
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||||
|
|
||||||
|
loss_dict = {}
|
||||||
|
loss_dict["l1"] = l1
|
||||||
|
if self.cfg.vae:
|
||||||
|
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||||
|
loss_dict["kl"] = total_kld[0]
|
||||||
|
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||||
|
else:
|
||||||
|
loss_dict["loss"] = loss_dict["l1"]
|
||||||
|
return loss_dict
|
||||||
|
else:
|
||||||
|
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||||
|
return action
|
||||||
102
lerobot/common/policies/act/position_encoding.py
Normal file
102
lerobot/common/policies/act/position_encoding.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""
|
||||||
|
Various positional encodings for the transformer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from .utils import NestedTensor
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEmbeddingSine(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a more standard version of the position embedding, very similar to the one
|
||||||
|
used by the Attention is all you need paper, generalized to work on images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_pos_feats = num_pos_feats
|
||||||
|
self.temperature = temperature
|
||||||
|
self.normalize = normalize
|
||||||
|
if scale is not None and normalize is False:
|
||||||
|
raise ValueError("normalize should be True if scale is passed")
|
||||||
|
if scale is None:
|
||||||
|
scale = 2 * math.pi
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, tensor):
|
||||||
|
x = tensor
|
||||||
|
# mask = tensor_list.mask
|
||||||
|
# assert mask is not None
|
||||||
|
# not_mask = ~mask
|
||||||
|
|
||||||
|
not_mask = torch.ones_like(x[0, [0]])
|
||||||
|
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||||
|
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||||
|
if self.normalize:
|
||||||
|
eps = 1e-6
|
||||||
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||||
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||||
|
|
||||||
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||||
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||||
|
|
||||||
|
pos_x = x_embed[:, :, :, None] / dim_t
|
||||||
|
pos_y = y_embed[:, :, :, None] / dim_t
|
||||||
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||||
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||||
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEmbeddingLearned(nn.Module):
|
||||||
|
"""
|
||||||
|
Absolute pos embedding, learned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_pos_feats=256):
|
||||||
|
super().__init__()
|
||||||
|
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||||
|
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.uniform_(self.row_embed.weight)
|
||||||
|
nn.init.uniform_(self.col_embed.weight)
|
||||||
|
|
||||||
|
def forward(self, tensor_list: NestedTensor):
|
||||||
|
x = tensor_list.tensors
|
||||||
|
h, w = x.shape[-2:]
|
||||||
|
i = torch.arange(w, device=x.device)
|
||||||
|
j = torch.arange(h, device=x.device)
|
||||||
|
x_emb = self.col_embed(i)
|
||||||
|
y_emb = self.row_embed(j)
|
||||||
|
pos = (
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||||
|
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
.permute(2, 0, 1)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.repeat(x.shape[0], 1, 1, 1)
|
||||||
|
)
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
def build_position_encoding(args):
|
||||||
|
n_steps = args.hidden_dim // 2
|
||||||
|
if args.position_embedding in ("v2", "sine"):
|
||||||
|
# TODO find a better way of exposing other arguments
|
||||||
|
position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
|
||||||
|
elif args.position_embedding in ("v3", "learned"):
|
||||||
|
position_embedding = PositionEmbeddingLearned(n_steps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"not supported {args.position_embedding}")
|
||||||
|
|
||||||
|
return position_embedding
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user