Compare commits
11 Commits
qgallouede
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7bf36cd413 | ||
|
|
b49f7b70e2 | ||
|
|
f1230cdac0 | ||
|
|
5395829596 | ||
|
|
a45802c281 | ||
|
|
167a51cb69 | ||
|
|
fbc66a082b | ||
|
|
603455e313 | ||
|
|
6500945be5 | ||
|
|
ebbcad8c05 | ||
|
|
d98b435b4c |
142
.dockerignore
142
.dockerignore
@@ -1,142 +0,0 @@
|
||||
# Misc
|
||||
.git
|
||||
tmp
|
||||
wandb
|
||||
data
|
||||
outputs
|
||||
.vscode
|
||||
rl
|
||||
media
|
||||
|
||||
|
||||
# Logging
|
||||
logs
|
||||
|
||||
# HPC
|
||||
nautilus/*.yaml
|
||||
*.key
|
||||
|
||||
# Slurm
|
||||
sbatch*.sh
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
!tests/data
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
2
.gitattributes
vendored
2
.gitattributes
vendored
@@ -1,2 +0,0 @@
|
||||
*.memmap filter=lfs diff=lfs merge=lfs -text
|
||||
*.stl filter=lfs diff=lfs merge=lfs -text
|
||||
54
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
54
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -1,54 +0,0 @@
|
||||
name: "\U0001F41B Bug Report"
|
||||
description: Submit a bug report to help us improve LeRobot
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for taking the time to submit a bug report! 🐛
|
||||
If this is not a bug related to the LeRobot library directly, but instead a general question about your code or the library specifically please use our [discord](https://discord.gg/s3KuuzsPFb).
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: If needed, you can share your lerobot configuration with us by running `python -m lerobot.scripts.display_sys_info` and copy-pasting its outputs below
|
||||
render: Shell
|
||||
placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
id: information-scripts-examples
|
||||
attributes:
|
||||
label: Information
|
||||
description: 'The problem arises when using:'
|
||||
options:
|
||||
- label: "One of the scripts in the examples/ folder of LeRobot"
|
||||
- label: "My own task or dataset (give details below)"
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Reproduction
|
||||
description: |
|
||||
If needed, provide a simple code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
|
||||
Sharing error messages or stack traces could be useful as well!
|
||||
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
||||
Try to avoid screenshots, as they are hard to read and don't allow copy-and-pasting.
|
||||
|
||||
placeholder: |
|
||||
Steps to reproduce the behavior:
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: "A clear and concise description of what you would expect to happen."
|
||||
32
.github/PULL_REQUEST_TEMPLATE.md
vendored
32
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,32 +0,0 @@
|
||||
# What does this PR do?
|
||||
|
||||
Examples:
|
||||
- Fixes # (issue)
|
||||
- Adds new dataset
|
||||
- Optimizes something
|
||||
|
||||
## How was it tested?
|
||||
|
||||
Examples:
|
||||
- Added `test_something` in `tests/test_stuff.py`.
|
||||
- Added `new_feature` and checked that training converges with policy X on dataset/environment Y.
|
||||
- Optimized `some_function`, it now runs X times faster than previously.
|
||||
|
||||
## How to checkout & try? (for the reviewer)
|
||||
|
||||
Examples:
|
||||
```bash
|
||||
DATA_DIR=tests/data pytest -sx tests/test_stuff.py::test_something
|
||||
```
|
||||
```bash
|
||||
python lerobot/scripts/train.py --some.option=true
|
||||
```
|
||||
|
||||
## Before submitting
|
||||
Please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
|
||||
|
||||
|
||||
## Who can review?
|
||||
|
||||
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
|
||||
members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people.
|
||||
30
.github/scripts/dep_build.py
vendored
30
.github/scripts/dep_build.py
vendored
@@ -1,30 +0,0 @@
|
||||
PYPROJECT = "pyproject.toml"
|
||||
DEPS = {
|
||||
"gym-pusht": '{ git = "git@github.com:huggingface/gym-pusht.git", optional = true}',
|
||||
"gym-xarm": '{ git = "git@github.com:huggingface/gym-xarm.git", optional = true}',
|
||||
"gym-aloha": '{ git = "git@github.com:huggingface/gym-aloha.git", optional = true}',
|
||||
}
|
||||
|
||||
|
||||
def update_envs_as_path_dependencies():
|
||||
with open(PYPROJECT) as file:
|
||||
lines = file.readlines()
|
||||
|
||||
new_lines = []
|
||||
for line in lines:
|
||||
if any(dep in line for dep in DEPS.values()):
|
||||
for dep in DEPS:
|
||||
if dep in line:
|
||||
new_line = f'{dep} = {{ path = "envs/{dep}/", optional = true}}\n'
|
||||
new_lines.append(new_line)
|
||||
break
|
||||
|
||||
else:
|
||||
new_lines.append(line)
|
||||
|
||||
with open(PYPROJECT, "w") as file:
|
||||
file.writelines(new_lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
update_envs_as_path_dependencies()
|
||||
203
.github/workflows/build-docker-images.yml
vendored
203
.github/workflows/build-docker-images.yml
vendored
@@ -1,203 +0,0 @@
|
||||
# Inspired by
|
||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
|
||||
name: Builds
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
schedule:
|
||||
- cron: "0 1 * * *"
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
# CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
latest-cpu:
|
||||
name: "Build CPU"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo df -h
|
||||
# sudo ls -l /usr/local/lib/
|
||||
# sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# HACK(aliberts): to be removed for release
|
||||
# -----------------------------------------
|
||||
- name: Checkout gym-aloha
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: huggingface/gym-aloha
|
||||
path: envs/gym-aloha
|
||||
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
|
||||
- name: Checkout gym-xarm
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: huggingface/gym-xarm
|
||||
path: envs/gym-xarm
|
||||
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
|
||||
- name: Checkout gym-pusht
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: huggingface/gym-pusht
|
||||
path: envs/gym-pusht
|
||||
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Change envs dependencies as local path
|
||||
run: python .github/scripts/dep_build.py
|
||||
# -----------------------------------------
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push CPU
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/lerobot-cpu/Dockerfile
|
||||
push: true
|
||||
tags: huggingface/lerobot-cpu
|
||||
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
||||
|
||||
# - name: Post to a Slack channel
|
||||
# id: slack
|
||||
# #uses: slackapi/slack-github-action@v1.25.0
|
||||
# uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||
# with:
|
||||
# # Slack channel id, channel name, or user id to post message.
|
||||
# # See also: https://api.slack.com/methods/chat.postMessage#channels
|
||||
# channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
||||
# # For posting a rich message using Block Kit
|
||||
# payload: |
|
||||
# {
|
||||
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
||||
# "blocks": [
|
||||
# {
|
||||
# "type": "section",
|
||||
# "text": {
|
||||
# "type": "mrkdwn",
|
||||
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# env:
|
||||
# SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
latest-cuda:
|
||||
name: "Build GPU"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo df -h
|
||||
# sudo ls -l /usr/local/lib/
|
||||
# sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo df -h
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# HACK(aliberts): to be removed for release
|
||||
# -----------------------------------------
|
||||
- name: Checkout gym-aloha
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: huggingface/gym-aloha
|
||||
path: envs/gym-aloha
|
||||
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
|
||||
- name: Checkout gym-xarm
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: huggingface/gym-xarm
|
||||
path: envs/gym-xarm
|
||||
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
|
||||
- name: Checkout gym-pusht
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: huggingface/gym-pusht
|
||||
path: envs/gym-pusht
|
||||
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Change envs dependencies as local path
|
||||
run: python .github/scripts/dep_build.py
|
||||
# -----------------------------------------
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/lerobot-gpu/Dockerfile
|
||||
push: true
|
||||
tags: huggingface/lerobot-gpu
|
||||
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
||||
|
||||
# - name: Post to a Slack channel
|
||||
# id: slack
|
||||
# #uses: slackapi/slack-github-action@v1.25.0
|
||||
# uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||
# with:
|
||||
# # Slack channel id, channel name, or user id to post message.
|
||||
# # See also: https://api.slack.com/methods/chat.postMessage#channels
|
||||
# channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
||||
# # For posting a rich message using Block Kit
|
||||
# payload: |
|
||||
# {
|
||||
# "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
||||
# "blocks": [
|
||||
# {
|
||||
# "type": "section",
|
||||
# "text": {
|
||||
# "type": "mrkdwn",
|
||||
# "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# env:
|
||||
# SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
79
.github/workflows/nightly-tests.yml
vendored
79
.github/workflows/nightly-tests.yml
vendored
@@ -1,79 +0,0 @@
|
||||
# Inspired by
|
||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
|
||||
name: Nightly
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
|
||||
|
||||
jobs:
|
||||
run_all_tests_cpu:
|
||||
name: "Test CPU"
|
||||
strategy:
|
||||
fail-fast: false
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: huggingface/lerobot-cpu:latest
|
||||
options: --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Tests
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
run: pytest -v --cov=./lerobot --disable-warnings tests
|
||||
|
||||
- name: Tests end-to-end
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
run: make test-end-to-end
|
||||
|
||||
|
||||
run_all_tests_single_gpu:
|
||||
name: "Test GPU"
|
||||
strategy:
|
||||
fail-fast: false
|
||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
TEST_TYPE: "single_gpu"
|
||||
container:
|
||||
image: huggingface/lerobot-gpu:latest
|
||||
options: --gpus all --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Nvidia-smi
|
||||
run: nvidia-smi
|
||||
|
||||
- name: Test
|
||||
run: pytest -v --cov=./lerobot --cov-report=xml --disable-warnings tests
|
||||
# TODO(aliberts): Link with HF Codecov account
|
||||
# - name: Upload coverage reports to Codecov with GitHub Action
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# files: ./coverage.xml
|
||||
# verbose: true
|
||||
- name: Tests end-to-end
|
||||
run: make test-end-to-end
|
||||
|
||||
# - name: Generate Report
|
||||
# if: always()
|
||||
# run: |
|
||||
# pip install slack_sdk tabulate
|
||||
# python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
38
.github/workflows/style.yml
vendored
38
.github/workflows/style.yml
vendored
@@ -1,38 +0,0 @@
|
||||
name: Style
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
jobs:
|
||||
ruff_check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Get Ruff Version from pre-commit-config.yaml
|
||||
id: get-ruff-version
|
||||
run: |
|
||||
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
|
||||
echo "RUFF_VERSION=${RUFF_VERSION}" >> $GITHUB_ENV
|
||||
|
||||
- name: Install Ruff
|
||||
run: python -m pip install "ruff==${{ env.RUFF_VERSION }}"
|
||||
|
||||
- name: Run Ruff
|
||||
run: ruff check .
|
||||
109
.github/workflows/test-docker-build.yml
vendored
109
.github/workflows/test-docker-build.yml
vendored
@@ -1,109 +0,0 @@
|
||||
# Inspired by
|
||||
# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
|
||||
name: Test Docker builds (PR)
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
# Run only when DockerFile files are modified
|
||||
- "docker/**"
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
jobs:
|
||||
get_changed_files:
|
||||
name: "Get all modified Dockerfiles"
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
with:
|
||||
files: docker/**
|
||||
json: "true"
|
||||
|
||||
- name: Run step if only the files listed above change
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: set-matrix
|
||||
env:
|
||||
ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }}
|
||||
run: |
|
||||
echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT
|
||||
|
||||
|
||||
build_modified_dockerfiles:
|
||||
name: "Build all modified Docker images"
|
||||
needs: get_changed_files
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo df -h
|
||||
# sudo ls -l /usr/local/lib/
|
||||
# sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
# HACK(aliberts): to be removed for release
|
||||
# -----------------------------------------
|
||||
- name: Checkout gym-aloha
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: huggingface/gym-aloha
|
||||
path: envs/gym-aloha
|
||||
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
|
||||
- name: Checkout gym-xarm
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: huggingface/gym-xarm
|
||||
path: envs/gym-xarm
|
||||
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
|
||||
- name: Checkout gym-pusht
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
repository: huggingface/gym-pusht
|
||||
path: envs/gym-pusht
|
||||
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Change envs dependencies as local path
|
||||
run: python .github/scripts/dep_build.py
|
||||
# -----------------------------------------
|
||||
|
||||
- name: Build Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
file: ${{ matrix.docker-file }}
|
||||
context: .
|
||||
push: False
|
||||
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
||||
158
.github/workflows/test.yml
vendored
158
.github/workflows/test.yml
vendored
@@ -1,74 +1,118 @@
|
||||
name: Tests
|
||||
name: Test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "lerobot/**"
|
||||
- "tests/**"
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
types: [opened, synchronize, reopened, labeled]
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "lerobot/**"
|
||||
- "tests/**"
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, macos-latest-large]
|
||||
test:
|
||||
if: |
|
||||
${{ github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'CI') }} ||
|
||||
${{ github.event_name == 'push' }}
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
MUJOCO_GL: egl
|
||||
POETRY_VERSION: 1.8.1
|
||||
steps:
|
||||
- name: Add SSH key for installing envs
|
||||
uses: webfactory/ssh-agent@v0.9.0
|
||||
with:
|
||||
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install EGL
|
||||
run: |
|
||||
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
|
||||
sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||
elif [[ "${{ matrix.os }}" == 'macos-latest' || "${{ matrix.os }}" == 'macos-latest-large' ]]; then
|
||||
brew install mesa
|
||||
fi
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Set up Python 3.10
|
||||
#----------------------------------------------
|
||||
# check-out repo and set-up python
|
||||
#----------------------------------------------
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up python
|
||||
id: setup-python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: "poetry"
|
||||
|
||||
- name: Install poetry dependencies
|
||||
python-version: '3.10'
|
||||
#----------------------------------------------
|
||||
# install & configure poetry
|
||||
#----------------------------------------------
|
||||
- name: Load cached Poetry installation
|
||||
id: restore-poetry-cache
|
||||
uses: actions/cache/restore@v3
|
||||
with:
|
||||
path: ~/.local # the path depends on the OS
|
||||
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache
|
||||
- name: Install Poetry
|
||||
if: steps.restore-poetry-cache.outputs.cache-hit != 'true'
|
||||
uses: snok/install-poetry@v1
|
||||
with:
|
||||
version: ${{ env.POETRY_VERSION }}
|
||||
virtualenvs-create: true
|
||||
installer-parallel: true
|
||||
- name: Save cached Poetry installation
|
||||
if: |
|
||||
steps.restore-poetry-cache.outputs.cache-hit != 'true' &&
|
||||
github.ref_name == 'main'
|
||||
id: save-poetry-cache
|
||||
uses: actions/cache/save@v3
|
||||
with:
|
||||
path: ~/.local # the path depends on the OS
|
||||
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache
|
||||
- name: Configure Poetry
|
||||
run: poetry config virtualenvs.in-project true
|
||||
#----------------------------------------------
|
||||
# install dependencies
|
||||
#----------------------------------------------
|
||||
- name: Load cached venv
|
||||
id: restore-dependencies-cache
|
||||
uses: actions/cache/restore@v3
|
||||
with:
|
||||
path: .venv
|
||||
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
|
||||
- name: Install dependencies
|
||||
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
poetry install --all-extras
|
||||
|
||||
- name: Test with pytest
|
||||
poetry install --no-interaction --no-root
|
||||
git clone https://github.com/real-stanford/diffusion_policy
|
||||
cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/
|
||||
- name: Save cached venv
|
||||
if: |
|
||||
steps.restore-dependencies-cache.outputs.cache-hit != 'true' &&
|
||||
github.ref_name == 'main'
|
||||
id: save-dependencies-cache
|
||||
uses: actions/cache/save@v3
|
||||
with:
|
||||
path: .venv
|
||||
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
|
||||
#----------------------------------------------
|
||||
# install project
|
||||
#----------------------------------------------
|
||||
- name: Install project
|
||||
run: poetry install --no-interaction
|
||||
#----------------------------------------------
|
||||
# run tests
|
||||
#----------------------------------------------
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest tests -v --cov=./lerobot --durations=0 \
|
||||
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
|
||||
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
|
||||
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||
&& rm -rf tests/outputs outputs
|
||||
|
||||
- name: Test end-to-end
|
||||
source .venv/bin/activate
|
||||
pytest tests
|
||||
- name: Test train pusht end-to-end
|
||||
run: |
|
||||
make test-end-to-end \
|
||||
&& rm -rf outputs
|
||||
source .venv/bin/activate
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.job.name=pusht \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
offline_steps=1 \
|
||||
online_steps=0 \
|
||||
device=cpu
|
||||
# TODO(rcadene, aliberts): Add end-to-end test of eval checkpoint post training
|
||||
# - name: Test eval pusht end-to-end
|
||||
# run: |
|
||||
# source .venv/bin/activate
|
||||
# python lerobot/scripts/eval.py
|
||||
# hydra.job.name=pusht \
|
||||
# env=pusht \
|
||||
# wandb.enable=False \
|
||||
# eval_episodes=1 \
|
||||
# device=cpu
|
||||
#----------------------------------------------
|
||||
# cleanup
|
||||
#----------------------------------------------
|
||||
- name: Cleanup
|
||||
run: rm -rf diffusion_policy data
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -1,3 +1,6 @@
|
||||
# Custom
|
||||
diffusion_policy
|
||||
|
||||
# Logging
|
||||
logs
|
||||
tmp
|
||||
@@ -11,9 +14,6 @@ rl
|
||||
nautilus/*.yaml
|
||||
*.key
|
||||
|
||||
# Slurm
|
||||
sbatch*.sh
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
@@ -54,7 +54,6 @@ pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
!tests/data
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
exclude: ^(tests/data)
|
||||
exclude: ^(data/|tests/|diffusion_policy/)
|
||||
default_language_version:
|
||||
python: python3.10
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.6.0
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
- id: debug-statements
|
||||
@@ -14,11 +14,11 @@ repos:
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.15.2
|
||||
rev: v3.15.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.4.2
|
||||
rev: v0.2.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, caste, color, religion, or sexual
|
||||
identity and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the overall
|
||||
community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or advances of
|
||||
any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email address,
|
||||
without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official email address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
[feedback@huggingface.co](mailto:feedback@huggingface.co).
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series of
|
||||
actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or permanent
|
||||
ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within the
|
||||
community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.1, available at
|
||||
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
||||
|
||||
Community Impact Guidelines were inspired by
|
||||
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
|
||||
[https://www.contributor-covenant.org/translations][translations].
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
||||
[Mozilla CoC]: https://github.com/mozilla/diversity
|
||||
[FAQ]: https://www.contributor-covenant.org/faq
|
||||
[translations]: https://www.contributor-covenant.org/translations
|
||||
270
CONTRIBUTING.md
270
CONTRIBUTING.md
@@ -1,270 +0,0 @@
|
||||
# How to contribute to 🤗 LeRobot?
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code
|
||||
is thus not the only way to help the community. Answering questions, helping
|
||||
others, reaching out and improving the documentations are immensely valuable to
|
||||
the community.
|
||||
|
||||
It also helps us if you spread the word: reference the library from blog posts
|
||||
on the awesome projects it made possible, shout out on Twitter when it has
|
||||
helped you, or simply ⭐️ the repo to say "thank you".
|
||||
|
||||
Whichever way you choose to contribute, please be mindful to respect our
|
||||
[code of conduct](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
## You can contribute in so many ways!
|
||||
|
||||
Some of the ways you can contribute to 🤗 LeRobot:
|
||||
* Fixing outstanding issues with the existing code.
|
||||
* Implementing new models, datasets or simulation environments.
|
||||
* Contributing to the examples or to the documentation.
|
||||
* Submitting issues related to bugs or desired new features.
|
||||
|
||||
Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](remi.cadene@huggingface.co).
|
||||
|
||||
If you are not sure how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/orgs/huggingface/projects/46)
|
||||
|
||||
## Submitting a new issue or feature request
|
||||
|
||||
Do your best to follow these guidelines when submitting an issue or a feature
|
||||
request. It will make it easier for us to come back to you quickly and with good
|
||||
feedback.
|
||||
|
||||
### Did you find a bug?
|
||||
|
||||
The 🤗 LeRobot library is robust and reliable thanks to the users who notify us of
|
||||
the problems they encounter. So thank you for reporting an issue.
|
||||
|
||||
First, we would really appreciate it if you could **make sure the bug was not
|
||||
already reported** (use the search bar on Github under Issues).
|
||||
|
||||
Did not find it? :( So we can act quickly on it, please follow these steps:
|
||||
|
||||
* Include your **OS type and version**, the versions of **Python** and **PyTorch**.
|
||||
* A short, self-contained, code snippet that allows us to reproduce the bug in
|
||||
less than 30s.
|
||||
* The full traceback if an exception is raised.
|
||||
* Attach any other additional information, like screenshots, you think may help.
|
||||
|
||||
### Do you want a new feature?
|
||||
|
||||
A good feature request addresses the following points:
|
||||
|
||||
1. Motivation first:
|
||||
* Is it related to a problem/frustration with the library? If so, please explain
|
||||
why. Providing a code snippet that demonstrates the problem is best.
|
||||
* Is it related to something you would need for a project? We'd love to hear
|
||||
about it!
|
||||
* Is it something you worked on and think could benefit the community?
|
||||
Awesome! Tell us what problem it solved for you.
|
||||
2. Write a *paragraph* describing the feature.
|
||||
3. Provide a **code snippet** that demonstrates its future use.
|
||||
4. In case this is related to a paper, please attach a link.
|
||||
5. Attach any additional information (drawings, screenshots, etc.) you think may help.
|
||||
|
||||
If your issue is well written we're already 80% of the way there by the time you
|
||||
post it.
|
||||
|
||||
## Adding new policies, datasets or environments
|
||||
|
||||
Look at our implementations for [datasets](./lerobot/common/datasets/), [policies](./lerobot/common/policies/),
|
||||
environments ([aloha](https://github.com/huggingface/gym-aloha),
|
||||
[xarm](https://github.com/huggingface/gym-xarm),
|
||||
[pusht](https://github.com/huggingface/gym-pusht))
|
||||
and follow the same api design.
|
||||
|
||||
When implementing a new dataset loadable with LeRobotDataset follow these steps:
|
||||
- Update `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
|
||||
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
|
||||
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
|
||||
- Set the required `name` class attribute.
|
||||
- Update variables in `tests/test_available.py` by importing your new Policy class
|
||||
|
||||
## Submitting a pull request (PR)
|
||||
|
||||
Before writing code, we strongly advise you to search through the existing PRs or
|
||||
issues to make sure that nobody is already working on the same thing. If you are
|
||||
unsure, it is always a good idea to open an issue to get some feedback.
|
||||
|
||||
You will need basic `git` proficiency to be able to contribute to
|
||||
🤗 LeRobot. `git` is not the easiest tool to use but it has the greatest
|
||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
|
||||
Follow these steps to start contributing:
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/lerobot) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
under your GitHub user account.
|
||||
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
|
||||
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
|
||||
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
||||
|
||||
```bash
|
||||
git clone git@github.com:<your Github handle>/lerobot.git
|
||||
cd lerobot
|
||||
git remote add upstream https://github.com/huggingface/lerobot.git
|
||||
```
|
||||
|
||||
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
|
||||
|
||||
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
|
||||
|
||||
```bash
|
||||
git checkout main
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
```
|
||||
|
||||
Once your `main` branch is synchronized, create a new branch from it:
|
||||
|
||||
```bash
|
||||
git checkout -b a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
🚨 **Do not** work on the `main` branch.
|
||||
|
||||
4. for development, we use `poetry` instead of just `pip` to easily track our dependencies.
|
||||
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
|
||||
|
||||
Set up a development environment with conda or miniconda:
|
||||
```bash
|
||||
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
|
||||
```
|
||||
|
||||
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
|
||||
```bash
|
||||
poetry install --sync --extras "dev test"
|
||||
```
|
||||
|
||||
You can also install the project with all its dependencies (including environments):
|
||||
```bash
|
||||
poetry install --sync --all-extras
|
||||
```
|
||||
|
||||
> **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they *will* be tested in the CI. In general, we advise you to install everything and test locally before pushing.
|
||||
|
||||
Whichever command you chose to install the project (e.g. `poetry install --sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
|
||||
|
||||
The equivalent of `pip install some-package`, would just be:
|
||||
```bash
|
||||
poetry add some-package
|
||||
```
|
||||
|
||||
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
||||
```bash
|
||||
poetry lock --no-update
|
||||
```
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
passes. You should run the tests impacted by your changes like this (see
|
||||
below an explanation regarding the environment variable):
|
||||
|
||||
```bash
|
||||
pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
6. Follow our style.
|
||||
|
||||
`lerobot` relies on `ruff` to format its source code
|
||||
consistently. Set up [`pre-commit`](https://pre-commit.com/) to run these checks
|
||||
automatically as Git commit hooks.
|
||||
|
||||
Install `pre-commit` hooks:
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
You can run these hooks whenever you need on staged files with:
|
||||
```bash
|
||||
pre-commit
|
||||
```
|
||||
|
||||
Once you're happy with your changes, add changed files using `git add` and
|
||||
make a commit with `git commit` to record your changes locally:
|
||||
|
||||
```bash
|
||||
git add modified_file.py
|
||||
git commit
|
||||
```
|
||||
|
||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
repository regularly. This way you can quickly account for changes:
|
||||
|
||||
```bash
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
```
|
||||
|
||||
Push the changes to your account using:
|
||||
|
||||
```bash
|
||||
git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
6. Once you are satisfied (**and the checklist below is happy too**), go to the
|
||||
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
|
||||
to the project maintainers for review.
|
||||
|
||||
7. It's ok if maintainers ask you for changes. It happens to core contributors
|
||||
too! So everyone can see the changes in the Pull request, work in your local
|
||||
branch and push the changes to your fork. They will automatically appear in
|
||||
the pull request.
|
||||
|
||||
|
||||
### Checklist
|
||||
|
||||
1. The title of your pull request should be a summary of its contribution;
|
||||
2. If your pull request addresses an issue, please mention the issue number in
|
||||
the pull request description to make sure they are linked (and people
|
||||
consulting the issue know you are working on it);
|
||||
3. To indicate a work in progress please prefix the title with `[WIP]`, or preferably mark
|
||||
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
|
||||
it from PRs ready to be merged;
|
||||
4. Make sure existing tests pass;
|
||||
<!-- 5. Add high-coverage tests. No quality testing = no merge.
|
||||
|
||||
See an example of a good PR here: https://github.com/huggingface/lerobot/pull/ -->
|
||||
|
||||
### Tests
|
||||
|
||||
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in the [tests folder](https://github.com/huggingface/lerobot/tree/main/tests).
|
||||
|
||||
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
|
||||
|
||||
On Mac:
|
||||
```bash
|
||||
brew install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
On Ubuntu:
|
||||
```bash
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
Pull artifacts if they're not in [tests/data](tests/data)
|
||||
```bash
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
We use `pytest` in order to run the tests. From the root of the
|
||||
repository, here's how to run tests with `pytest` for the library:
|
||||
|
||||
```bash
|
||||
DATA_DIR="tests/data" python -m pytest -sv ./tests
|
||||
```
|
||||
|
||||
|
||||
You can specify a smaller set of tests in order to test only the feature
|
||||
you're working on.
|
||||
507
LICENSE
507
LICENSE
@@ -1,507 +0,0 @@
|
||||
Copyright 2024 The Hugging Face team. All rights reserved.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
||||
## Some of lerobot's code is derived from Diffusion Policy, which is subject to the following copyright notice:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Columbia Artificial Intelligence and Robotics Lab
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
## Some of lerobot's code is derived from FOWM, which is subject to the following copyright notice:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Yunhai Feng
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
## Some of lerobot's code is derived from simxarm, which is subject to the following copyright notice:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Nicklas Hansen & Yanjie Ze
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
## Some of lerobot's code is derived from ALOHA, which is subject to the following copyright notice:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Tony Z. Zhao
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
## Some of lerobot's code is derived from DETR, which is subject to the following copyright notice:
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2020 - present, Facebook, Inc
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
95
Makefile
95
Makefile
@@ -1,95 +0,0 @@
|
||||
.PHONY: tests
|
||||
|
||||
PYTHON_PATH := $(shell which python)
|
||||
|
||||
# If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python
|
||||
POETRY_CHECK := $(shell command -v poetry)
|
||||
ifneq ($(POETRY_CHECK),)
|
||||
PYTHON_PATH := $(shell poetry run which python)
|
||||
endif
|
||||
|
||||
export PATH := $(dir $(PYTHON_PATH)):$(PATH)
|
||||
|
||||
|
||||
build-cpu:
|
||||
docker build -t lerobot:latest -f docker/lerobot-cpu/Dockerfile .
|
||||
|
||||
build-gpu:
|
||||
docker build -t lerobot:latest -f docker/lerobot-gpu/Dockerfile .
|
||||
|
||||
test-end-to-end:
|
||||
${MAKE} test-act-ete-train
|
||||
${MAKE} test-act-ete-eval
|
||||
${MAKE} test-diffusion-ete-train
|
||||
${MAKE} test-diffusion-ete-eval
|
||||
${MAKE} test-tdmpc-ete-train
|
||||
${MAKE} test-tdmpc-ete-eval
|
||||
|
||||
test-act-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=act \
|
||||
env=aloha \
|
||||
wandb.enable=False \
|
||||
offline_steps=2 \
|
||||
online_steps=0 \
|
||||
eval_episodes=1 \
|
||||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
policy.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/act/
|
||||
|
||||
test-act-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
--config tests/outputs/act/.hydra/config.yaml \
|
||||
eval_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/act/models/2.pt
|
||||
|
||||
test-diffusion-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=diffusion \
|
||||
env=pusht \
|
||||
wandb.enable=False \
|
||||
offline_steps=2 \
|
||||
online_steps=0 \
|
||||
eval_episodes=1 \
|
||||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=2 \
|
||||
policy.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/diffusion/
|
||||
|
||||
test-diffusion-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
--config tests/outputs/diffusion/.hydra/config.yaml \
|
||||
eval_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
|
||||
|
||||
test-tdmpc-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=tdmpc \
|
||||
env=xarm \
|
||||
wandb.enable=False \
|
||||
offline_steps=1 \
|
||||
online_steps=2 \
|
||||
eval_episodes=1 \
|
||||
env.episode_length=2 \
|
||||
device=cpu \
|
||||
save_model=true \
|
||||
save_freq=2 \
|
||||
policy.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
--config tests/outputs/tdmpc/.hydra/config.yaml \
|
||||
eval_episodes=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
|
||||
309
README.md
309
README.md
@@ -1,282 +1,74 @@
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="media/lerobot-logo-thumbnail.png">
|
||||
<source media="(prefers-color-scheme: light)" srcset="media/lerobot-logo-thumbnail.png">
|
||||
<img alt="LeRobot, Hugging Face Robotics Library" src="media/lerobot-logo-thumbnail.png" style="max-width: 100%;">
|
||||
</picture>
|
||||
<br/>
|
||||
<br/>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/nightly-tests.yml?query=branch%3Amain)
|
||||
[](https://codecov.io/gh/huggingface/lerobot)
|
||||
[](https://www.python.org/downloads/)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/LICENSE)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://github.com/huggingface/lerobot/tree/main/examples)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md)
|
||||
[](https://discord.gg/s3KuuzsPFb)
|
||||
|
||||
</div>
|
||||
|
||||
<h3 align="center">
|
||||
<p>State-of-the-art Machine Learning for real-world robotics</p>
|
||||
</h3>
|
||||
|
||||
---
|
||||
|
||||
|
||||
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
|
||||
|
||||
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
|
||||
|
||||
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulated environments so that everyone can get started. In the coming weeks, the plan is to add more and more support for real-world robotics on the most affordable and capable robots out there.
|
||||
|
||||
🤗 LeRobot hosts pretrained models and datasets on this HuggingFace community page: [huggingface.co/lerobot](https://huggingface.co/lerobot)
|
||||
|
||||
#### Examples of pretrained models and environments
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><img src="http://remicadene.com/assets/gif/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="http://remicadene.com/assets/gif/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="http://remicadene.com/assets/gif/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">ACT policy on ALOHA env</td>
|
||||
<td align="center">TDMPC policy on SimXArm env</td>
|
||||
<td align="center">Diffusion policy on PushT env</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
- ACT policy and ALOHA environment are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha/)
|
||||
- Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
|
||||
- TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)
|
||||
- Abstractions and utilities for Reinforcement Learning come from [TorchRL](https://github.com/pytorch/rl)
|
||||
# LeRobot
|
||||
|
||||
## Installation
|
||||
|
||||
Download our source code:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git && cd lerobot
|
||||
Create a virtual environment with python 3.10, e.g. using `conda`:
|
||||
```
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
[Install `poetry`](https://python-poetry.org/docs/#installation) (if you don't have it already)
|
||||
```
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
```
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
pip install .
|
||||
Install dependencies
|
||||
```
|
||||
poetry install
|
||||
```
|
||||
|
||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
||||
- [xarm](https://github.com/huggingface/gym-xarm)
|
||||
- [pusht](https://github.com/huggingface/gym-pusht)
|
||||
|
||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||
```bash
|
||||
pip install ".[aloha, pusht]"
|
||||
If you encounter a disk space error, try to change your tmp dir to a location where you have enough disk space, e.g.
|
||||
```
|
||||
mkdir ~/tmp
|
||||
export TMPDIR='~/tmp'
|
||||
```
|
||||
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
|
||||
```bash
|
||||
wandb login
|
||||
Install `diffusion_policy` #HACK
|
||||
```
|
||||
# from this directory
|
||||
git clone https://github.com/real-stanford/diffusion_policy
|
||||
cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/
|
||||
```
|
||||
|
||||
## Walkthrough
|
||||
## Usage
|
||||
|
||||
|
||||
### Train
|
||||
|
||||
```
|
||||
.
|
||||
├── lerobot
|
||||
| ├── configs # contains hydra yaml files with all options that you can override in the command line
|
||||
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
|
||||
| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, xarm.yaml
|
||||
| | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml
|
||||
| ├── common # contains classes and utilities
|
||||
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
|
||||
| | ├── envs # various sim environments: aloha, pusht, xarm
|
||||
| | └── policies # various policies: act, diffusion, tdmpc
|
||||
| └── scripts # contains functions to execute via command line
|
||||
| ├── visualize_dataset.py # load a dataset and render its demonstrations
|
||||
| ├── eval.py # load policy and evaluate it on an environment
|
||||
| └── train.py # train a policy via imitation learning and/or reinforcement learning
|
||||
├── outputs # contains results of scripts execution: logs, videos, model checkpoints
|
||||
├── .github
|
||||
| └── workflows
|
||||
| └── test.yml # defines install settings for continuous integration and specifies end-to-end tests
|
||||
└── tests # contains pytest utilities for continuous integration
|
||||
|
||||
```
|
||||
|
||||
### Visualize datasets
|
||||
|
||||
Check out [examples](./examples) to see how you can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities.
|
||||
|
||||
Or you can achieve the same result by executing our script from the command line:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
env=pusht \
|
||||
hydra.run.dir=outputs/visualize_dataset/example
|
||||
# >>> ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||
```
|
||||
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
Check out [examples](./examples) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
|
||||
|
||||
Or you can achieve the same result by executing our script from the command line:
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--hub-id lerobot/diffusion_policy_pusht_image \
|
||||
eval_episodes=10 \
|
||||
hydra.run.dir=outputs/eval/example_hub
|
||||
```
|
||||
|
||||
After training your own policy, you can also re-evaluate the checkpoints with:
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--config PATH/TO/FOLDER/config.yaml \
|
||||
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
|
||||
eval_episodes=10 \
|
||||
hydra.run.dir=outputs/eval/example_dir
|
||||
```
|
||||
|
||||
See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
Check out [examples](./examples) to see how you can start training a model on a dataset, which will be automatically downloaded if needed.
|
||||
|
||||
In general, you can use our training script to easily train any policy on any environment:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
env=aloha \
|
||||
task=sim_insertion \
|
||||
repo_id=lerobot/aloha_sim_insertion_scripted \
|
||||
policy=act \
|
||||
hydra.run.dir=outputs/train/aloha_act
|
||||
hydra.job.name=pusht \
|
||||
env=pusht
|
||||
```
|
||||
|
||||
After training, you may want to revisit model evaluation to change the evaluation settings. In fact, during training every checkpoint is already evaluated but on a low number of episodes for efficiency. Check out [example](./examples) to evaluate any model checkpoint on more episodes to increase statistical significance.
|
||||
|
||||
## Contribute
|
||||
|
||||
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Add a new dataset
|
||||
|
||||
```python
|
||||
# TODO(rcadene, AdilZouitine): rewrite this section
|
||||
```
|
||||
|
||||
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then you can upload it to the hub with:
|
||||
```bash
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
|
||||
--repo-type dataset \
|
||||
--revision v1.0
|
||||
```
|
||||
|
||||
You will need to set the corresponding version as a default argument in your dataset class:
|
||||
```python
|
||||
version: str | None = "v1.1",
|
||||
```
|
||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
||||
|
||||
For instance, for [lerobot/pusht](https://huggingface.co/datasets/lerobot/pusht), we used:
|
||||
```bash
|
||||
HF_USER=lerobot
|
||||
DATASET=pusht
|
||||
```
|
||||
|
||||
If you want to improve an existing dataset, you can download it locally with:
|
||||
```bash
|
||||
mkdir -p data/$DATASET
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download ${HF_USER}/$DATASET \
|
||||
--repo-type dataset \
|
||||
--local-dir data/$DATASET \
|
||||
--local-dir-use-symlinks=False \
|
||||
--revision v1.0
|
||||
```
|
||||
|
||||
Iterate on your code and dataset with:
|
||||
```bash
|
||||
DATA_DIR=data python train.py
|
||||
```
|
||||
|
||||
Upload a new version (v2.0 or v1.1 if the changes are respectively more or less significant):
|
||||
```bash
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATASET \
|
||||
--repo-type dataset \
|
||||
--revision v1.1 \
|
||||
--delete "*"
|
||||
```
|
||||
|
||||
Then you will need to set the corresponding version as a default argument in your dataset class:
|
||||
```python
|
||||
version: str | None = "v1.1",
|
||||
```
|
||||
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)
|
||||
|
||||
|
||||
Finally, you might want to mock the dataset if you need to update the unit tests as well:
|
||||
```bash
|
||||
python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir tests/data/$DATASET
|
||||
```
|
||||
|
||||
### Add a pretrained policy
|
||||
|
||||
```python
|
||||
# TODO(rcadene, alexander-soare): rewrite this section
|
||||
```
|
||||
|
||||
Once you have trained a policy you may upload it to the HuggingFace hub.
|
||||
|
||||
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
|
||||
|
||||
Secondly, assuming you have trained a policy, you need:
|
||||
|
||||
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
|
||||
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
|
||||
|
||||
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
|
||||
### Visualize offline buffer
|
||||
|
||||
```
|
||||
to_upload
|
||||
├── config.yaml
|
||||
└── model.pt
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
hydra.run.dir=tmp/$(date +"%Y_%m_%d") \
|
||||
env=pusht
|
||||
```
|
||||
|
||||
With the folder prepared, run the following with a desired revision ID.
|
||||
### Visualize online buffer / Eval
|
||||
|
||||
```bash
|
||||
huggingface-cli upload $HUB_ID to_upload --revision $REVISION_ID
|
||||
```
|
||||
python lerobot/scripts/eval.py \
|
||||
hydra.run.dir=tmp/$(date +"%Y_%m_%d") \
|
||||
env=pusht
|
||||
```
|
||||
|
||||
If you want this to be the default revision also run the following (don't worry, it won't upload the files again; it will just adjust the file pointers):
|
||||
|
||||
```bash
|
||||
huggingface-cli upload $HUB_ID to_upload
|
||||
```
|
||||
## TODO
|
||||
|
||||
See `eval.py` for an example of how a user may use your policy.
|
||||
If you don't know how to contribute or want to know the next features we working on, look on this project page: [LeRobot TODO](https://github.com/users/Cadene/projects/1)
|
||||
|
||||
Ask [Remi Cadene](re.cadene@gmail.com) for access if needed.
|
||||
|
||||
|
||||
### Improve your code with profiling
|
||||
## Profile
|
||||
|
||||
An example of a code snippet to profile the evaluation of a policy:
|
||||
**Example**
|
||||
```python
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
@@ -295,12 +87,25 @@ with profile(
|
||||
with record_function("eval_policy"):
|
||||
for i in range(num_episodes):
|
||||
prof.step()
|
||||
# insert code to profile, potentially whole body of eval_policy function
|
||||
```
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/eval.py \
|
||||
--config outputs/pusht/.hydra/config.yaml \
|
||||
pretrained_model_path=outputs/pusht/model.pt \
|
||||
pretrained_model_path=/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt \
|
||||
eval_episodes=7
|
||||
```
|
||||
|
||||
## Contribute
|
||||
|
||||
**Style**
|
||||
```
|
||||
# install if needed
|
||||
pre-commit install
|
||||
# apply style and linter checks before git commit
|
||||
pre-commit run -a
|
||||
```
|
||||
|
||||
**Tests**
|
||||
```
|
||||
pytest -sx tests
|
||||
```
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
# Configure image
|
||||
ARG PYTHON_VERSION=3.10
|
||||
|
||||
FROM python:${PYTHON_VERSION}-slim
|
||||
ARG PYTHON_VERSION
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create virtual environment
|
||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||
RUN python -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
|
||||
# Execute in bash shell rather than python
|
||||
CMD ["/bin/bash"]
|
||||
@@ -1,27 +0,0 @@
|
||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||
|
||||
# Configure image
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create virtual environment
|
||||
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||
RUN python -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
|
||||
# Install LeRobot
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
@@ -1,91 +0,0 @@
|
||||
"""
|
||||
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
|
||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||
|
||||
Features included in this script:
|
||||
- Loading a dataset and accessing its properties.
|
||||
- Filtering data by episode number.
|
||||
- Converting tensor data for visualization.
|
||||
- Saving video files from dataset frames.
|
||||
- Using advanced dataset features like timestamp-based frame selection.
|
||||
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
||||
|
||||
The script ends with examples of how to batch process data using PyTorch's DataLoader.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import imageio
|
||||
import torch
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
print("List of available datasets", lerobot.available_datasets)
|
||||
# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted',
|
||||
# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted',
|
||||
# # 'lerobot/pusht', 'lerobot/xarm_lift_medium']
|
||||
|
||||
repo_id = "lerobot/pusht"
|
||||
|
||||
# You can easily load a dataset from a Hugging Face repositery
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
print(f"{dataset=}")
|
||||
print(f"{dataset.hf_dataset=}")
|
||||
|
||||
# and provides additional utilities for robotics and compatibility with pytorch
|
||||
print(f"number of samples/frames: {dataset.num_samples=}")
|
||||
print(f"number of episodes: {dataset.num_episodes=}")
|
||||
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
||||
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
||||
|
||||
# While the LeRobotDataset adds helpers for working within our library, we still expose the underling Hugging Face dataset.
|
||||
# It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
||||
# TODO(rcadene): remove this example of accessing hf_dataset
|
||||
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
||||
|
||||
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grab all the image frames.
|
||||
frames = [sample["observation.image"] for sample in dataset]
|
||||
|
||||
# but frames are now float32 range [0,1] channel first (c,h,w) to follow pytorch convention,
|
||||
# to view them, we convert to uint8 range [0,255]
|
||||
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
||||
# and to channel last (h,w,c)
|
||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||
|
||||
# and finally save them to a mp4 video
|
||||
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
||||
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_5.mp4", frames, fps=dataset.fps)
|
||||
|
||||
# For many machine learning applications we need to load histories of past observations, or trajectorys of future actions. Our datasets can load previous and future frames for each key/modality,
|
||||
# using timestamps differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
"observation.image": [-1, -0.5, -0.20, 0],
|
||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
||||
print(f"{dataset[0]['action'].shape=}") # (64,c)
|
||||
|
||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers
|
||||
# because they are just PyTorch datasets.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
for batch in dataloader:
|
||||
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32,8,c)
|
||||
print(f"{batch['action'].shape=}") # (32,64,c)
|
||||
break
|
||||
@@ -1,38 +0,0 @@
|
||||
"""
|
||||
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
||||
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.eval import eval
|
||||
|
||||
# Get a pretrained policy from the hub.
|
||||
# TODO(alexander-soare): This no longer works until we upload a new model that uses the current configs.
|
||||
hub_id = "lerobot/diffusion_policy_pusht_image"
|
||||
folder = Path(snapshot_download(hub_id))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
# folder = Path("outputs/train/example_pusht_diffusion")
|
||||
|
||||
config_path = folder / "config.yaml"
|
||||
weights_path = folder / "model.pt"
|
||||
|
||||
# Override some config parameters to do with evaluation.
|
||||
overrides = [
|
||||
f"policy.pretrained_model_path={weights_path}",
|
||||
"eval_episodes=10",
|
||||
"rollout_batch_size=10",
|
||||
"device=cuda",
|
||||
]
|
||||
|
||||
# Create a Hydra config.
|
||||
cfg = init_hydra_config(config_path, overrides)
|
||||
|
||||
# Evaluate the policy and save the outputs including metrics and videos.
|
||||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
|
||||
)
|
||||
@@ -1,67 +0,0 @@
|
||||
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
||||
|
||||
Once you have trained a model with this script, you can try to evaluate it on
|
||||
examples/2_evaluate_pretrained_policy.py
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||
os.makedirs(output_directory, exist_ok=True)
|
||||
|
||||
# Number of offline training steps (we'll only do offline training for this example.
|
||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||
training_steps = 5000
|
||||
device = torch.device("cuda")
|
||||
log_freq = 250
|
||||
|
||||
# Set up the dataset.
|
||||
hydra_cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
|
||||
dataset = make_dataset(hydra_cfg)
|
||||
|
||||
# Set up the the policy.
|
||||
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
|
||||
# For this example, no arguments need to be passed because the defaults are set up for PushT.
|
||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# Create dataloader for offline training.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device != torch.device("cpu"),
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Run training loop.
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
info = policy.update(batch)
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy and configuration for later use.
|
||||
policy.save(output_directory / "model.pt")
|
||||
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
|
||||
@@ -1,92 +0,0 @@
|
||||
"""
|
||||
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
|
||||
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import lerobot
|
||||
print(lerobot.available_envs)
|
||||
print(lerobot.available_tasks_per_env)
|
||||
print(lerobot.available_datasets)
|
||||
print(lerobot.available_datasets_per_env)
|
||||
print(lerobot.available_policies)
|
||||
print(lerobot.available_policies_per_env)
|
||||
```
|
||||
|
||||
When implementing a new dataset loadable with LeRobotDataset follow these steps:
|
||||
- Update `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
|
||||
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
|
||||
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
|
||||
- Set the required `name` class attribute.
|
||||
- Update variables in `tests/test_available.py` by importing your new Policy class
|
||||
"""
|
||||
|
||||
import itertools
|
||||
|
||||
from lerobot.__version__ import __version__ # noqa: F401
|
||||
|
||||
available_tasks_per_env = {
|
||||
"aloha": [
|
||||
"AlohaInsertion-v0",
|
||||
"AlohaTransferCube-v0",
|
||||
],
|
||||
"pusht": ["PushT-v0"],
|
||||
"xarm": ["XarmLift-v0"],
|
||||
}
|
||||
available_envs = list(available_tasks_per_env.keys())
|
||||
|
||||
available_datasets_per_env = {
|
||||
"aloha": [
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/aloha_sim_insertion_scripted",
|
||||
"lerobot/aloha_sim_transfer_cube_human",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted",
|
||||
],
|
||||
"pusht": ["lerobot/pusht"],
|
||||
"xarm": [
|
||||
"lerobot/xarm_lift_medium",
|
||||
"lerobot/xarm_lift_medium_replay",
|
||||
"lerobot/xarm_push_medium",
|
||||
"lerobot/xarm_push_medium_replay",
|
||||
],
|
||||
}
|
||||
|
||||
available_datasets_without_env = ["lerobot/umi_cup_in_the_wild"]
|
||||
|
||||
available_datasets = list(
|
||||
itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env)
|
||||
)
|
||||
|
||||
# TODO(rcadene, aliberts, alexander-soare): Add real-world env with a gym API
|
||||
available_datasets_without_env = ["lerobot/umi_cup_in_the_wild"]
|
||||
|
||||
available_datasets = list(
|
||||
itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env)
|
||||
)
|
||||
|
||||
available_policies = [
|
||||
"act",
|
||||
"diffusion",
|
||||
"tdmpc",
|
||||
]
|
||||
|
||||
available_policies_per_env = {
|
||||
"aloha": ["act"],
|
||||
"pusht": ["diffusion"],
|
||||
"xarm": ["tdmpc"],
|
||||
}
|
||||
|
||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||
env_dataset_pairs = [
|
||||
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
|
||||
]
|
||||
env_dataset_policy_triplets = [
|
||||
(env, dataset, policy)
|
||||
for env, datasets in available_datasets_per_env.items()
|
||||
for dataset in datasets
|
||||
for policy in available_policies_per_env[env]
|
||||
]
|
||||
|
||||
@@ -1,8 +1 @@
|
||||
"""To enable `lerobot.__version__`"""
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
try:
|
||||
__version__ = version("lerobot")
|
||||
except PackageNotFoundError:
|
||||
__version__ = "unknown"
|
||||
__version__ = "0.0.0"
|
||||
|
||||
0
lerobot/common/__init__.py
Normal file
0
lerobot/common/__init__.py
Normal file
0
lerobot/common/datasets/__init__.py
Normal file
0
lerobot/common/datasets/__init__.py
Normal file
158
lerobot/common/datasets/abstract.py
Normal file
158
lerobot/common/datasets/abstract.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import abc
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.datasets.utils import _get_root_dir
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
|
||||
|
||||
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
self.dataset_id = dataset_id
|
||||
self.shuffle = shuffle
|
||||
self.root = _get_root_dir(self.dataset_id) if root is None else root
|
||||
self.root = Path(self.root)
|
||||
self.data_dir = self.root / self.dataset_id
|
||||
|
||||
storage = self._download_or_load_storage()
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=ImmutableDatasetWriter() if writer is None else writer,
|
||||
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def stats_patterns(self) -> dict:
|
||||
return {
|
||||
("observation", "state"): "b c -> 1 c",
|
||||
("observation", "image"): "b c h w -> 1 c 1 1",
|
||||
("action"): "b c -> 1 c",
|
||||
}
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list:
|
||||
return [("observation", "image")]
|
||||
|
||||
@property
|
||||
def num_cameras(self) -> int:
|
||||
return len(self.image_keys)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self._storage._storage["episode"].unique())
|
||||
|
||||
def set_transform(self, transform):
|
||||
self.transform = transform
|
||||
|
||||
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
||||
stats_path = self.data_dir / "stats.pth"
|
||||
if stats_path.exists():
|
||||
stats = torch.load(stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
stats = self._compute_stats(num_batch, batch_size)
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
@abc.abstractmethod
|
||||
def _download_and_preproc(self) -> torch.StorageBase:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _download_or_load_storage(self):
|
||||
if not self._is_downloaded():
|
||||
storage = self._download_and_preproc()
|
||||
else:
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.data_dir))
|
||||
return storage
|
||||
|
||||
def _is_downloaded(self) -> bool:
|
||||
return self.data_dir.is_dir()
|
||||
|
||||
def _compute_stats(self, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
storage=self._storage,
|
||||
batch_size=batch_size,
|
||||
prefetch=True,
|
||||
)
|
||||
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
|
||||
# compute mean, min, max
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
batch = rb.sample()
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
if key not in mean:
|
||||
# first batch initialize mean, min, max
|
||||
mean[key] = einops.reduce(batch[key], pattern, "mean")
|
||||
max[key] = einops.reduce(batch[key], pattern, "max")
|
||||
min[key] = einops.reduce(batch[key], pattern, "min")
|
||||
else:
|
||||
mean[key] += einops.reduce(batch[key], pattern, "mean")
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
batch = rb.sample()
|
||||
|
||||
for key in self.stats_patterns:
|
||||
mean[key] /= num_batch
|
||||
|
||||
# compute std, min, max
|
||||
for _ in tqdm.tqdm(range(num_batch)):
|
||||
batch = rb.sample()
|
||||
for key, pattern in self.stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
if key not in std:
|
||||
# first batch initialize std
|
||||
std[key] = (batch_mean - mean[key]) ** 2
|
||||
else:
|
||||
std[key] += (batch_mean - mean[key]) ** 2
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
for key in self.stats_patterns:
|
||||
std[key] = torch.sqrt(std[key] / num_batch)
|
||||
|
||||
stats = TensorDict({}, batch_size=[])
|
||||
for key in self.stats_patterns:
|
||||
stats[(*key, "mean")] = mean[key]
|
||||
stats[(*key, "std")] = std[key]
|
||||
stats[(*key, "max")] = max[key]
|
||||
stats[(*key, "min")] = min[key]
|
||||
|
||||
if key[0] == "observation":
|
||||
# use same stats for the next observations
|
||||
stats[("next", *key)] = stats[key]
|
||||
return stats
|
||||
182
lerobot/common/datasets/aloha.py
Normal file
182
lerobot/common/datasets/aloha.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import gdown
|
||||
import h5py
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
|
||||
DATASET_IDS = [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
|
||||
FOLDER_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
|
||||
}
|
||||
|
||||
EP48_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
|
||||
}
|
||||
|
||||
EP49_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
|
||||
}
|
||||
|
||||
NUM_EPISODES = {
|
||||
"aloha_sim_insertion_human": 50,
|
||||
"aloha_sim_insertion_scripted": 50,
|
||||
"aloha_sim_transfer_cube_human": 50,
|
||||
"aloha_sim_transfer_cube_scripted": 50,
|
||||
}
|
||||
|
||||
EPISODE_LEN = {
|
||||
"aloha_sim_insertion_human": 500,
|
||||
"aloha_sim_insertion_scripted": 400,
|
||||
"aloha_sim_transfer_cube_human": 400,
|
||||
"aloha_sim_transfer_cube_scripted": 400,
|
||||
}
|
||||
|
||||
CAMERAS = {
|
||||
"aloha_sim_insertion_human": ["top"],
|
||||
"aloha_sim_insertion_scripted": ["top"],
|
||||
"aloha_sim_transfer_cube_human": ["top"],
|
||||
"aloha_sim_transfer_cube_scripted": ["top"],
|
||||
}
|
||||
|
||||
|
||||
def download(data_dir, dataset_id):
|
||||
assert dataset_id in DATASET_IDS
|
||||
assert dataset_id in FOLDER_URLS
|
||||
assert dataset_id in EP48_URLS
|
||||
assert dataset_id in EP49_URLS
|
||||
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
|
||||
|
||||
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
||||
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
|
||||
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
|
||||
|
||||
class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
assert dataset_id in DATASET_IDS
|
||||
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@property
|
||||
def stats_patterns(self) -> dict:
|
||||
d = {
|
||||
("observation", "state"): "b c -> 1 c",
|
||||
("action"): "b c -> 1 c",
|
||||
}
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1"
|
||||
return d
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list:
|
||||
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
||||
|
||||
def _download_and_preproc(self):
|
||||
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
||||
if not raw_dir.is_dir():
|
||||
download(raw_dir, self.dataset_id)
|
||||
|
||||
total_num_frames = 0
|
||||
logging.info("Compute total number of frames to initialize offline buffer")
|
||||
for ep_id in range(NUM_EPISODES[self.dataset_id]):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
total_num_frames += ep["/action"].shape[0] - 1
|
||||
logging.info(f"{total_num_frames=}")
|
||||
|
||||
logging.info("Initialize and feed offline buffer")
|
||||
idxtd = 0
|
||||
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
ep_num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(ep_num_frames, 1, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "state"): state[:-1],
|
||||
"action": action[:-1],
|
||||
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
|
||||
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
|
||||
("next", "observation", "state"): state[1:],
|
||||
# TODO: compute reward and success
|
||||
# ("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
# ("next", "success"): success[1:],
|
||||
},
|
||||
batch_size=ep_num_frames - 1,
|
||||
)
|
||||
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||
ep_td["observation", "image", cam] = image[:-1]
|
||||
ep_td["next", "observation", "image", cam] = image[1:]
|
||||
|
||||
if ep_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.data_dir)
|
||||
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
idxtd = idxtd + len(ep_td)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
@@ -3,42 +3,112 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.envs.transforms import NormalizeTransform
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
split="train",
|
||||
def make_offline_buffer(
|
||||
cfg, overwrite_sampler=None, normalize=True, overwrite_batch_size=None, overwrite_prefetch=None
|
||||
):
|
||||
if cfg.env.name not in cfg.dataset.repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})."
|
||||
)
|
||||
if cfg.policy.balanced_sampling:
|
||||
assert cfg.online_steps > 0
|
||||
batch_size = None
|
||||
pin_memory = False
|
||||
prefetch = None
|
||||
else:
|
||||
assert cfg.online_steps == 0
|
||||
num_slices = cfg.policy.batch_size
|
||||
batch_size = cfg.policy.horizon * num_slices
|
||||
pin_memory = cfg.device == "cuda"
|
||||
prefetch = cfg.prefetch
|
||||
|
||||
delta_timestamps = cfg.policy.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
for key in delta_timestamps:
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
if overwrite_batch_size is not None:
|
||||
batch_size = overwrite_batch_size
|
||||
|
||||
# TODO(rcadene): add data augmentations
|
||||
if overwrite_prefetch is not None:
|
||||
prefetch = overwrite_prefetch
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
split=split,
|
||||
if overwrite_sampler is None:
|
||||
# TODO(rcadene): move batch_size outside
|
||||
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||
|
||||
if cfg.offline_prioritized_sampler:
|
||||
logging.info("use prioritized sampler for offline dataset")
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.policy.per_alpha,
|
||||
beta=cfg.policy.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
else:
|
||||
logging.info("use simple sampler for offline dataset")
|
||||
sampler = SliceSampler(
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
else:
|
||||
sampler = overwrite_sampler
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
|
||||
clsfunc = SimxarmExperienceReplay
|
||||
dataset_id = f"xarm_{cfg.env.task}_medium"
|
||||
|
||||
elif cfg.env.name == "pusht":
|
||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||
|
||||
clsfunc = PushtExperienceReplay
|
||||
dataset_id = "pusht"
|
||||
|
||||
elif cfg.env.name == "aloha":
|
||||
from lerobot.common.datasets.aloha import AlohaExperienceReplay
|
||||
|
||||
clsfunc = AlohaExperienceReplay
|
||||
dataset_id = f"aloha_{cfg.env.task}"
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
offline_buffer = clsfunc(
|
||||
dataset_id=dataset_id,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
for key, stats_dict in cfg.override_dataset_stats.items():
|
||||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats()
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
|
||||
return dataset
|
||||
if cfg.policy == "tdmpc":
|
||||
for key in offline_buffer.image_keys:
|
||||
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
|
||||
in_keys.append(key)
|
||||
# since we use next observations in tdmpc
|
||||
in_keys.append(("next", *key))
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
|
||||
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
transform = NormalizeTransform(stats, in_keys, mode="min_max")
|
||||
offline_buffer.set_transform(transform)
|
||||
|
||||
if not overwrite_sampler:
|
||||
index = torch.arange(0, offline_buffer.num_samples, 1)
|
||||
sampler.extend(index)
|
||||
|
||||
return offline_buffer
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
load_episode_data_index,
|
||||
load_hf_dataset,
|
||||
load_info,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
)
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
version: str | None = "v1.1",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
|
||||
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
||||
self.stats = load_stats(repo_id, version, root)
|
||||
self.info = load_info(repo_id, version, root)
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
return [key for key, feats in self.hf_dataset.features.items() if isinstance(feats, datasets.Image)]
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.hf_dataset.unique("episode_index"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.hf_dataset[idx]
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.hf_dataset,
|
||||
self.episode_data_index,
|
||||
self.delta_timestamps,
|
||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||
)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
@@ -1,619 +0,0 @@
|
||||
"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/)
|
||||
|
||||
Copied from the original Diffusion Policy repository and used in our `download_and_upload_dataset.py` script.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import numbers
|
||||
import os
|
||||
from functools import cached_property
|
||||
|
||||
import numcodecs
|
||||
import numpy as np
|
||||
import zarr
|
||||
|
||||
|
||||
def check_chunks_compatible(chunks: tuple, shape: tuple):
|
||||
assert len(shape) == len(chunks)
|
||||
for c in chunks:
|
||||
assert isinstance(c, numbers.Integral)
|
||||
assert c > 0
|
||||
|
||||
|
||||
def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"):
|
||||
old_arr = group[name]
|
||||
if chunks is None:
|
||||
chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks
|
||||
check_chunks_compatible(chunks, old_arr.shape)
|
||||
|
||||
if compressor is None:
|
||||
compressor = old_arr.compressor
|
||||
|
||||
if (chunks == old_arr.chunks) and (compressor == old_arr.compressor):
|
||||
# no change
|
||||
return old_arr
|
||||
|
||||
# rechunk recompress
|
||||
group.move(name, tmp_key)
|
||||
old_arr = group[tmp_key]
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
source=old_arr,
|
||||
dest=group,
|
||||
name=name,
|
||||
chunks=chunks,
|
||||
compressor=compressor,
|
||||
)
|
||||
del group[tmp_key]
|
||||
arr = group[name]
|
||||
return arr
|
||||
|
||||
|
||||
def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=None):
|
||||
"""
|
||||
Common shapes
|
||||
T,D
|
||||
T,N,D
|
||||
T,H,W,C
|
||||
T,N,H,W,C
|
||||
"""
|
||||
itemsize = np.dtype(dtype).itemsize
|
||||
# reversed
|
||||
rshape = list(shape[::-1])
|
||||
if max_chunk_length is not None:
|
||||
rshape[-1] = int(max_chunk_length)
|
||||
split_idx = len(shape) - 1
|
||||
for i in range(len(shape) - 1):
|
||||
this_chunk_bytes = itemsize * np.prod(rshape[:i])
|
||||
next_chunk_bytes = itemsize * np.prod(rshape[: i + 1])
|
||||
if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes:
|
||||
split_idx = i
|
||||
|
||||
rchunks = rshape[:split_idx]
|
||||
item_chunk_bytes = itemsize * np.prod(rshape[:split_idx])
|
||||
this_max_chunk_length = rshape[split_idx]
|
||||
next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes))
|
||||
rchunks.append(next_chunk_length)
|
||||
len_diff = len(shape) - len(rchunks)
|
||||
rchunks.extend([1] * len_diff)
|
||||
chunks = tuple(rchunks[::-1])
|
||||
# print(np.prod(chunks) * itemsize / target_chunk_bytes)
|
||||
return chunks
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
"""
|
||||
Zarr-based temporal datastructure.
|
||||
Assumes first dimension to be time. Only chunk in time dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, root: zarr.Group | dict[str, dict]):
|
||||
"""
|
||||
Dummy constructor. Use copy_from* and create_from* class methods instead.
|
||||
"""
|
||||
assert "data" in root
|
||||
assert "meta" in root
|
||||
assert "episode_ends" in root["meta"]
|
||||
for value in root["data"].values():
|
||||
assert value.shape[0] == root["meta"]["episode_ends"][-1]
|
||||
self.root = root
|
||||
|
||||
# ============= create constructors ===============
|
||||
@classmethod
|
||||
def create_empty_zarr(cls, storage=None, root=None):
|
||||
if root is None:
|
||||
if storage is None:
|
||||
storage = zarr.MemoryStore()
|
||||
root = zarr.group(store=storage)
|
||||
root.require_group("data", overwrite=False)
|
||||
meta = root.require_group("meta", overwrite=False)
|
||||
if "episode_ends" not in meta:
|
||||
meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False)
|
||||
return cls(root=root)
|
||||
|
||||
@classmethod
|
||||
def create_empty_numpy(cls):
|
||||
root = {"data": {}, "meta": {"episode_ends": np.zeros((0,), dtype=np.int64)}}
|
||||
return cls(root=root)
|
||||
|
||||
@classmethod
|
||||
def create_from_group(cls, group, **kwargs):
|
||||
if "data" not in group:
|
||||
# create from stratch
|
||||
buffer = cls.create_empty_zarr(root=group, **kwargs)
|
||||
else:
|
||||
# already exist
|
||||
buffer = cls(root=group, **kwargs)
|
||||
return buffer
|
||||
|
||||
@classmethod
|
||||
def create_from_path(cls, zarr_path, mode="r", **kwargs):
|
||||
"""
|
||||
Open a on-disk zarr directly (for dataset larger than memory).
|
||||
Slower.
|
||||
"""
|
||||
group = zarr.open(os.path.expanduser(zarr_path), mode)
|
||||
return cls.create_from_group(group, **kwargs)
|
||||
|
||||
# ============= copy constructors ===============
|
||||
@classmethod
|
||||
def copy_from_store(
|
||||
cls,
|
||||
src_store,
|
||||
store=None,
|
||||
keys=None,
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: dict | str | numcodecs.abc.Codec | None = None,
|
||||
if_exists="replace",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load to memory.
|
||||
"""
|
||||
src_root = zarr.group(src_store)
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
root = None
|
||||
if store is None:
|
||||
# numpy backend
|
||||
meta = {}
|
||||
for key, value in src_root["meta"].items():
|
||||
if len(value.shape) == 0:
|
||||
meta[key] = np.array(value)
|
||||
else:
|
||||
meta[key] = value[:]
|
||||
|
||||
if keys is None:
|
||||
keys = src_root["data"].keys()
|
||||
data = {}
|
||||
for key in keys:
|
||||
arr = src_root["data"][key]
|
||||
data[key] = arr[:]
|
||||
|
||||
root = {"meta": meta, "data": data}
|
||||
else:
|
||||
root = zarr.group(store=store)
|
||||
# copy without recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
||||
)
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
if keys is None:
|
||||
keys = src_root["data"].keys()
|
||||
for key in keys:
|
||||
value = src_root["data"][key]
|
||||
cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=src_store,
|
||||
dest=store,
|
||||
source_path=this_path,
|
||||
dest_path=this_path,
|
||||
if_exists=if_exists,
|
||||
)
|
||||
else:
|
||||
# copy with recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
source=value,
|
||||
dest=data_group,
|
||||
name=key,
|
||||
chunks=cks,
|
||||
compressor=cpr,
|
||||
if_exists=if_exists,
|
||||
)
|
||||
buffer = cls(root=root)
|
||||
return buffer
|
||||
|
||||
@classmethod
|
||||
def copy_from_path(
|
||||
cls,
|
||||
zarr_path,
|
||||
backend=None,
|
||||
store=None,
|
||||
keys=None,
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: dict | str | numcodecs.abc.Codec | None = None,
|
||||
if_exists="replace",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Copy a on-disk zarr to in-memory compressed.
|
||||
Recommended
|
||||
"""
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
if backend == "numpy":
|
||||
print("backend argument is deprecated!")
|
||||
store = None
|
||||
group = zarr.open(os.path.expanduser(zarr_path), "r")
|
||||
return cls.copy_from_store(
|
||||
src_store=group.store,
|
||||
store=store,
|
||||
keys=keys,
|
||||
chunks=chunks,
|
||||
compressors=compressors,
|
||||
if_exists=if_exists,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ============= save methods ===============
|
||||
def save_to_store(
|
||||
self,
|
||||
store,
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
||||
if_exists="replace",
|
||||
**kwargs,
|
||||
):
|
||||
root = zarr.group(store)
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
if self.backend == "zarr":
|
||||
# recompression free copy
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=self.root.store,
|
||||
dest=store,
|
||||
source_path="/meta",
|
||||
dest_path="/meta",
|
||||
if_exists=if_exists,
|
||||
)
|
||||
else:
|
||||
meta_group = root.create_group("meta", overwrite=True)
|
||||
# save meta, no chunking
|
||||
for key, value in self.root["meta"].items():
|
||||
_ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape)
|
||||
|
||||
# save data, chunk
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
for key, value in self.root["data"].items():
|
||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
if isinstance(value, zarr.Array):
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
source=self.root.store,
|
||||
dest=store,
|
||||
source_path=this_path,
|
||||
dest_path=this_path,
|
||||
if_exists=if_exists,
|
||||
)
|
||||
else:
|
||||
# copy with recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
source=value,
|
||||
dest=data_group,
|
||||
name=key,
|
||||
chunks=cks,
|
||||
compressor=cpr,
|
||||
if_exists=if_exists,
|
||||
)
|
||||
else:
|
||||
# numpy
|
||||
_ = data_group.array(name=key, data=value, chunks=cks, compressor=cpr)
|
||||
return store
|
||||
|
||||
def save_to_path(
|
||||
self,
|
||||
zarr_path,
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
||||
if_exists="replace",
|
||||
**kwargs,
|
||||
):
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
store = zarr.DirectoryStore(os.path.expanduser(zarr_path))
|
||||
return self.save_to_store(
|
||||
store, chunks=chunks, compressors=compressors, if_exists=if_exists, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def resolve_compressor(compressor="default"):
|
||||
if compressor == "default":
|
||||
compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE)
|
||||
elif compressor == "disk":
|
||||
compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE)
|
||||
return compressor
|
||||
|
||||
@classmethod
|
||||
def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array):
|
||||
# allows compressor to be explicitly set to None
|
||||
cpr = "nil"
|
||||
if isinstance(compressors, dict):
|
||||
if key in compressors:
|
||||
cpr = cls.resolve_compressor(compressors[key])
|
||||
elif isinstance(array, zarr.Array):
|
||||
cpr = array.compressor
|
||||
else:
|
||||
cpr = cls.resolve_compressor(compressors)
|
||||
# backup default
|
||||
if cpr == "nil":
|
||||
cpr = cls.resolve_compressor("default")
|
||||
return cpr
|
||||
|
||||
@classmethod
|
||||
def _resolve_array_chunks(cls, chunks: dict | tuple, key, array):
|
||||
cks = None
|
||||
if isinstance(chunks, dict):
|
||||
if key in chunks:
|
||||
cks = chunks[key]
|
||||
elif isinstance(array, zarr.Array):
|
||||
cks = array.chunks
|
||||
elif isinstance(chunks, tuple):
|
||||
cks = chunks
|
||||
else:
|
||||
raise TypeError(f"Unsupported chunks type {type(chunks)}")
|
||||
# backup default
|
||||
if cks is None:
|
||||
cks = get_optimal_chunks(shape=array.shape, dtype=array.dtype)
|
||||
# check
|
||||
check_chunks_compatible(chunks=cks, shape=array.shape)
|
||||
return cks
|
||||
|
||||
# ============= properties =================
|
||||
@cached_property
|
||||
def data(self):
|
||||
return self.root["data"]
|
||||
|
||||
@cached_property
|
||||
def meta(self):
|
||||
return self.root["meta"]
|
||||
|
||||
def update_meta(self, data):
|
||||
# sanitize data
|
||||
np_data = {}
|
||||
for key, value in data.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
np_data[key] = value
|
||||
else:
|
||||
arr = np.array(value)
|
||||
if arr.dtype == object:
|
||||
raise TypeError(f"Invalid value type {type(value)}")
|
||||
np_data[key] = arr
|
||||
|
||||
meta_group = self.meta
|
||||
if self.backend == "zarr":
|
||||
for key, value in np_data.items():
|
||||
_ = meta_group.array(
|
||||
name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True
|
||||
)
|
||||
else:
|
||||
meta_group.update(np_data)
|
||||
|
||||
return meta_group
|
||||
|
||||
@property
|
||||
def episode_ends(self):
|
||||
return self.meta["episode_ends"]
|
||||
|
||||
def get_episode_idxs(self):
|
||||
import numba
|
||||
|
||||
numba.jit(nopython=True)
|
||||
|
||||
def _get_episode_idxs(episode_ends):
|
||||
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
||||
for i in range(len(episode_ends)):
|
||||
start = 0
|
||||
if i > 0:
|
||||
start = episode_ends[i - 1]
|
||||
end = episode_ends[i]
|
||||
for idx in range(start, end):
|
||||
result[idx] = i
|
||||
return result
|
||||
|
||||
return _get_episode_idxs(self.episode_ends)
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
backend = "numpy"
|
||||
if isinstance(self.root, zarr.Group):
|
||||
backend = "zarr"
|
||||
return backend
|
||||
|
||||
# =========== dict-like API ==============
|
||||
def __repr__(self) -> str:
|
||||
if self.backend == "zarr":
|
||||
return str(self.root.tree())
|
||||
else:
|
||||
return super().__repr__()
|
||||
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
def values(self):
|
||||
return self.data.values()
|
||||
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.data
|
||||
|
||||
# =========== our API ==============
|
||||
@property
|
||||
def n_steps(self):
|
||||
if len(self.episode_ends) == 0:
|
||||
return 0
|
||||
return self.episode_ends[-1]
|
||||
|
||||
@property
|
||||
def n_episodes(self):
|
||||
return len(self.episode_ends)
|
||||
|
||||
@property
|
||||
def chunk_size(self):
|
||||
if self.backend == "zarr":
|
||||
return next(iter(self.data.arrays()))[-1].chunks[0]
|
||||
return None
|
||||
|
||||
@property
|
||||
def episode_lengths(self):
|
||||
ends = self.episode_ends[:]
|
||||
ends = np.insert(ends, 0, 0)
|
||||
lengths = np.diff(ends)
|
||||
return lengths
|
||||
|
||||
def add_episode(
|
||||
self,
|
||||
data: dict[str, np.ndarray],
|
||||
chunks: dict[str, tuple] | None = None,
|
||||
compressors: str | numcodecs.abc.Codec | dict | None = None,
|
||||
):
|
||||
if chunks is None:
|
||||
chunks = {}
|
||||
if compressors is None:
|
||||
compressors = {}
|
||||
assert len(data) > 0
|
||||
is_zarr = self.backend == "zarr"
|
||||
|
||||
curr_len = self.n_steps
|
||||
episode_length = None
|
||||
for value in data.values():
|
||||
assert len(value.shape) >= 1
|
||||
if episode_length is None:
|
||||
episode_length = len(value)
|
||||
else:
|
||||
assert episode_length == len(value)
|
||||
new_len = curr_len + episode_length
|
||||
|
||||
for key, value in data.items():
|
||||
new_shape = (new_len,) + value.shape[1:]
|
||||
# create array
|
||||
if key not in self.data:
|
||||
if is_zarr:
|
||||
cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value)
|
||||
cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value)
|
||||
arr = self.data.zeros(
|
||||
name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr
|
||||
)
|
||||
else:
|
||||
# copy data to prevent modify
|
||||
arr = np.zeros(shape=new_shape, dtype=value.dtype)
|
||||
self.data[key] = arr
|
||||
else:
|
||||
arr = self.data[key]
|
||||
assert value.shape[1:] == arr.shape[1:]
|
||||
# same method for both zarr and numpy
|
||||
if is_zarr:
|
||||
arr.resize(new_shape)
|
||||
else:
|
||||
arr.resize(new_shape, refcheck=False)
|
||||
# copy data
|
||||
arr[-value.shape[0] :] = value
|
||||
|
||||
# append to episode ends
|
||||
episode_ends = self.episode_ends
|
||||
if is_zarr:
|
||||
episode_ends.resize(episode_ends.shape[0] + 1)
|
||||
else:
|
||||
episode_ends.resize(episode_ends.shape[0] + 1, refcheck=False)
|
||||
episode_ends[-1] = new_len
|
||||
|
||||
# rechunk
|
||||
if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]:
|
||||
rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5))
|
||||
|
||||
def drop_episode(self):
|
||||
is_zarr = self.backend == "zarr"
|
||||
episode_ends = self.episode_ends[:].copy()
|
||||
assert len(episode_ends) > 0
|
||||
start_idx = 0
|
||||
if len(episode_ends) > 1:
|
||||
start_idx = episode_ends[-2]
|
||||
for value in self.data.values():
|
||||
new_shape = (start_idx,) + value.shape[1:]
|
||||
if is_zarr:
|
||||
value.resize(new_shape)
|
||||
else:
|
||||
value.resize(new_shape, refcheck=False)
|
||||
if is_zarr:
|
||||
self.episode_ends.resize(len(episode_ends) - 1)
|
||||
else:
|
||||
self.episode_ends.resize(len(episode_ends) - 1, refcheck=False)
|
||||
|
||||
def pop_episode(self):
|
||||
assert self.n_episodes > 0
|
||||
episode = self.get_episode(self.n_episodes - 1, copy=True)
|
||||
self.drop_episode()
|
||||
return episode
|
||||
|
||||
def extend(self, data):
|
||||
self.add_episode(data)
|
||||
|
||||
def get_episode(self, idx, copy=False):
|
||||
idx = list(range(len(self.episode_ends)))[idx]
|
||||
start_idx = 0
|
||||
if idx > 0:
|
||||
start_idx = self.episode_ends[idx - 1]
|
||||
end_idx = self.episode_ends[idx]
|
||||
result = self.get_steps_slice(start_idx, end_idx, copy=copy)
|
||||
return result
|
||||
|
||||
def get_episode_slice(self, idx):
|
||||
start_idx = 0
|
||||
if idx > 0:
|
||||
start_idx = self.episode_ends[idx - 1]
|
||||
end_idx = self.episode_ends[idx]
|
||||
return slice(start_idx, end_idx)
|
||||
|
||||
def get_steps_slice(self, start, stop, step=None, copy=False):
|
||||
_slice = slice(start, stop, step)
|
||||
|
||||
result = {}
|
||||
for key, value in self.data.items():
|
||||
x = value[_slice]
|
||||
if copy and isinstance(value, np.ndarray):
|
||||
x = x.copy()
|
||||
result[key] = x
|
||||
return result
|
||||
|
||||
# =========== chunking =============
|
||||
def get_chunks(self) -> dict:
|
||||
assert self.backend == "zarr"
|
||||
chunks = {}
|
||||
for key, value in self.data.items():
|
||||
chunks[key] = value.chunks
|
||||
return chunks
|
||||
|
||||
def set_chunks(self, chunks: dict):
|
||||
assert self.backend == "zarr"
|
||||
for key, value in chunks.items():
|
||||
if key in self.data:
|
||||
arr = self.data[key]
|
||||
if value != arr.chunks:
|
||||
check_chunks_compatible(chunks=value, shape=arr.shape)
|
||||
rechunk_recompress_array(self.data, key, chunks=value)
|
||||
|
||||
def get_compressors(self) -> dict:
|
||||
assert self.backend == "zarr"
|
||||
compressors = {}
|
||||
for key, value in self.data.items():
|
||||
compressors[key] = value.compressor
|
||||
return compressors
|
||||
|
||||
def set_compressors(self, compressors: dict):
|
||||
assert self.backend == "zarr"
|
||||
for key, value in compressors.items():
|
||||
if key in self.data:
|
||||
arr = self.data[key]
|
||||
compressor = self.resolve_compressor(value)
|
||||
if compressor != arr.compressor:
|
||||
rechunk_recompress_array(self.data, key, compressor=compressor)
|
||||
@@ -1,179 +0,0 @@
|
||||
"""
|
||||
This file contains all obsolete download scripts. They are centralized here to not have to load
|
||||
useless dependencies when using datasets.
|
||||
"""
|
||||
|
||||
import io
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
|
||||
|
||||
def download_raw(root, dataset_id) -> Path:
|
||||
if "pusht" in dataset_id:
|
||||
return download_pusht(root=root, dataset_id=dataset_id)
|
||||
elif "xarm" in dataset_id:
|
||||
return download_xarm(root=root, dataset_id=dataset_id)
|
||||
elif "aloha" in dataset_id:
|
||||
return download_aloha(root=root, dataset_id=dataset_id)
|
||||
elif "umi" in dataset_id:
|
||||
return download_umi(root=root, dataset_id=dataset_id)
|
||||
else:
|
||||
raise ValueError(dataset_id)
|
||||
|
||||
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
import zipfile
|
||||
|
||||
import requests
|
||||
|
||||
print(f"downloading from {url}")
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||
|
||||
zip_file = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
zip_file.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
zip_file.seek(0)
|
||||
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_folder)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def download_pusht(root: str, dataset_id: str = "pusht", fps: int = 10) -> Path:
|
||||
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
pusht_zarr = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||
|
||||
root = Path(root)
|
||||
raw_dir: Path = root / f"{dataset_id}_raw"
|
||||
zarr_path: Path = (raw_dir / pusht_zarr).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(pusht_url, raw_dir)
|
||||
return zarr_path
|
||||
|
||||
|
||||
def download_xarm(root: str, dataset_id: str, fps: int = 15) -> Path:
|
||||
root = Path(root)
|
||||
raw_dir: Path = root / "xarm_datasets_raw"
|
||||
if not raw_dir.exists():
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
zip_path = raw_dir / "data.zip"
|
||||
gdown.download(url, str(zip_path), quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
||||
for member in zip_f.namelist():
|
||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||
print(member)
|
||||
zip_f.extract(member=member)
|
||||
zip_path.unlink()
|
||||
|
||||
dataset_path: Path = root / f"{dataset_id}"
|
||||
return dataset_path
|
||||
|
||||
|
||||
def download_aloha(root: str, dataset_id: str) -> Path:
|
||||
folder_urls = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
|
||||
}
|
||||
|
||||
ep48_urls = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
|
||||
}
|
||||
|
||||
ep49_urls = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
|
||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
|
||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
|
||||
}
|
||||
num_episodes = { # noqa: F841 # we keep this for reference
|
||||
"aloha_sim_insertion_human": 50,
|
||||
"aloha_sim_insertion_scripted": 50,
|
||||
"aloha_sim_transfer_cube_human": 50,
|
||||
"aloha_sim_transfer_cube_scripted": 50,
|
||||
}
|
||||
|
||||
episode_len = { # noqa: F841 # we keep this for reference
|
||||
"aloha_sim_insertion_human": 500,
|
||||
"aloha_sim_insertion_scripted": 400,
|
||||
"aloha_sim_transfer_cube_human": 400,
|
||||
"aloha_sim_transfer_cube_scripted": 400,
|
||||
}
|
||||
|
||||
cameras = { # noqa: F841 # we keep this for reference
|
||||
"aloha_sim_insertion_human": ["top"],
|
||||
"aloha_sim_insertion_scripted": ["top"],
|
||||
"aloha_sim_transfer_cube_human": ["top"],
|
||||
"aloha_sim_transfer_cube_scripted": ["top"],
|
||||
}
|
||||
root = Path(root)
|
||||
raw_dir: Path = root / f"{dataset_id}_raw"
|
||||
if not raw_dir.is_dir():
|
||||
import gdown
|
||||
|
||||
assert dataset_id in folder_urls
|
||||
assert dataset_id in ep48_urls
|
||||
assert dataset_id in ep49_urls
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
gdown.download_folder(folder_urls[dataset_id], output=str(raw_dir))
|
||||
|
||||
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
||||
gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True)
|
||||
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
return raw_dir
|
||||
|
||||
|
||||
def download_umi(root: str, dataset_id: str) -> Path:
|
||||
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
|
||||
cup_in_the_wild_zarr = Path("umi/cup_in_the_wild/cup_in_the_wild.zarr")
|
||||
|
||||
root = Path(root)
|
||||
raw_dir: Path = root / f"{dataset_id}_raw"
|
||||
zarr_path: Path = (raw_dir / cup_in_the_wild_zarr).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
|
||||
return zarr_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
root = "data"
|
||||
dataset_ids = [
|
||||
"pusht",
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
"umi_cup_in_the_wild",
|
||||
]
|
||||
for dataset_id in dataset_ids:
|
||||
download_raw(root=root, dataset_id=dataset_id)
|
||||
@@ -1,311 +0,0 @@
|
||||
# imagecodecs/numcodecs.py
|
||||
|
||||
# Copyright (c) 2021-2022, Christoph Gohlke
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
# POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# Copied from: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/codecs/imagecodecs_numcodecs.py#L1
|
||||
"""Additional numcodecs implemented using imagecodecs."""
|
||||
|
||||
__version__ = "2022.9.26"
|
||||
|
||||
__all__ = ("register_codecs",)
|
||||
|
||||
import imagecodecs
|
||||
import numpy
|
||||
from numcodecs.abc import Codec
|
||||
from numcodecs.registry import get_codec, register_codec
|
||||
|
||||
# TODO (azouitine): Remove useless codecs
|
||||
|
||||
|
||||
def protective_squeeze(x: numpy.ndarray):
|
||||
"""
|
||||
Squeeze dim only if it's not the last dim.
|
||||
Image dim expected to be *, H, W, C
|
||||
"""
|
||||
img_shape = x.shape[-3:]
|
||||
if len(x.shape) > 3:
|
||||
n_imgs = numpy.prod(x.shape[:-3])
|
||||
if n_imgs > 1:
|
||||
img_shape = (-1,) + img_shape
|
||||
return x.reshape(img_shape)
|
||||
|
||||
|
||||
def get_default_image_compressor(**kwargs):
|
||||
if imagecodecs.JPEGXL:
|
||||
# has JPEGXL
|
||||
this_kwargs = {
|
||||
"effort": 3,
|
||||
"distance": 0.3,
|
||||
# bug in libjxl, invalid codestream for non-lossless
|
||||
# when decoding speed > 1
|
||||
"decodingspeed": 1,
|
||||
}
|
||||
this_kwargs.update(kwargs)
|
||||
return JpegXl(**this_kwargs)
|
||||
else:
|
||||
this_kwargs = {"level": 50}
|
||||
this_kwargs.update(kwargs)
|
||||
return Jpeg2k(**this_kwargs)
|
||||
|
||||
|
||||
class Jpeg2k(Codec):
|
||||
"""JPEG 2000 codec for numcodecs."""
|
||||
|
||||
codec_id = "imagecodecs_jpeg2k"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
level=None,
|
||||
codecformat=None,
|
||||
colorspace=None,
|
||||
tile=None,
|
||||
reversible=None,
|
||||
bitspersample=None,
|
||||
resolutions=None,
|
||||
numthreads=None,
|
||||
verbose=0,
|
||||
):
|
||||
self.level = level
|
||||
self.codecformat = codecformat
|
||||
self.colorspace = colorspace
|
||||
self.tile = None if tile is None else tuple(tile)
|
||||
self.reversible = reversible
|
||||
self.bitspersample = bitspersample
|
||||
self.resolutions = resolutions
|
||||
self.numthreads = numthreads
|
||||
self.verbose = verbose
|
||||
|
||||
def encode(self, buf):
|
||||
buf = protective_squeeze(numpy.asarray(buf))
|
||||
return imagecodecs.jpeg2k_encode(
|
||||
buf,
|
||||
level=self.level,
|
||||
codecformat=self.codecformat,
|
||||
colorspace=self.colorspace,
|
||||
tile=self.tile,
|
||||
reversible=self.reversible,
|
||||
bitspersample=self.bitspersample,
|
||||
resolutions=self.resolutions,
|
||||
numthreads=self.numthreads,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
def decode(self, buf, out=None):
|
||||
return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out)
|
||||
|
||||
|
||||
class JpegXl(Codec):
|
||||
"""JPEG XL codec for numcodecs."""
|
||||
|
||||
codec_id = "imagecodecs_jpegxl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# encode
|
||||
level=None,
|
||||
effort=None,
|
||||
distance=None,
|
||||
lossless=None,
|
||||
decodingspeed=None,
|
||||
photometric=None,
|
||||
planar=None,
|
||||
usecontainer=None,
|
||||
# decode
|
||||
index=None,
|
||||
keeporientation=None,
|
||||
# both
|
||||
numthreads=None,
|
||||
):
|
||||
"""
|
||||
Return JPEG XL image from numpy array.
|
||||
Float must be in nominal range 0..1.
|
||||
|
||||
Currently L, LA, RGB, RGBA images are supported in contig mode.
|
||||
Extra channels are only supported for grayscale images in planar mode.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
level : Default to None, i.e. not overwriting lossess and decodingspeed options.
|
||||
When < 0: Use lossless compression
|
||||
When in [0,1,2,3,4]: Sets the decoding speed tier for the provided options.
|
||||
Minimum is 0 (slowest to decode, best quality/density), and maximum
|
||||
is 4 (fastest to decode, at the cost of some quality/density).
|
||||
effort : Default to 3.
|
||||
Sets encoder effort/speed level without affecting decoding speed.
|
||||
Valid values are, from faster to slower speed: 1:lightning 2:thunder
|
||||
3:falcon 4:cheetah 5:hare 6:wombat 7:squirrel 8:kitten 9:tortoise.
|
||||
Speed: lightning, thunder, falcon, cheetah, hare, wombat, squirrel, kitten, tortoise
|
||||
control the encoder effort in ascending order.
|
||||
This also affects memory usage: using lower effort will typically reduce memory
|
||||
consumption during encoding.
|
||||
lightning and thunder are fast modes useful for lossless mode (modular).
|
||||
falcon disables all of the following tools.
|
||||
cheetah enables coefficient reordering, context clustering, and heuristics for selecting DCT sizes and quantization steps.
|
||||
hare enables Gaborish filtering, chroma from luma, and an initial estimate of quantization steps.
|
||||
wombat enables error diffusion quantization and full DCT size selection heuristics.
|
||||
squirrel (default) enables dots, patches, and spline detection, and full context clustering.
|
||||
kitten optimizes the adaptive quantization for a psychovisual metric.
|
||||
tortoise enables a more thorough adaptive quantization search.
|
||||
distance : Default to 1.0
|
||||
Sets the distance level for lossy compression: target max butteraugli distance,
|
||||
lower = higher quality. Range: 0 .. 15. 0.0 = mathematically lossless
|
||||
(however, use JxlEncoderSetFrameLossless instead to use true lossless,
|
||||
as setting distance to 0 alone is not the only requirement).
|
||||
1.0 = visually lossless. Recommended range: 0.5 .. 3.0.
|
||||
lossess : Default to False.
|
||||
Use lossess encoding.
|
||||
decodingspeed : Default to 0.
|
||||
Duplicate to level. [0,4]
|
||||
photometric : Return JxlColorSpace value.
|
||||
Default logic is quite complicated but works most of the time.
|
||||
Accepted value:
|
||||
int: [-1,3]
|
||||
str: ['RGB',
|
||||
'WHITEISZERO', 'MINISWHITE',
|
||||
'BLACKISZERO', 'MINISBLACK', 'GRAY',
|
||||
'XYB', 'KNOWN']
|
||||
planar : Enable multi-channel mode.
|
||||
Default to false.
|
||||
usecontainer :
|
||||
Forces the encoder to use the box-based container format (BMFF)
|
||||
even when not necessary.
|
||||
When using JxlEncoderUseBoxes, JxlEncoderStoreJPEGMetadata or
|
||||
JxlEncoderSetCodestreamLevel with level 10, the encoder will
|
||||
automatically also use the container format, it is not necessary
|
||||
to use JxlEncoderUseContainer for those use cases.
|
||||
By default this setting is disabled.
|
||||
index : Selectively decode frames for animation.
|
||||
Default to 0, decode all frames.
|
||||
When set to > 0, decode that frame index only.
|
||||
keeporientation :
|
||||
Enables or disables preserving of as-in-bitstream pixeldata orientation.
|
||||
Some images are encoded with an Orientation tag indicating that the
|
||||
decoder must perform a rotation and/or mirroring to the encoded image data.
|
||||
|
||||
If skip_reorientation is JXL_FALSE (the default): the decoder will apply
|
||||
the transformation from the orientation setting, hence rendering the image
|
||||
according to its specified intent. When producing a JxlBasicInfo, the decoder
|
||||
will always set the orientation field to JXL_ORIENT_IDENTITY (matching the
|
||||
returned pixel data) and also align xsize and ysize so that they correspond
|
||||
to the width and the height of the returned pixel data.
|
||||
|
||||
If skip_reorientation is JXL_TRUE: the decoder will skip applying the
|
||||
transformation from the orientation setting, returning the image in
|
||||
the as-in-bitstream pixeldata orientation. This may be faster to decode
|
||||
since the decoder doesnt have to apply the transformation, but can
|
||||
cause wrong display of the image if the orientation tag is not correctly
|
||||
taken into account by the user.
|
||||
|
||||
By default, this option is disabled, and the returned pixel data is
|
||||
re-oriented according to the images Orientation setting.
|
||||
threads : Default to 1.
|
||||
If <= 0, use all cores.
|
||||
If > 32, clipped to 32.
|
||||
"""
|
||||
|
||||
self.level = level
|
||||
self.effort = effort
|
||||
self.distance = distance
|
||||
self.lossless = bool(lossless)
|
||||
self.decodingspeed = decodingspeed
|
||||
self.photometric = photometric
|
||||
self.planar = planar
|
||||
self.usecontainer = usecontainer
|
||||
self.index = index
|
||||
self.keeporientation = keeporientation
|
||||
self.numthreads = numthreads
|
||||
|
||||
def encode(self, buf):
|
||||
# TODO: only squeeze all but last dim
|
||||
buf = protective_squeeze(numpy.asarray(buf))
|
||||
return imagecodecs.jpegxl_encode(
|
||||
buf,
|
||||
level=self.level,
|
||||
effort=self.effort,
|
||||
distance=self.distance,
|
||||
lossless=self.lossless,
|
||||
decodingspeed=self.decodingspeed,
|
||||
photometric=self.photometric,
|
||||
planar=self.planar,
|
||||
usecontainer=self.usecontainer,
|
||||
numthreads=self.numthreads,
|
||||
)
|
||||
|
||||
def decode(self, buf, out=None):
|
||||
return imagecodecs.jpegxl_decode(
|
||||
buf,
|
||||
index=self.index,
|
||||
keeporientation=self.keeporientation,
|
||||
numthreads=self.numthreads,
|
||||
out=out,
|
||||
)
|
||||
|
||||
|
||||
def _flat(out):
|
||||
"""Return numpy array as contiguous view of bytes if possible."""
|
||||
if out is None:
|
||||
return None
|
||||
view = memoryview(out)
|
||||
if view.readonly or not view.contiguous:
|
||||
return None
|
||||
return view.cast("B")
|
||||
|
||||
|
||||
def register_codecs(codecs=None, force=False, verbose=True):
|
||||
"""Register codecs in this module with numcodecs."""
|
||||
for name, cls in globals().items():
|
||||
if not hasattr(cls, "codec_id") or name == "Codec":
|
||||
continue
|
||||
if codecs is not None and cls.codec_id not in codecs:
|
||||
continue
|
||||
try:
|
||||
try: # noqa: SIM105
|
||||
get_codec({"id": cls.codec_id})
|
||||
except TypeError:
|
||||
# registered, but failed
|
||||
pass
|
||||
except ValueError:
|
||||
# not registered yet
|
||||
pass
|
||||
else:
|
||||
if not force:
|
||||
if verbose:
|
||||
log_warning(f"numcodec {cls.codec_id!r} already registered")
|
||||
continue
|
||||
if verbose:
|
||||
log_warning(f"replacing registered numcodec {cls.codec_id!r}")
|
||||
register_codec(cls)
|
||||
|
||||
|
||||
def log_warning(msg, *args, **kwargs):
|
||||
"""Log message with level WARNING."""
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).warning(msg, *args, **kwargs)
|
||||
@@ -1,199 +0,0 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import h5py
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
class AlohaProcessor:
|
||||
"""
|
||||
Process HDF5 files formatted like in: https://github.com/tonyzhaozh/act
|
||||
|
||||
Attributes:
|
||||
folder_path (Path): Path to the directory containing HDF5 files.
|
||||
cameras (list[str]): List of camera identifiers to check in the files.
|
||||
fps (int): Frames per second used in timestamp calculations.
|
||||
|
||||
Methods:
|
||||
is_valid() -> bool:
|
||||
Validates if each HDF5 file within the folder contains all required datasets.
|
||||
preprocess() -> dict:
|
||||
Processes the files and returns structured data suitable for further analysis.
|
||||
to_hf_dataset(data_dict: dict) -> Dataset:
|
||||
Converts processed data into a Hugging Face Dataset object.
|
||||
"""
|
||||
|
||||
def __init__(self, folder_path: Path, cameras: list[str] | None = None, fps: int | None = None):
|
||||
"""
|
||||
Initializes the AlohaProcessor with a specified directory path containing HDF5 files,
|
||||
an optional list of cameras, and a frame rate.
|
||||
|
||||
Args:
|
||||
folder_path (Path): The directory path where HDF5 files are stored.
|
||||
cameras (list[str] | None): Optional list of cameras to validate within the files. Defaults to ['top'] if None.
|
||||
fps (int): Frame rate for the datasets, used in time calculations. Default is 50.
|
||||
|
||||
Examples:
|
||||
>>> processor = AlohaProcessor(Path("path_to_hdf5_directory"), ["camera1", "camera2"])
|
||||
>>> processor.is_valid()
|
||||
True
|
||||
"""
|
||||
self.folder_path = folder_path
|
||||
if cameras is None:
|
||||
cameras = ["top"]
|
||||
self.cameras = cameras
|
||||
if fps is None:
|
||||
fps = 50
|
||||
self._fps = fps
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
return self._fps
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""
|
||||
Validates the HDF5 files in the specified folder to ensure they contain the required datasets
|
||||
for actions, positions, and images for each specified camera.
|
||||
|
||||
Returns:
|
||||
bool: True if all files are valid HDF5 files with all required datasets, False otherwise.
|
||||
"""
|
||||
hdf5_files: list[Path] = list(self.folder_path.glob("episode_*.hdf5"))
|
||||
if len(hdf5_files) == 0:
|
||||
return False
|
||||
try:
|
||||
hdf5_files = sorted(
|
||||
hdf5_files, key=lambda x: int(re.search(r"episode_(\d+).hdf5", x.name).group(1))
|
||||
)
|
||||
except AttributeError:
|
||||
# All file names must contain a numerical identifier matching 'episode_(\\d+).hdf5
|
||||
return False
|
||||
|
||||
# Check if the sequence is consecutive eg episode_0, episode_1, episode_2, etc.
|
||||
# If not, return False
|
||||
previous_number = None
|
||||
for file in hdf5_files:
|
||||
current_number = int(re.search(r"episode_(\d+).hdf5", file.name).group(1))
|
||||
if previous_number is not None and current_number - previous_number != 1:
|
||||
return False
|
||||
previous_number = current_number
|
||||
|
||||
for file in hdf5_files:
|
||||
try:
|
||||
with h5py.File(file, "r") as file:
|
||||
# Check for the expected datasets within the HDF5 file
|
||||
required_datasets = ["/action", "/observations/qpos"]
|
||||
# Add camera-specific image datasets to the required datasets
|
||||
camera_datasets = [f"/observations/images/{cam}" for cam in self.cameras]
|
||||
required_datasets.extend(camera_datasets)
|
||||
|
||||
if not all(dataset in file for dataset in required_datasets):
|
||||
return False
|
||||
except OSError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def preprocess(self):
|
||||
"""
|
||||
Collects episode data from the HDF5 file and returns it as an AlohaStep named tuple.
|
||||
|
||||
Returns:
|
||||
AlohaStep: Named tuple containing episode data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the file is not valid.
|
||||
"""
|
||||
if not self.is_valid():
|
||||
raise ValueError("The HDF5 file is invalid or does not contain the required datasets.")
|
||||
|
||||
hdf5_files = list(self.folder_path.glob("*.hdf5"))
|
||||
hdf5_files = sorted(hdf5_files, key=lambda x: int(re.search(r"episode_(\d+)", x.name).group(1)))
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
|
||||
for ep_path in tqdm.tqdm(hdf5_files):
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
ep_id = int(re.search(r"episode_(\d+)", ep_path.name).group(1))
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
for cam in self.cameras:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
|
||||
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
|
||||
|
||||
ep_dict.update(
|
||||
{
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_index": torch.tensor([ep_id] * num_frames),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# TODO(rcadene): compute reward and success
|
||||
# "next.reward": reward,
|
||||
"next.done": done,
|
||||
# "next.success": success,
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(ep_id, int)
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
def to_hf_dataset(self, data_dict) -> Dataset:
|
||||
"""
|
||||
Converts a dictionary of data into a Hugging Face Dataset object.
|
||||
|
||||
Args:
|
||||
data_dict (dict): A dictionary containing the data to be converted.
|
||||
|
||||
Returns:
|
||||
Dataset: The converted Hugging Face Dataset object.
|
||||
"""
|
||||
image_features = {f"observation.images.{cam}": Image() for cam in self.cameras}
|
||||
features = {
|
||||
"observation.state": Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
# "next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
# "next.success": Value(dtype="bool", id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
}
|
||||
update_features = {**image_features, **features}
|
||||
features = Features(update_features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
return hf_dataset
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
@@ -1,180 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import zarr
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
class PushTProcessor:
|
||||
""" Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy
|
||||
"""
|
||||
def __init__(self, folder_path: Path, fps: int | None = None):
|
||||
self.zarr_path = folder_path
|
||||
if fps is None:
|
||||
fps = 10
|
||||
self._fps = fps
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
return self._fps
|
||||
|
||||
def is_valid(self):
|
||||
try:
|
||||
zarr_data = zarr.open(self.zarr_path, mode="r")
|
||||
except Exception:
|
||||
# TODO (azouitine): Handle the exception properly
|
||||
return False
|
||||
required_datasets = {
|
||||
"data/action",
|
||||
"data/img",
|
||||
"data/keypoint",
|
||||
"data/n_contacts",
|
||||
"data/state",
|
||||
"meta/episode_ends",
|
||||
}
|
||||
for dataset in required_datasets:
|
||||
if dataset not in zarr_data:
|
||||
return False
|
||||
nb_frames = zarr_data["data/img"].shape[0]
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
|
||||
return all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
def preprocess(self):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
# as define in env
|
||||
success_threshold = 0.95 # 95% coverage,
|
||||
|
||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
|
||||
self.zarr_path
|
||||
) # , keys=['img', 'state', 'action'])
|
||||
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
assert len(
|
||||
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
|
||||
# TODO: verify that goal pose is expected to be fixed
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
imgs = torch.from_numpy(dataset_dict["img"]) # b h w c
|
||||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
id_to = dataset_dict.meta["episode_ends"][episode_id]
|
||||
|
||||
num_frames = id_to - id_from
|
||||
|
||||
assert (episode_ids[id_from:id_to] == episode_id).all()
|
||||
|
||||
image = imgs[id_from:id_to]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
state = states[id_from:id_to]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
reward = torch.zeros(num_frames)
|
||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
||||
success[i] = coverage > success_threshold
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_dict = {
|
||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||
"observation.state": agent_pos,
|
||||
"action": actions[id_from:id_to],
|
||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": image[1:],
|
||||
# "next.observation.state": agent_pos[1:],
|
||||
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
def to_hf_dataset(self, data_dict):
|
||||
features = {
|
||||
"observation.image": Image(),
|
||||
"observation.state": Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
"next.success": Value(dtype="bool", id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
@@ -1,280 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import zarr
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
class UmiProcessor:
|
||||
"""
|
||||
Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface
|
||||
|
||||
Attributes:
|
||||
folder_path (str): The path to the folder containing Zarr datasets.
|
||||
fps (int): Frames per second, used to calculate timestamps for frames.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, folder_path: str, fps: int | None = None):
|
||||
self.zarr_path = folder_path
|
||||
if fps is None:
|
||||
# TODO (azouitine): Add reference to the paper
|
||||
fps = 15
|
||||
self._fps = fps
|
||||
register_codecs()
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
return self._fps
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""
|
||||
Validates the Zarr folder to ensure it contains all required datasets with consistent frame counts.
|
||||
|
||||
Returns:
|
||||
bool: True if all required datasets are present and have consistent frame counts, False otherwise.
|
||||
"""
|
||||
# Check if the Zarr folder is valid
|
||||
try:
|
||||
zarr_data = zarr.open(self.zarr_path, mode="r")
|
||||
except Exception:
|
||||
# TODO (azouitine): Handle the exception properly
|
||||
return False
|
||||
required_datasets = {
|
||||
"data/robot0_demo_end_pose",
|
||||
"data/robot0_demo_start_pose",
|
||||
"data/robot0_eef_pos",
|
||||
"data/robot0_eef_rot_axis_angle",
|
||||
"data/robot0_gripper_width",
|
||||
"meta/episode_ends",
|
||||
"data/camera0_rgb",
|
||||
}
|
||||
for dataset in required_datasets:
|
||||
if dataset not in zarr_data:
|
||||
return False
|
||||
nb_frames = zarr_data["data/camera0_rgb"].shape[0]
|
||||
|
||||
required_datasets.remove("meta/episode_ends")
|
||||
|
||||
return all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
def preprocess(self):
|
||||
"""
|
||||
Collects and processes all episodes from the Zarr dataset into structured data dictionaries.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: A tuple containing the structured episode data and episode index mappings.
|
||||
"""
|
||||
zarr_data = zarr.open(self.zarr_path, mode="r")
|
||||
|
||||
# We process the image data separately because it is too large to fit in memory
|
||||
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
|
||||
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
|
||||
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
|
||||
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
|
||||
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
|
||||
|
||||
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
|
||||
states = torch.cat([states_pos, gripper_width], dim=1)
|
||||
|
||||
episode_ends = zarr_data["meta/episode_ends"][:]
|
||||
num_episodes: int = episode_ends.shape[0]
|
||||
|
||||
episode_ids = torch.from_numpy(self.get_episode_idxs(episode_ends))
|
||||
|
||||
# We convert it in torch tensor later because the jit function does not support torch tensors
|
||||
episode_ends = torch.from_numpy(episode_ends)
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
id_from = 0
|
||||
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
id_to = episode_ends[episode_id]
|
||||
|
||||
num_frames = id_to - id_from
|
||||
|
||||
assert (
|
||||
episode_ids[id_from:id_to] == episode_id
|
||||
).all(), f"episode_ids[{id_from}:{id_to}] != {episode_id}"
|
||||
|
||||
state = states[id_from:id_to]
|
||||
ep_dict = {
|
||||
# observation.image will be filled later
|
||||
"observation.state": state,
|
||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||
"end_pose": end_pose[id_from:id_to],
|
||||
"start_pos": start_pos[id_from:id_to],
|
||||
"gripper_width": gripper_width[id_from:id_to],
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
id_from += num_frames
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = id_from
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
print("Saving images to disk in temporary folder...")
|
||||
# datasets.Image() can take a list of paths to images, so we save the images to a temporary folder
|
||||
# to avoid loading them all in memory
|
||||
_save_images_concurrently(
|
||||
data=zarr_data, image_key="data/camera0_rgb", folder_path="tmp_umi_images", max_workers=12
|
||||
)
|
||||
print("Saving images to disk in temporary folder... Done")
|
||||
|
||||
# Sort files by number eg. 1.png, 2.png, 3.png, 9.png, 10.png instead of 1.png, 10.png, 2.png, 3.png, 9.png
|
||||
# to correctly match the images with the data
|
||||
images_path = sorted(
|
||||
glob("tmp_umi_images/*"), key=lambda x: int(re.search(r"(\d+)\.png$", x).group(1))
|
||||
)
|
||||
data_dict["observation.image"] = images_path
|
||||
print("Images saved to disk, do not forget to delete the folder tmp_umi_images/")
|
||||
|
||||
# Cleanup
|
||||
return data_dict, episode_data_index
|
||||
|
||||
def to_hf_dataset(self, data_dict):
|
||||
"""
|
||||
Converts the processed data dictionary into a Hugging Face dataset with defined features.
|
||||
|
||||
Args:
|
||||
data_dict (Dict): The data dictionary containing tensors and episode information.
|
||||
|
||||
Returns:
|
||||
Dataset: A Hugging Face dataset constructed from the provided data dictionary.
|
||||
"""
|
||||
features = {
|
||||
"observation.image": Image(),
|
||||
"observation.state": Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_index_from": Value(dtype="int64", id=None),
|
||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||
# `start_pos` and `end_pos` respectively represent the positions of the end-effector
|
||||
# at the beginning and the end of the episode.
|
||||
# `gripper_width` indicates the distance between the grippers, and this value is included
|
||||
# in the state vector, which comprises the concatenation of the end-effector position
|
||||
# and gripper width.
|
||||
"end_pose": Sequence(
|
||||
length=data_dict["end_pose"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"start_pos": Sequence(
|
||||
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"gripper_width": Sequence(
|
||||
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
return hf_dataset
|
||||
|
||||
def cleanup(self):
|
||||
# Cleanup
|
||||
if os.path.exists("tmp_umi_images"):
|
||||
print("Removing temporary images folder")
|
||||
shutil.rmtree("tmp_umi_images")
|
||||
print("Cleanup done")
|
||||
|
||||
@classmethod
|
||||
def get_episode_idxs(cls, episode_ends: np.ndarray) -> np.ndarray:
|
||||
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
|
||||
from numba import jit
|
||||
|
||||
@jit(nopython=True)
|
||||
def _get_episode_idxs(episode_ends):
|
||||
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
||||
start_idx = 0
|
||||
for episode_number, end_idx in enumerate(episode_ends):
|
||||
result[start_idx:end_idx] = episode_number
|
||||
start_idx = end_idx
|
||||
return result
|
||||
|
||||
return _get_episode_idxs(episode_ends)
|
||||
|
||||
|
||||
def _clear_folder(folder_path: str):
|
||||
"""
|
||||
Clears all the content of the specified folder. Creates the folder if it does not exist.
|
||||
|
||||
Args:
|
||||
folder_path (str): Path to the folder to clear.
|
||||
|
||||
Examples:
|
||||
>>> import os
|
||||
>>> os.makedirs('example_folder', exist_ok=True)
|
||||
>>> with open('example_folder/temp_file.txt', 'w') as f:
|
||||
... f.write('example')
|
||||
>>> clear_folder('example_folder')
|
||||
>>> os.listdir('example_folder')
|
||||
[]
|
||||
"""
|
||||
if os.path.exists(folder_path):
|
||||
for filename in os.listdir(folder_path):
|
||||
file_path = os.path.join(folder_path, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
except Exception as e:
|
||||
print(f"Failed to delete {file_path}. Reason: {e}")
|
||||
else:
|
||||
os.makedirs(folder_path)
|
||||
|
||||
|
||||
def _save_image(img_array: np.array, i: int, folder_path: str):
|
||||
"""
|
||||
Saves a single image to the specified folder.
|
||||
|
||||
Args:
|
||||
img_array (ndarray): The numpy array of the image.
|
||||
i (int): Index of the image, used for naming.
|
||||
folder_path (str): Path to the folder where the image will be saved.
|
||||
"""
|
||||
img = PILImage.fromarray(img_array)
|
||||
img_format = "PNG" if img_array.dtype == np.uint8 else "JPEG"
|
||||
img.save(os.path.join(folder_path, f"{i}.{img_format.lower()}"), quality=100)
|
||||
|
||||
|
||||
def _save_images_concurrently(data: dict, image_key: str, folder_path: str, max_workers: int = 4):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
"""
|
||||
Saves images from the zarr_data to the specified folder using multithreading.
|
||||
|
||||
Args:
|
||||
zarr_data (dict): A dictionary containing image data in an array format.
|
||||
folder_path (str): Path to the folder where images will be saved.
|
||||
max_workers (int): The maximum number of threads to use for saving images.
|
||||
"""
|
||||
num_images = len(data["data/camera0_rgb"])
|
||||
_clear_folder(folder_path) # Clear or create folder first
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[executor.submit(_save_image, data[image_key][i], i, folder_path) for i in range(num_images)]
|
||||
@@ -1,20 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
def concatenate_episodes(ep_dicts):
|
||||
data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
data_dict[key].append(x)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
@@ -1,145 +0,0 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
class XarmProcessor:
|
||||
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
|
||||
|
||||
def __init__(self, folder_path: str, fps: int | None = None):
|
||||
self.folder_path = Path(folder_path)
|
||||
self.keys = {"actions", "rewards", "dones", "masks"}
|
||||
self.nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
|
||||
if fps is None:
|
||||
fps = 15
|
||||
self._fps = fps
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
return self._fps
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
# get all .pkl files
|
||||
xarm_files = list(self.folder_path.glob("*.pkl"))
|
||||
if len(xarm_files) != 1:
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(xarm_files[0], "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if not isinstance(dataset_dict, dict):
|
||||
return False
|
||||
|
||||
if not all(k in dataset_dict for k in self.keys):
|
||||
return False
|
||||
|
||||
# Check for consistent lengths in nested keys
|
||||
try:
|
||||
expected_len = len(dataset_dict["actions"])
|
||||
if any(len(dataset_dict[key]) != expected_len for key in self.keys if key in dataset_dict):
|
||||
return False
|
||||
|
||||
for key, subkeys in self.nested_keys.items():
|
||||
nested_dict = dataset_dict.get(key, {})
|
||||
if any(
|
||||
len(nested_dict[subkey]) != expected_len for subkey in subkeys if subkey in nested_dict
|
||||
):
|
||||
return False
|
||||
except KeyError: # If any expected key or subkey is missing
|
||||
return False
|
||||
|
||||
return True # All checks passed
|
||||
|
||||
def preprocess(self):
|
||||
if not self.is_valid():
|
||||
raise ValueError("The Xarm file is invalid or does not contain the required datasets.")
|
||||
|
||||
xarm_files = list(self.folder_path.glob("*.pkl"))
|
||||
|
||||
with open(xarm_files[0], "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
id_to = 0
|
||||
episode_id = 0
|
||||
total_frames = dataset_dict["actions"].shape[0]
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
id_to += 1
|
||||
|
||||
if not dataset_dict["dones"][i]:
|
||||
continue
|
||||
|
||||
num_frames = id_to - id_from
|
||||
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to])
|
||||
image = einops.rearrange(image, "b c h w -> b h w c")
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to])
|
||||
action = torch.tensor(dataset_dict["actions"][id_from:id_to])
|
||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to])
|
||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to])
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to])
|
||||
next_done = torch.tensor(dataset_dict["dones"][id_from:id_to])
|
||||
|
||||
ep_dict = {
|
||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": next_image,
|
||||
# "next.observation.state": next_state,
|
||||
"next.reward": next_reward,
|
||||
"next.done": next_done,
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from = id_to
|
||||
episode_id += 1
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
def to_hf_dataset(self, data_dict):
|
||||
features = {
|
||||
"observation.image": Image(),
|
||||
"observation.state": Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
return hf_dataset
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
212
lerobot/common/datasets/pusht.py
Normal file
212
lerobot/common/datasets/pusht.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pygame
|
||||
import pymunk
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from diffusion_policy.common.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
from diffusion_policy.env.pusht.pusht_env import pymunk_to_shapely
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
|
||||
DEFAULT_TEE_MASK = pymunk.ShapeFilter.ALL_MASKS()
|
||||
PUSHT_URL = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
PUSHT_ZARR = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||
|
||||
|
||||
def get_goal_pose_body(pose):
|
||||
mass = 1
|
||||
inertia = pymunk.moment_for_box(mass, (50, 100))
|
||||
body = pymunk.Body(mass, inertia)
|
||||
# preserving the legacy assignment order for compatibility
|
||||
# the order here doesn't matter somehow, maybe because CoM is aligned with body origin
|
||||
body.position = pose[:2].tolist()
|
||||
body.angle = pose[2]
|
||||
return body
|
||||
|
||||
|
||||
def add_segment(space, a, b, radius):
|
||||
shape = pymunk.Segment(space.static_body, a, b, radius)
|
||||
shape.color = pygame.Color("LightGray") # https://htmlcolorcodes.com/color-names
|
||||
return shape
|
||||
|
||||
|
||||
def add_tee(
|
||||
space,
|
||||
position,
|
||||
angle,
|
||||
scale=30,
|
||||
color="LightSlateGray",
|
||||
mask=DEFAULT_TEE_MASK,
|
||||
):
|
||||
mass = 1
|
||||
length = 4
|
||||
vertices1 = [
|
||||
(-length * scale / 2, scale),
|
||||
(length * scale / 2, scale),
|
||||
(length * scale / 2, 0),
|
||||
(-length * scale / 2, 0),
|
||||
]
|
||||
inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||
vertices2 = [
|
||||
(-scale / 2, scale),
|
||||
(-scale / 2, length * scale),
|
||||
(scale / 2, length * scale),
|
||||
(scale / 2, scale),
|
||||
]
|
||||
inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1)
|
||||
body = pymunk.Body(mass, inertia1 + inertia2)
|
||||
shape1 = pymunk.Poly(body, vertices1)
|
||||
shape2 = pymunk.Poly(body, vertices2)
|
||||
shape1.color = pygame.Color(color)
|
||||
shape2.color = pygame.Color(color)
|
||||
shape1.filter = pymunk.ShapeFilter(mask=mask)
|
||||
shape2.filter = pymunk.ShapeFilter(mask=mask)
|
||||
body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2
|
||||
body.position = position
|
||||
body.angle = angle
|
||||
body.friction = 1
|
||||
space.add(body, shape1, shape2)
|
||||
return body
|
||||
|
||||
|
||||
class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
def _download_and_preproc(self):
|
||||
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
|
||||
zarr_path = (raw_dir / PUSHT_ZARR).resolve()
|
||||
if not zarr_path.is_dir():
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(PUSHT_URL, raw_dir)
|
||||
|
||||
# load
|
||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
|
||||
zarr_path
|
||||
) # , keys=['img', 'state', 'action'])
|
||||
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
total_frames = dataset_dict["action"].shape[0]
|
||||
assert len(
|
||||
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
|
||||
# TODO: verify that goal pose is expected to be fixed
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
goal_body = get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
imgs = torch.from_numpy(dataset_dict["img"])
|
||||
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
|
||||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
idx0 = 0
|
||||
idxtd = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||
|
||||
image = imgs[idx0:idx1]
|
||||
|
||||
state = states[idx0:idx1]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
reward = torch.zeros(num_frames, 1)
|
||||
success = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
add_segment(space, (5, 506), (5, 5), 2),
|
||||
add_segment(space, (5, 5), (506, 5), 2),
|
||||
add_segment(space, (506, 5), (506, 506), 2),
|
||||
add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body = add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / SUCCESS_THRESHOLD, 0, 1)
|
||||
success[i] = coverage > SUCCESS_THRESHOLD
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "image"): image[:-1],
|
||||
("observation", "state"): agent_pos[:-1],
|
||||
"action": actions[idx0:idx1][:-1],
|
||||
"episode": episode_ids[idx0:idx1][:-1],
|
||||
"frame_id": torch.arange(0, num_frames - 1, 1),
|
||||
("next", "observation", "image"): image[1:],
|
||||
("next", "observation", "state"): agent_pos[1:],
|
||||
# TODO: verify that reward and done are aligned with image and agent_pos
|
||||
("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
("next", "success"): success[1:],
|
||||
},
|
||||
batch_size=num_frames - 1,
|
||||
)
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir)
|
||||
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
|
||||
idx0 = idx1
|
||||
idxtd = idxtd + len(ep_td)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
120
lerobot/common/datasets/simxarm.py
Normal file
120
lerobot/common/datasets/simxarm.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import pickle
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
SliceSampler,
|
||||
)
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
|
||||
|
||||
def download():
|
||||
raise NotImplementedError()
|
||||
import gdown
|
||||
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
download_path = "data.zip"
|
||||
gdown.download(url, download_path, quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(download_path, "r") as zip_f:
|
||||
for member in zip_f.namelist():
|
||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||
print(member)
|
||||
zip_f.extract(member=member)
|
||||
Path(download_path).unlink()
|
||||
|
||||
|
||||
class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
def _download_and_preproc(self):
|
||||
# TODO(rcadene): finish download
|
||||
download()
|
||||
|
||||
dataset_path = self.data_dir / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset_dict = pickle.load(f)
|
||||
|
||||
total_frames = dataset_dict["actions"].shape[0]
|
||||
|
||||
idx0 = 0
|
||||
idx1 = 0
|
||||
episode_id = 0
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
idx1 += 1
|
||||
|
||||
if not dataset_dict["dones"][i]:
|
||||
continue
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||
next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||
next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||
|
||||
episode = TensorDict(
|
||||
{
|
||||
("observation", "image"): image,
|
||||
("observation", "state"): state,
|
||||
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
("next", "observation", "image"): next_image,
|
||||
("next", "observation", "state"): next_state,
|
||||
("next", "observation", "reward"): next_reward,
|
||||
("next", "observation", "done"): next_done,
|
||||
},
|
||||
batch_size=num_frames,
|
||||
)
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir)
|
||||
|
||||
td_data[idx0:idx1] = episode
|
||||
|
||||
episode_id += 1
|
||||
idx0 = idx1
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
@@ -1,358 +1,30 @@
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
import io
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import einops
|
||||
import torch
|
||||
import requests
|
||||
import tqdm
|
||||
from datasets import Image, load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.torch import load_file
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="/"):
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
print(f"downloading from {url}")
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||
|
||||
For example:
|
||||
```
|
||||
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
|
||||
>>> print(flatten_dict(dct))
|
||||
{"a/b": 1, "a/c/d": 2, "e": 3}
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
zip_file = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
zip_file.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
def unflatten_dict(d, sep="/"):
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
d = outdict
|
||||
for part in parts[:-1]:
|
||||
if part not in d:
|
||||
d[part] = {}
|
||||
d = d[part]
|
||||
d[parts[-1]] = value
|
||||
return outdict
|
||||
zip_file.seek(0)
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||
with channel first (c h w) of float32 type in range [0,1].
|
||||
"""
|
||||
for key in items_dict:
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
else:
|
||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if root is not None:
|
||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / split))
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_folder)
|
||||
return True
|
||||
else:
|
||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
|
||||
"""episode_data_index contains the range of indices for each episode
|
||||
|
||||
Example:
|
||||
```python
|
||||
from_id = episode_data_index["from"][episode_id].item()
|
||||
to_id = episode_data_index["to"][episode_id].item()
|
||||
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
|
||||
)
|
||||
|
||||
return load_file(path)
|
||||
|
||||
|
||||
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||
|
||||
Example:
|
||||
```python
|
||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
|
||||
|
||||
stats = load_file(path)
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_info(repo_id, version, root) -> dict:
|
||||
"""info contains useful information regarding the dataset that are not stored elsewhere
|
||||
|
||||
Example:
|
||||
```python
|
||||
print("frame per second used to collect the video", info["fps"])
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "info.json"
|
||||
else:
|
||||
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
|
||||
|
||||
with open(path) as f:
|
||||
info = json.load(f)
|
||||
return info
|
||||
|
||||
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[str, torch.Tensor],
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
tol: float,
|
||||
) -> dict[torch.Tensor]:
|
||||
"""
|
||||
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
|
||||
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
|
||||
given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
|
||||
|
||||
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
|
||||
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
|
||||
the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function
|
||||
populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array
|
||||
is useful during batched training to not supervise actions associated to timestamps coming after the end of the
|
||||
episode, or to pad the observations in a specific way. Note that by default the observation frames before the start
|
||||
of the episode are the same as the first frame of the episode.
|
||||
|
||||
Parameters:
|
||||
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key
|
||||
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
|
||||
modality (e.g., "timestamp", "observation.image", "action").
|
||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
||||
They indicate the start index and end index of each episode in the dataset.
|
||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
|
||||
retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
||||
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
|
||||
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
|
||||
smallest expected inter-frame period, but large enough to account for jitter.
|
||||
|
||||
Returns:
|
||||
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for
|
||||
each modality (e.g. "observation.image_is_pad").
|
||||
|
||||
Raises:
|
||||
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization
|
||||
issues with timestamps during data collection.
|
||||
"""
|
||||
# get indices of the frames associated to the episode, and their timestamps
|
||||
ep_id = item["episode_index"].item()
|
||||
ep_data_id_from = episode_data_index["from"][ep_id].item()
|
||||
ep_data_id_to = episode_data_index["to"][ep_id].item()
|
||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||
|
||||
# load timestamps
|
||||
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||
ep_timestamps = torch.stack(ep_timestamps)
|
||||
|
||||
# we make the assumption that the timestamps are sorted
|
||||
ep_first_ts = ep_timestamps[0]
|
||||
ep_last_ts = ep_timestamps[-1]
|
||||
current_ts = item["timestamp"].item()
|
||||
|
||||
for key in delta_timestamps:
|
||||
# get timestamps used as query to retrieve data of previous/future frames
|
||||
delta_ts = delta_timestamps[key]
|
||||
query_ts = current_ts + torch.tensor(delta_ts)
|
||||
|
||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
||||
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||
|
||||
is_pad = min_ > tol
|
||||
|
||||
# check violated query timestamps are all outside the episode range
|
||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
)
|
||||
|
||||
# get dataset indices corresponding to frames to be loaded
|
||||
data_ids = ep_data_ids[argmin_]
|
||||
|
||||
# load frames modality
|
||||
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
||||
item[key] = torch.stack(item[key])
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def get_stats_einops_patterns(hf_dataset):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||
|
||||
Note: We assume the images of `hf_dataset` are in channel first format
|
||||
"""
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
hf_dataset,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
stats_patterns = {}
|
||||
for key, feats_type in hf_dataset.features.items():
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
if isinstance(feats_type, Image):
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
elif batch[key].ndim == 2:
|
||||
stats_patterns[key] = "b c -> c "
|
||||
elif batch[key].ndim == 1:
|
||||
stats_patterns[key] = "b -> 1"
|
||||
else:
|
||||
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
||||
|
||||
return stats_patterns
|
||||
|
||||
|
||||
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(hf_dataset)
|
||||
|
||||
stats_patterns = get_stats_einops_patterns(hf_dataset)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
for key in stats_patterns:
|
||||
mean[key] = torch.tensor(0.0).float()
|
||||
std[key] = torch.tensor(0.0).float()
|
||||
max[key] = torch.tensor(-float("inf")).float()
|
||||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
def create_seeded_dataloader(hf_dataset, batch_size, seed):
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
hf_dataset,
|
||||
num_workers=4,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
generator=generator,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||
# surprises when rerunning the sampler.
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
if first_batch is None:
|
||||
first_batch = deepcopy(batch)
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation.
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
# Sanity check to make sure the batches are still in the same order as before.
|
||||
if first_batch_ is None:
|
||||
first_batch_ = deepcopy(batch)
|
||||
for key in stats_patterns:
|
||||
assert torch.equal(first_batch_[key], first_batch[key])
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
for key in stats_patterns:
|
||||
std[key] = torch.sqrt(std[key])
|
||||
|
||||
stats = {}
|
||||
for key in stats_patterns:
|
||||
stats[key] = {
|
||||
"mean": mean[key],
|
||||
"std": std[key],
|
||||
"max": max[key],
|
||||
"min": min[key],
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def cycle(iterable):
|
||||
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
||||
|
||||
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
|
||||
"""
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
try:
|
||||
yield next(iterator)
|
||||
except StopIteration:
|
||||
iterator = iter(iterable)
|
||||
return False
|
||||
|
||||
0
lerobot/common/envs/__init__.py
Normal file
0
lerobot/common/envs/__init__.py
Normal file
75
lerobot/common/envs/abstract.py
Normal file
75
lerobot/common/envs/abstract.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import abc
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
from tensordict import TensorDict
|
||||
from torchrl.envs import EnvBase
|
||||
|
||||
|
||||
class AbstractEnv(EnvBase):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
self.num_prev_obs = num_prev_obs
|
||||
self.num_prev_action = num_prev_action
|
||||
self._rendering_hooks = []
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
|
||||
self._make_spec()
|
||||
self._current_seed = self.set_seed(seed)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||
if self.num_prev_action > 0:
|
||||
raise NotImplementedError()
|
||||
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||
|
||||
def register_rendering_hook(self, func):
|
||||
self._rendering_hooks.append(func)
|
||||
|
||||
def call_rendering_hooks(self):
|
||||
for func in self._rendering_hooks:
|
||||
func(self)
|
||||
|
||||
def reset_rendering_hooks(self):
|
||||
self._rendering_hooks = []
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self, mode="rgb_array", width=640, height=480):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _step(self, tensordict: TensorDict):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _make_spec(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,59 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
|
||||
<equality>
|
||||
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
</equality>
|
||||
|
||||
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
|
||||
<body name="peg" pos="0.2 0.5 0.05">
|
||||
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
<body name="socket" pos="-0.2 0.5 0.05">
|
||||
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
@@ -0,0 +1,48 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
|
||||
<equality>
|
||||
<weld body1="mocap_left" body2="vx300s_left/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
<weld body1="mocap_right" body2="vx300s_right/gripper_link" solref="0.01 1" solimp=".25 .25 0.001" />
|
||||
</equality>
|
||||
|
||||
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body mocap="true" name="mocap_left" pos="0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_left_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_left_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_left_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
<body mocap="true" name="mocap_right" pos="-0.095 0.50 0.425">
|
||||
<site pos="0 0 0" size="0.003 0.003 0.03" type="box" name="mocap_right_site1" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.003 0.03 0.003" type="box" name="mocap_right_site2" rgba="1 0 0 1"/>
|
||||
<site pos="0 0 0" size="0.03 0.003 0.003" type="box" name="mocap_right_site3" rgba="1 0 0 1"/>
|
||||
</body>
|
||||
|
||||
<body name="box" pos="0.2 0.5 0.05">
|
||||
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
@@ -0,0 +1,53 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body name="peg" pos="0.2 0.5 0.05">
|
||||
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
<body name="socket" pos="-0.2 0.5 0.05">
|
||||
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<!-- <geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg_ref" rgba="1 0 0 1" />-->
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0 -0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
@@ -0,0 +1,42 @@
|
||||
<mujoco>
|
||||
<include file="scene.xml"/>
|
||||
<include file="vx300s_dependencies.xml"/>
|
||||
<worldbody>
|
||||
<include file="vx300s_left.xml" />
|
||||
<include file="vx300s_right.xml" />
|
||||
|
||||
<body name="box" pos="0.2 0.5 0.05">
|
||||
<joint name="red_box_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.02 0.02 0.02" type="box" name="red_box" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_left/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_left/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_left/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_left/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_left/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_left/right_finger" kp="200" user="1"/>
|
||||
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/waist" kp="800" user="1" forcelimited="true" forcerange="-150 150"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.85005 1.25664" joint="vx300s_right/shoulder" kp="1600" user="1" forcelimited="true" forcerange="-300 300"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.76278 1.6057" joint="vx300s_right/elbow" kp="800" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/forearm_roll" kp="10" user="1" forcelimited="true" forcerange="-100 100"/>
|
||||
<position ctrllimited="true" ctrlrange="-1.8675 2.23402" joint="vx300s_right/wrist_angle" kp="50" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-3.14158 3.14158" joint="vx300s_right/wrist_rotate" kp="20" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="0.021 0.057" joint="vx300s_right/left_finger" kp="200" user="1"/>
|
||||
<position ctrllimited="true" ctrlrange="-0.057 -0.021" joint="vx300s_right/right_finger" kp="200" user="1"/>
|
||||
|
||||
</actuator>
|
||||
|
||||
<keyframe>
|
||||
<key qpos="0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0 -0.96 1.16 0 -0.3 0 0.024 -0.024 0.2 0.5 0.05 1 0 0 0"/>
|
||||
</keyframe>
|
||||
|
||||
|
||||
</mujoco>
|
||||
38
lerobot/common/envs/aloha/assets/scene.xml
Normal file
38
lerobot/common/envs/aloha/assets/scene.xml
Normal file
@@ -0,0 +1,38 @@
|
||||
<mujocoinclude>
|
||||
<!-- <option timestep='0.0025' iterations="50" tolerance="1e-10" solver="Newton" jacobian="dense" cone="elliptic"/>-->
|
||||
|
||||
<asset>
|
||||
<mesh file="tabletop.stl" name="tabletop" scale="0.001 0.001 0.001"/>
|
||||
</asset>
|
||||
|
||||
<visual>
|
||||
<map fogstart="1.5" fogend="5" force="0.1" znear="0.1"/>
|
||||
<quality shadowsize="4096" offsamples="4"/>
|
||||
<headlight ambient="0.4 0.4 0.4"/>
|
||||
</visual>
|
||||
|
||||
<worldbody>
|
||||
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='-1 -1 1'
|
||||
dir='1 1 -1'/>
|
||||
<light directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='1 -1 1' dir='-1 1 -1'/>
|
||||
<light castshadow="false" directional='true' diffuse='.3 .3 .3' specular='0.3 0.3 0.3' pos='0 1 1'
|
||||
dir='0 -1 -1'/>
|
||||
|
||||
<body name="table" pos="0 .6 0">
|
||||
<geom group="1" mesh="tabletop" pos="0 0 0" type="mesh" conaffinity="1" contype="1" name="table" rgba="0.2 0.2 0.2 1" />
|
||||
</body>
|
||||
<body name="midair" pos="0 .6 0.2">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="midair" rgba="1 0 0 0"/>
|
||||
</body>
|
||||
|
||||
<camera name="left_pillar" pos="-0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="right_pillar" pos="0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="top" pos="0 0.6 0.8" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="angle" pos="0 0 0.6" fovy="78" mode="targetbody" target="table"/>
|
||||
<camera name="front_close" pos="0 0.2 0.4" fovy="78" mode="targetbody" target="vx300s_left/camera_focus"/>
|
||||
|
||||
</worldbody>
|
||||
|
||||
|
||||
|
||||
</mujocoinclude>
|
||||
BIN
lerobot/common/envs/aloha/assets/tabletop.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/tabletop.stl
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_1_base.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_1_base.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl
Normal file
Binary file not shown.
BIN
lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl
Normal file
BIN
lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl
Normal file
Binary file not shown.
17
lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
Normal file
17
lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
Normal file
@@ -0,0 +1,17 @@
|
||||
<mujocoinclude>
|
||||
<compiler angle="radian" inertiafromgeom="auto" inertiagrouprange="4 5"/>
|
||||
<asset>
|
||||
<mesh name="vx300s_1_base" file="vx300s_1_base.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_2_shoulder" file="vx300s_2_shoulder.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_3_upper_arm" file="vx300s_3_upper_arm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_4_upper_forearm" file="vx300s_4_upper_forearm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_5_lower_forearm" file="vx300s_5_lower_forearm.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_6_wrist" file="vx300s_6_wrist.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_7_gripper" file="vx300s_7_gripper.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_8_gripper_prop" file="vx300s_8_gripper_prop.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_9_gripper_bar" file="vx300s_9_gripper_bar.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_10_gripper_finger_left" file="vx300s_10_custom_finger_left.stl" scale="0.001 0.001 0.001" />
|
||||
<mesh name="vx300s_10_gripper_finger_right" file="vx300s_10_custom_finger_right.stl" scale="0.001 0.001 0.001" />
|
||||
</asset>
|
||||
|
||||
</mujocoinclude>
|
||||
59
lerobot/common/envs/aloha/assets/vx300s_left.xml
Normal file
59
lerobot/common/envs/aloha/assets/vx300s_left.xml
Normal file
@@ -0,0 +1,59 @@
|
||||
|
||||
<mujocoinclude>
|
||||
<body name="vx300s_left" pos="-0.469 0.5 0">
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_left/1_base" contype="0" conaffinity="0"/>
|
||||
<body name="vx300s_left/shoulder_link" pos="0 0 0.079">
|
||||
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
|
||||
<joint name="vx300s_left/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
|
||||
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_left/2_shoulder" />
|
||||
<body name="vx300s_left/upper_arm_link" pos="0 0 0.04805">
|
||||
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
|
||||
<joint name="vx300s_left/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_left/3_upper_arm"/>
|
||||
<body name="vx300s_left/upper_forearm_link" pos="0.05955 0 0.3">
|
||||
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
|
||||
<joint name="vx300s_left/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
|
||||
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_left/4_upper_forearm" />
|
||||
<body name="vx300s_left/lower_forearm_link" pos="0.2 0 0">
|
||||
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
|
||||
<joint name="vx300s_left/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_left/5_lower_forearm"/>
|
||||
<body name="vx300s_left/wrist_link" pos="0.1 0 0">
|
||||
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
|
||||
<joint name="vx300s_left/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_left/6_wrist" />
|
||||
<body name="vx300s_left/gripper_link" pos="0.069744 0 0">
|
||||
<body name="vx300s_left/camera_focus" pos="0.15 0 0.01">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="left_cam_focus" rgba="0 0 1 0"/>
|
||||
</body>
|
||||
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_left_site1" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_left_site2" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_left_site3" rgba="0 0 1 0"/>
|
||||
<camera name="left_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_left/camera_focus"/>
|
||||
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
|
||||
<joint name="vx300s_left/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_left/7_gripper" />
|
||||
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_left/9_gripper_bar" />
|
||||
<body name="vx300s_left/gripper_prop_link" pos="0.0485 0 0">
|
||||
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
|
||||
<!-- <joint name="vx300s_left/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
|
||||
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_left/8_gripper_prop" />
|
||||
</body>
|
||||
<body name="vx300s_left/left_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
|
||||
<joint name="vx300s_left/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_left/10_left_gripper_finger"/>
|
||||
</body>
|
||||
<body name="vx300s_left/right_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
|
||||
<joint name="vx300s_left/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_left/10_right_gripper_finger"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</mujocoinclude>
|
||||
59
lerobot/common/envs/aloha/assets/vx300s_right.xml
Normal file
59
lerobot/common/envs/aloha/assets/vx300s_right.xml
Normal file
@@ -0,0 +1,59 @@
|
||||
|
||||
<mujocoinclude>
|
||||
<body name="vx300s_right" pos="0.469 0.5 0" euler="0 0 3.1416">
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_1_base" name="vx300s_right/1_base" contype="0" conaffinity="0"/>
|
||||
<body name="vx300s_right/shoulder_link" pos="0 0 0.079">
|
||||
<inertial pos="0.000259233 -3.3552e-06 0.0116129" quat="-0.476119 0.476083 0.52279 0.522826" mass="0.798614" diaginertia="0.00120156 0.00113744 0.0009388" />
|
||||
<joint name="vx300s_right/waist" pos="0 0 0" axis="0 0 1" limited="true" range="-3.14158 3.14158" frictionloss="50" />
|
||||
<geom pos="0 0 -0.003" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_2_shoulder" name="vx300s_right/2_shoulder" />
|
||||
<body name="vx300s_right/upper_arm_link" pos="0 0 0.04805">
|
||||
<inertial pos="0.0206949 4e-10 0.226459" quat="0 0.0728458 0 0.997343" mass="0.792592" diaginertia="0.00911338 0.008925 0.000759317" />
|
||||
<joint name="vx300s_right/shoulder" pos="0 0 0" axis="0 1 0" limited="true" range="-1.85005 1.25664" frictionloss="60" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_3_upper_arm" name="vx300s_right/3_upper_arm"/>
|
||||
<body name="vx300s_right/upper_forearm_link" pos="0.05955 0 0.3">
|
||||
<inertial pos="0.105723 0 0" quat="-0.000621631 0.704724 0.0105292 0.709403" mass="0.322228" diaginertia="0.00144107 0.00134228 0.000152047" />
|
||||
<joint name="vx300s_right/elbow" pos="0 0 0" axis="0 1 0" limited="true" range="-1.76278 1.6057" frictionloss="60" />
|
||||
<geom type="mesh" mesh="vx300s_4_upper_forearm" name="vx300s_right/4_upper_forearm" />
|
||||
<body name="vx300s_right/lower_forearm_link" pos="0.2 0 0">
|
||||
<inertial pos="0.0513477 0.00680462 0" quat="-0.702604 -0.0796724 -0.702604 0.0796724" mass="0.414823" diaginertia="0.0005911 0.000546493 0.000155707" />
|
||||
<joint name="vx300s_right/forearm_roll" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom quat="0 1 0 0" type="mesh" mesh="vx300s_5_lower_forearm" name="vx300s_right/5_lower_forearm"/>
|
||||
<body name="vx300s_right/wrist_link" pos="0.1 0 0">
|
||||
<inertial pos="0.046743 -7.6652e-06 0.010565" quat="-0.00100191 0.544586 0.0026583 0.8387" mass="0.115395" diaginertia="5.45707e-05 4.63101e-05 4.32692e-05" />
|
||||
<joint name="vx300s_right/wrist_angle" pos="0 0 0" axis="0 1 0" limited="true" range="-1.8675 2.23402" frictionloss="30" />
|
||||
<geom quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_6_wrist" name="vx300s_right/6_wrist" />
|
||||
<body name="vx300s_right/gripper_link" pos="0.069744 0 0">
|
||||
<body name="vx300s_right/camera_focus" pos="0.15 0 0.01">
|
||||
<site pos="0 0 0" size="0.01" type="sphere" name="right_cam_focus" rgba="0 0 1 0"/>
|
||||
</body>
|
||||
<site pos="0.15 0 0" size="0.003 0.003 0.03" type="box" name="cali_right_site1" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.003 0.03 0.003" type="box" name="cali_right_site2" rgba="0 0 1 0"/>
|
||||
<site pos="0.15 0 0" size="0.03 0.003 0.003" type="box" name="cali_right_site3" rgba="0 0 1 0"/>
|
||||
<camera name="right_wrist" pos="-0.1 0 0.16" fovy="20" mode="targetbody" target="vx300s_right/camera_focus"/>
|
||||
<inertial pos="0.0395662 -2.56311e-07 0.00400649" quat="0.62033 0.619916 -0.339682 0.339869" mass="0.251652" diaginertia="0.000689546 0.000650316 0.000468142" />
|
||||
<joint name="vx300s_right/wrist_rotate" pos="0 0 0" axis="1 0 0" limited="true" range="-3.14158 3.14158" frictionloss="30" />
|
||||
<geom pos="-0.02 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_7_gripper" name="vx300s_right/7_gripper" />
|
||||
<geom pos="-0.020175 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_9_gripper_bar" name="vx300s_right/9_gripper_bar" />
|
||||
<body name="vx300s_right/gripper_prop_link" pos="0.0485 0 0">
|
||||
<inertial pos="0.002378 2.85e-08 0" quat="0 0 0.897698 0.440611" mass="0.008009" diaginertia="4.2979e-06 2.8868e-06 1.5314e-06" />
|
||||
<!-- <joint name="vx300s_right/gripper" pos="0 0 0" axis="1 0 0" frictionloss="30" />-->
|
||||
<geom pos="-0.0685 0 0" quat="0.707107 0 0 0.707107" type="mesh" mesh="vx300s_8_gripper_prop" name="vx300s_right/8_gripper_prop" />
|
||||
</body>
|
||||
<body name="vx300s_right/left_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 -0.0060692 0" quat="0.449364 0.449364 -0.54596 -0.54596" mass="0.034796" diaginertia="2.48003e-05 1.417e-05 1.20797e-05" />
|
||||
<joint name="vx300s_right/left_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="0.021 0.057" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 -0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_left" name="vx300s_right/10_left_gripper_finger"/>
|
||||
</body>
|
||||
<body name="vx300s_right/right_finger_link" pos="0.0687 0 0">
|
||||
<inertial pos="0.017344 0.0060692 0" quat="0.44937 -0.44937 0.545955 -0.545955" mass="0.034796" diaginertia="2.48002e-05 1.417e-05 1.20798e-05" />
|
||||
<joint name="vx300s_right/right_finger" pos="0 0 0" axis="0 1 0" type="slide" limited="true" range="-0.057 -0.021" frictionloss="30" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0.005 0.052 0" euler="3.14 1.57 0" type="mesh" mesh="vx300s_10_gripper_finger_right" name="vx300s_right/10_right_gripper_finger"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</mujocoinclude>
|
||||
163
lerobot/common/envs/aloha/constants.py
Normal file
163
lerobot/common/envs/aloha/constants.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from pathlib import Path
|
||||
|
||||
### Simulation envs fixed constants
|
||||
DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz
|
||||
FPS = 50
|
||||
|
||||
|
||||
JOINTS = [
|
||||
# absolute joint position
|
||||
"left_arm_waist",
|
||||
"left_arm_shoulder",
|
||||
"left_arm_elbow",
|
||||
"left_arm_forearm_roll",
|
||||
"left_arm_wrist_angle",
|
||||
"left_arm_wrist_rotate",
|
||||
# normalized gripper position 0: close, 1: open
|
||||
"left_arm_gripper",
|
||||
# absolute joint position
|
||||
"right_arm_waist",
|
||||
"right_arm_shoulder",
|
||||
"right_arm_elbow",
|
||||
"right_arm_forearm_roll",
|
||||
"right_arm_wrist_angle",
|
||||
"right_arm_wrist_rotate",
|
||||
# normalized gripper position 0: close, 1: open
|
||||
"right_arm_gripper",
|
||||
]
|
||||
|
||||
ACTIONS = [
|
||||
# position and quaternion for end effector
|
||||
"left_arm_waist",
|
||||
"left_arm_shoulder",
|
||||
"left_arm_elbow",
|
||||
"left_arm_forearm_roll",
|
||||
"left_arm_wrist_angle",
|
||||
"left_arm_wrist_rotate",
|
||||
# normalized gripper position (0: close, 1: open)
|
||||
"left_arm_gripper",
|
||||
"right_arm_waist",
|
||||
"right_arm_shoulder",
|
||||
"right_arm_elbow",
|
||||
"right_arm_forearm_roll",
|
||||
"right_arm_wrist_angle",
|
||||
"right_arm_wrist_rotate",
|
||||
# normalized gripper position (0: close, 1: open)
|
||||
"right_arm_gripper",
|
||||
]
|
||||
|
||||
|
||||
START_ARM_POSE = [
|
||||
0,
|
||||
-0.96,
|
||||
1.16,
|
||||
0,
|
||||
-0.3,
|
||||
0,
|
||||
0.02239,
|
||||
-0.02239,
|
||||
0,
|
||||
-0.96,
|
||||
1.16,
|
||||
0,
|
||||
-0.3,
|
||||
0,
|
||||
0.02239,
|
||||
-0.02239,
|
||||
]
|
||||
|
||||
ASSETS_DIR = Path(__file__).parent.resolve() / "assets" # note: absolute path
|
||||
|
||||
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||
|
||||
# Gripper joint limits (qpos[6])
|
||||
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
||||
|
||||
############################ Helper functions ############################
|
||||
|
||||
|
||||
def normalize_master_gripper_position(x):
|
||||
return (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
||||
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def normalize_puppet_gripper_position(x):
|
||||
return (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
||||
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def unnormalize_master_gripper_position(x):
|
||||
return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||
|
||||
|
||||
def unnormalize_puppet_gripper_position(x):
|
||||
return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||
|
||||
|
||||
def convert_position_from_master_to_puppet(x):
|
||||
return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x))
|
||||
|
||||
|
||||
def normalizer_master_gripper_joint(x):
|
||||
return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
|
||||
|
||||
def normalize_puppet_gripper_joint(x):
|
||||
return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
|
||||
|
||||
def unnormalize_master_gripper_joint(x):
|
||||
return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
|
||||
|
||||
def unnormalize_puppet_gripper_joint(x):
|
||||
return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
|
||||
|
||||
def convert_join_from_master_to_puppet(x):
|
||||
return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x))
|
||||
|
||||
|
||||
def normalize_master_gripper_velocity(x):
|
||||
return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
|
||||
def normalize_puppet_gripper_velocity(x):
|
||||
return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
|
||||
def convert_master_from_position_to_joint(x):
|
||||
return (
|
||||
normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
+ MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def convert_master_from_joint_to_position(x):
|
||||
return unnormalize_master_gripper_position(
|
||||
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
|
||||
|
||||
def convert_puppet_from_position_to_join(x):
|
||||
return (
|
||||
normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
+ PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
|
||||
|
||||
def convert_puppet_from_joint_to_position(x):
|
||||
return unnormalize_puppet_gripper_position(
|
||||
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
306
lerobot/common/envs/aloha/env.py
Normal file
306
lerobot/common/envs/aloha/env.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import importlib
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
BoundedTensorSpec,
|
||||
CompositeSpec,
|
||||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
|
||||
from lerobot.common.envs.abstract import AbstractEnv
|
||||
from lerobot.common.envs.aloha.constants import (
|
||||
ACTIONS,
|
||||
ASSETS_DIR,
|
||||
DT,
|
||||
JOINTS,
|
||||
)
|
||||
from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask
|
||||
from lerobot.common.envs.aloha.tasks.sim_end_effector import (
|
||||
InsertionEndEffectorTask,
|
||||
TransferCubeEndEffectorTask,
|
||||
)
|
||||
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
_has_gym = importlib.util.find_spec("gym") is not None
|
||||
|
||||
|
||||
class AlohaEnv(AbstractEnv):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=1,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(
|
||||
task=task,
|
||||
frame_skip=frame_skip,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=image_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
if not from_pixels:
|
||||
raise NotImplementedError()
|
||||
|
||||
self._env = self._make_env_task(task)
|
||||
|
||||
def render(self, mode="rgb_array", width=640, height=480):
|
||||
# TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
|
||||
image = self._env.physics.render(height=height, width=width, camera_id="top")
|
||||
return image
|
||||
|
||||
def _make_env_task(self, task_name):
|
||||
# time limit is controlled by StepCounter in env factory
|
||||
time_limit = float("inf")
|
||||
|
||||
if "sim_transfer_cube" in task_name:
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = TransferCubeTask(random=False)
|
||||
elif "sim_insertion" in task_name:
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = InsertionTask(random=False)
|
||||
elif "sim_end_effector_transfer_cube" in task_name:
|
||||
raise NotImplementedError()
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = TransferCubeEndEffectorTask(random=False)
|
||||
elif "sim_end_effector_insertion" in task_name:
|
||||
raise NotImplementedError()
|
||||
xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
|
||||
physics = mujoco.Physics.from_xml_path(str(xml_path))
|
||||
task = InsertionEndEffectorTask(random=False)
|
||||
else:
|
||||
raise NotImplementedError(task_name)
|
||||
|
||||
env = control.Environment(
|
||||
physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False
|
||||
)
|
||||
return env
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
image = torch.from_numpy(raw_obs["images"]["top"].copy())
|
||||
image = einops.rearrange(image, "h w c -> c h w")
|
||||
obs = {"image": image.type(torch.float32) / 255.0}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["state"] = torch.from_numpy(raw_obs["qpos"]).type(torch.float32)
|
||||
else:
|
||||
# TODO(rcadene):
|
||||
raise NotImplementedError()
|
||||
# obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||
|
||||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
|
||||
# TODO(rcadene): do not use global variable for this
|
||||
if "sim_transfer_cube" in self.task:
|
||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||
elif "sim_insertion" in self.task:
|
||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||
|
||||
raw_obs = self._env.reset()
|
||||
# TODO(rcadene): add assert
|
||||
# assert self._current_seed == self._env._seed
|
||||
|
||||
obs = self._format_raw_obs(raw_obs.observation)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.call_rendering_hooks()
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
# step expects shape=(4,) so we pad if necessary
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
|
||||
if action.ndim == 1:
|
||||
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
||||
else:
|
||||
if self.frame_skip > 1:
|
||||
raise NotImplementedError()
|
||||
|
||||
num_action_steps = action.shape[0]
|
||||
for i in range(num_action_steps):
|
||||
_, reward, discount, raw_obs = self._env.step(action[i])
|
||||
del discount # not used
|
||||
|
||||
# TOOD(rcadene): add an enum
|
||||
success = done = reward == 4
|
||||
sum_reward += reward
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"])
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
self.call_rendering_hooks()
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||
# succes and done are true when coverage > self.success_threshold in env
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([success], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
return td
|
||||
|
||||
def _make_spec(self):
|
||||
obs = {}
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if self.from_pixels:
|
||||
if isinstance(self.image_size, int):
|
||||
image_shape = (3, self.image_size, self.image_size)
|
||||
elif OmegaConf.is_list(self.image_size):
|
||||
assert len(self.image_size) == 3 # c h w
|
||||
assert self.image_size[0] == 3 # c is RGB
|
||||
image_shape = tuple(self.image_size)
|
||||
else:
|
||||
raise ValueError(self.image_size)
|
||||
if self.num_prev_obs > 0:
|
||||
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||
|
||||
obs["image"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=1,
|
||||
shape=image_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
state_shape = (len(JOINTS),)
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
# TODO: add low and high bounds
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
state_shape = (len(JOINTS),)
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
# TODO: add low and high bounds
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.observation_spec = CompositeSpec({"observation": obs})
|
||||
|
||||
# TODO(rcadene): valid when controling end effector?
|
||||
# action_space = self._env.action_spec()
|
||||
# self.action_spec = BoundedTensorSpec(
|
||||
# low=action_space.minimum,
|
||||
# high=action_space.maximum,
|
||||
# shape=action_space.shape,
|
||||
# dtype=torch.float32,
|
||||
# device=self.device,
|
||||
# )
|
||||
|
||||
# TODO(rcaene): add bounds (where are they????)
|
||||
self.action_spec = BoundedTensorSpec(
|
||||
shape=(len(ACTIONS)),
|
||||
low=-1,
|
||||
high=1,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||
shape=(1,),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.done_spec = CompositeSpec(
|
||||
{
|
||||
"done": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
"success": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
set_seed(seed)
|
||||
# TODO(rcadene): seed the env
|
||||
# self._env.seed(seed)
|
||||
logging.warning("Aloha env is not seeded")
|
||||
219
lerobot/common/envs/aloha/tasks/sim.py
Normal file
219
lerobot/common/envs/aloha/tasks/sim.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
from dm_control.suite import base
|
||||
|
||||
from lerobot.common.envs.aloha.constants import (
|
||||
START_ARM_POSE,
|
||||
normalize_puppet_gripper_position,
|
||||
normalize_puppet_gripper_velocity,
|
||||
unnormalize_puppet_gripper_position,
|
||||
)
|
||||
|
||||
BOX_POSE = [None] # to be changed from outside
|
||||
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with joint position control
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
|
||||
class BimanualViperXTask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
left_arm_action = action[:6]
|
||||
right_arm_action = action[7 : 7 + 6]
|
||||
normalized_left_gripper_action = action[6]
|
||||
normalized_right_gripper_action = action[7 + 6]
|
||||
|
||||
left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action)
|
||||
right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action)
|
||||
|
||||
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
|
||||
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
|
||||
|
||||
env_action = np.concatenate(
|
||||
[left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action]
|
||||
)
|
||||
super().before_step(env_action, physics)
|
||||
return
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos(physics)
|
||||
obs["qvel"] = self.get_qvel(physics)
|
||||
obs["env_state"] = self.get_env_state(physics)
|
||||
obs["images"] = {}
|
||||
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
|
||||
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
|
||||
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
|
||||
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7:] = BOX_POSE[0]
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = (
|
||||
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
)
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = (
|
||||
("socket-1", "table") in all_contact_pairs
|
||||
or ("socket-2", "table") in all_contact_pairs
|
||||
or ("socket-3", "table") in all_contact_pairs
|
||||
or ("socket-4", "table") in all_contact_pairs
|
||||
)
|
||||
peg_touch_socket = (
|
||||
("red_peg", "socket-1") in all_contact_pairs
|
||||
or ("red_peg", "socket-2") in all_contact_pairs
|
||||
or ("red_peg", "socket-3") in all_contact_pairs
|
||||
or ("red_peg", "socket-4") in all_contact_pairs
|
||||
)
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if (
|
||||
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
|
||||
): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
263
lerobot/common/envs/aloha/tasks/sim_end_effector.py
Normal file
263
lerobot/common/envs/aloha/tasks/sim_end_effector.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
from dm_control.suite import base
|
||||
|
||||
from lerobot.common.envs.aloha.constants import (
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
START_ARM_POSE,
|
||||
normalize_puppet_gripper_position,
|
||||
normalize_puppet_gripper_velocity,
|
||||
unnormalize_puppet_gripper_position,
|
||||
)
|
||||
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with end-effector control.
|
||||
Action space: [left_arm_pose (7), # position and quaternion for end effector
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_pose (7), # position and quaternion for end effector
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
|
||||
class BimanualViperXEndEffectorTask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
a_len = len(action) // 2
|
||||
action_left = action[:a_len]
|
||||
action_right = action[a_len:]
|
||||
|
||||
# set mocap position and quat
|
||||
# left
|
||||
np.copyto(physics.data.mocap_pos[0], action_left[:3])
|
||||
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], action_right[:3])
|
||||
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
|
||||
|
||||
# set gripper
|
||||
g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7])
|
||||
g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7])
|
||||
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
|
||||
|
||||
def initialize_robots(self, physics):
|
||||
# reset joint position
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
|
||||
# reset mocap to align with end effector
|
||||
# to obtain these numbers:
|
||||
# (1) make an ee_sim env and reset to the same start_pose
|
||||
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
|
||||
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
|
||||
# repeat the same for right side
|
||||
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
|
||||
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
|
||||
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
|
||||
|
||||
# reset gripper control
|
||||
close_gripper_control = np.array(
|
||||
[
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
]
|
||||
)
|
||||
np.copyto(physics.data.ctrl, close_gripper_control)
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
# note: it is important to do .copy()
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos(physics)
|
||||
obs["qvel"] = self.get_qvel(physics)
|
||||
obs["env_state"] = self.get_env_state(physics)
|
||||
obs["images"] = {}
|
||||
obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
|
||||
obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
|
||||
obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
|
||||
# used in scripted policy to obtain starting pose
|
||||
obs["mocap_pose_left"] = np.concatenate(
|
||||
[physics.data.mocap_pos[0], physics.data.mocap_quat[0]]
|
||||
).copy()
|
||||
obs["mocap_pose_right"] = np.concatenate(
|
||||
[physics.data.mocap_pos[1], physics.data.mocap_quat[1]]
|
||||
).copy()
|
||||
|
||||
# used when replaying joint trajectory
|
||||
obs["gripper_ctrl"] = physics.data.ctrl.copy()
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize box position
|
||||
cube_pose = sample_box_pose()
|
||||
box_start_idx = physics.model.name2id("red_box_joint", "joint")
|
||||
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionEndEffectorTask(BimanualViperXEndEffectorTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize peg and socket position
|
||||
peg_pose, socket_pose = sample_insertion_pose()
|
||||
|
||||
def id2index(j_id):
|
||||
return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
|
||||
|
||||
peg_start_id = physics.model.name2id("red_peg_joint", "joint")
|
||||
peg_start_idx = id2index(peg_start_id)
|
||||
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
socket_start_id = physics.model.name2id("blue_socket_joint", "joint")
|
||||
socket_start_idx = id2index(socket_start_id)
|
||||
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, "geom")
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, "geom")
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = (
|
||||
("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
)
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = (
|
||||
("socket-1", "table") in all_contact_pairs
|
||||
or ("socket-2", "table") in all_contact_pairs
|
||||
or ("socket-3", "table") in all_contact_pairs
|
||||
or ("socket-4", "table") in all_contact_pairs
|
||||
)
|
||||
peg_touch_socket = (
|
||||
("red_peg", "socket-1") in all_contact_pairs
|
||||
or ("red_peg", "socket-2") in all_contact_pairs
|
||||
or ("red_peg", "socket-3") in all_contact_pairs
|
||||
or ("red_peg", "socket-4") in all_contact_pairs
|
||||
)
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if (
|
||||
touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
|
||||
): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
39
lerobot/common/envs/aloha/utils.py
Normal file
39
lerobot/common/envs/aloha/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def sample_box_pose():
|
||||
x_range = [0.0, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
cube_quat = np.array([1, 0, 0, 0])
|
||||
return np.concatenate([cube_position, cube_quat])
|
||||
|
||||
|
||||
def sample_insertion_pose():
|
||||
# Peg
|
||||
x_range = [0.1, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
peg_quat = np.array([1, 0, 0, 0])
|
||||
peg_pose = np.concatenate([peg_position, peg_quat])
|
||||
|
||||
# Socket
|
||||
x_range = [-0.2, -0.1]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
socket_quat = np.array([1, 0, 0, 0])
|
||||
socket_pose = np.concatenate([socket_position, socket_quat])
|
||||
|
||||
return peg_pose, socket_pose
|
||||
@@ -1,43 +1,67 @@
|
||||
import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||
|
||||
|
||||
def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||
"""
|
||||
Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and
|
||||
returns batched observation, reward, terminated, truncated of `num_parallel_envs` items.
|
||||
"""
|
||||
def make_env(cfg, transform=None):
|
||||
kwargs = {
|
||||
"obs_type": "pixels_agent_pos",
|
||||
"render_mode": "rgb_array",
|
||||
"max_episode_steps": cfg.env.episode_length,
|
||||
"visualization_width": 384,
|
||||
"visualization_height": 384,
|
||||
"frame_skip": cfg.env.action_repeat,
|
||||
"from_pixels": cfg.env.from_pixels,
|
||||
"pixels_only": cfg.env.pixels_only,
|
||||
"image_size": cfg.env.image_size,
|
||||
# TODO(rcadene): do we want a specific eval_env_seed?
|
||||
"seed": cfg.seed,
|
||||
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||
}
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
if cfg.env.name == "simxarm":
|
||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(
|
||||
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`"
|
||||
)
|
||||
raise e
|
||||
kwargs["task"] = cfg.env.task
|
||||
clsfunc = SimxarmEnv
|
||||
elif cfg.env.name == "pusht":
|
||||
from lerobot.common.envs.pusht import PushtEnv
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||
|
||||
if num_parallel_envs == 0:
|
||||
# non-batched version of the env that returns an observation of shape (c)
|
||||
env = gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
||||
clsfunc = PushtEnv
|
||||
elif cfg.env.name == "aloha":
|
||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||
|
||||
kwargs["task"] = cfg.env.task
|
||||
clsfunc = AlohaEnv
|
||||
else:
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env = gym.vector.SyncVectorEnv(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
||||
for _ in range(num_parallel_envs)
|
||||
]
|
||||
)
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
env = clsfunc(**kwargs)
|
||||
|
||||
# limit rollout to max_steps
|
||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
|
||||
|
||||
if transform is not None:
|
||||
# useful to add normalization
|
||||
env.append_transform(transform)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
# def make_env(env_name, frame_skip, device, is_test=False):
|
||||
# env = GymEnv(
|
||||
# env_name,
|
||||
# frame_skip=frame_skip,
|
||||
# from_pixels=True,
|
||||
# pixels_only=False,
|
||||
# device=device,
|
||||
# )
|
||||
# env = TransformedEnv(env)
|
||||
# env.append_transform(NoopResetEnv(noops=30, random=True))
|
||||
# if not is_test:
|
||||
# env.append_transform(EndOfLifeTransform())
|
||||
# env.append_transform(RewardClipping(-1, 1))
|
||||
# env.append_transform(ToTensorImage())
|
||||
# env.append_transform(GrayScale())
|
||||
# env.append_transform(Resize(84, 84))
|
||||
# env.append_transform(CatFrames(N=4, dim=-3))
|
||||
# env.append_transform(RewardSum())
|
||||
# env.append_transform(StepCounter(max_steps=4500))
|
||||
# env.append_transform(DoubleToFloat())
|
||||
# env.append_transform(VecNorm(in_keys=["pixels"]))
|
||||
# return env
|
||||
|
||||
242
lerobot/common/envs/pusht.py
Normal file
242
lerobot/common/envs/pusht.py
Normal file
@@ -0,0 +1,242 @@
|
||||
import importlib
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
BoundedTensorSpec,
|
||||
CompositeSpec,
|
||||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
_has_gym = importlib.util.find_spec("gym") is not None
|
||||
_has_diffpolicy = importlib.util.find_spec("diffusion_policy") is not None and _has_gym
|
||||
|
||||
|
||||
class PushtEnv(EnvBase):
|
||||
def __init__(
|
||||
self,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
num_prev_obs=0,
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
self.num_prev_obs = num_prev_obs
|
||||
self.num_prev_action = num_prev_action
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
|
||||
if not _has_diffpolicy:
|
||||
raise ImportError("Cannot import diffusion_policy.")
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
||||
# from diffusion_policy.env.pusht.pusht_env import PushTEnv
|
||||
|
||||
if not from_pixels:
|
||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||
|
||||
self._env = PushTImageEnv(render_size=self.image_size)
|
||||
|
||||
self._make_spec()
|
||||
self._current_seed = self.set_seed(seed)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||
self._prev_obs_state_queue = deque(maxlen=self.num_prev_obs)
|
||||
if self.num_prev_action > 0:
|
||||
raise NotImplementedError()
|
||||
# self._prev_action_queue = deque(maxlen=self.num_prev_action)
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
if width != height:
|
||||
raise NotImplementedError()
|
||||
tmp = self._env.render_size
|
||||
self._env.render_size = width
|
||||
out = self._env.render(mode)
|
||||
self._env.render_size = tmp
|
||||
return out
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
image = torch.from_numpy(raw_obs["image"])
|
||||
obs = {"image": image}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["state"] = torch.from_numpy(raw_obs["agent_pos"]).type(torch.float32)
|
||||
else:
|
||||
# TODO:
|
||||
obs = {"state": torch.from_numpy(raw_obs["observation"]).type(torch.float32)}
|
||||
|
||||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
raw_obs = self._env.reset()
|
||||
assert self._current_seed == self._env._seed
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
# step expects shape=(4,) so we pad if necessary
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
|
||||
if action.ndim == 1:
|
||||
action = einops.repeat(action, "c -> t c", t=self.frame_skip)
|
||||
else:
|
||||
if self.frame_skip > 1:
|
||||
raise NotImplementedError()
|
||||
|
||||
num_action_steps = action.shape[0]
|
||||
for i in range(num_action_steps):
|
||||
raw_obs, reward, done, info = self._env.step(action[i])
|
||||
sum_reward += reward
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"])
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||
# succes and done are true when coverage > self.success_threshold in env
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([done], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
return td
|
||||
|
||||
def _make_spec(self):
|
||||
obs = {}
|
||||
if self.from_pixels:
|
||||
image_shape = (3, self.image_size, self.image_size)
|
||||
if self.num_prev_obs > 0:
|
||||
image_shape = (self.num_prev_obs + 1, *image_shape)
|
||||
|
||||
obs["image"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=1,
|
||||
shape=image_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
state_shape = self._env.observation_space["agent_pos"].shape
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=512,
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
state_shape = self._env.observation_space["observation"].shape
|
||||
if self.num_prev_obs > 0:
|
||||
state_shape = (self.num_prev_obs + 1, *state_shape)
|
||||
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
# TODO:
|
||||
shape=state_shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.observation_spec = CompositeSpec({"observation": obs})
|
||||
|
||||
self.action_spec = _gym_to_torchrl_spec_transform(
|
||||
self._env.action_space,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||
shape=(1,),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.done_spec = CompositeSpec(
|
||||
{
|
||||
"done": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
"success": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
set_seed(seed)
|
||||
self._env.seed(seed)
|
||||
181
lerobot/common/envs/simxarm.py
Normal file
181
lerobot/common/envs/simxarm.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
BoundedTensorSpec,
|
||||
CompositeSpec,
|
||||
DiscreteTensorSpec,
|
||||
UnboundedContinuousTensorSpec,
|
||||
)
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
MAX_NUM_ACTIONS = 4
|
||||
|
||||
_has_gym = importlib.util.find_spec("gym") is not None
|
||||
_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym
|
||||
|
||||
|
||||
class SimxarmEnv(EnvBase):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
frame_skip: int = 1,
|
||||
from_pixels: bool = False,
|
||||
pixels_only: bool = False,
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
self.pixels_only = pixels_only
|
||||
self.image_size = image_size
|
||||
|
||||
if pixels_only:
|
||||
assert from_pixels
|
||||
if from_pixels:
|
||||
assert image_size
|
||||
|
||||
if not _has_simxarm:
|
||||
raise ImportError("Cannot import simxarm.")
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
import gym
|
||||
from simxarm import TASKS
|
||||
|
||||
if self.task not in TASKS:
|
||||
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
|
||||
|
||||
self._env = TASKS[self.task]["env"]()
|
||||
|
||||
num_actions = len(TASKS[self.task]["action_space"])
|
||||
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
|
||||
if "w" not in TASKS[self.task]["action_space"]:
|
||||
self._action_padding[-1] = 1.0
|
||||
|
||||
self._make_spec()
|
||||
self.set_seed(seed)
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
return self._env.render(mode, width=width, height=height)
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
|
||||
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||
image = torch.tensor(image.copy(), dtype=torch.uint8)
|
||||
|
||||
obs = {"image": image}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32)
|
||||
else:
|
||||
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
|
||||
|
||||
obs = TensorDict(obs, batch_size=[])
|
||||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
raw_obs = self._env.reset()
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
"done": torch.tensor([False], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return td
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
# step expects shape=(4,) so we pad if necessary
|
||||
action = np.concatenate([action, self._action_padding])
|
||||
# TODO(rcadene): add info["is_success"] and info["success"] ?
|
||||
sum_reward = 0
|
||||
for _ in range(self.frame_skip):
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
sum_reward += reward
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
"observation": self._format_raw_obs(raw_obs),
|
||||
"reward": torch.tensor([sum_reward], dtype=torch.float32),
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([info["success"]], dtype=torch.bool),
|
||||
},
|
||||
batch_size=[],
|
||||
)
|
||||
return td
|
||||
|
||||
def _make_spec(self):
|
||||
obs = {}
|
||||
if self.from_pixels:
|
||||
obs["image"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(3, self.image_size, self.image_size),
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
shape=(len(self._env.robot_state),),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
shape=self._env.observation_space["observation"].shape,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.observation_spec = CompositeSpec({"observation": obs})
|
||||
|
||||
self.action_spec = _gym_to_torchrl_spec_transform(
|
||||
self._action_space,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.reward_spec = UnboundedContinuousTensorSpec(
|
||||
shape=(1,),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.done_spec = CompositeSpec(
|
||||
{
|
||||
"done": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
"success": DiscreteTensorSpec(
|
||||
2,
|
||||
shape=(1,),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
set_seed(seed)
|
||||
self._env.seed(seed)
|
||||
92
lerobot/common/envs/transforms.py
Normal file
92
lerobot/common/envs/transforms.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Sequence
|
||||
|
||||
from tensordict import TensorDictBase
|
||||
from tensordict.nn import dispatch
|
||||
from tensordict.utils import NestedKey
|
||||
from torchrl.envs.transforms import ObservationTransform, Transform
|
||||
|
||||
|
||||
class Prod(ObservationTransform):
|
||||
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.prod = prod
|
||||
|
||||
def _call(self, td):
|
||||
for key in self.in_keys:
|
||||
td[key] *= self.prod
|
||||
return td
|
||||
|
||||
def transform_observation_spec(self, obs_spec):
|
||||
for key in self.in_keys:
|
||||
obs_spec[key].space.high *= self.prod
|
||||
return obs_spec
|
||||
|
||||
|
||||
class NormalizeTransform(Transform):
|
||||
invertible = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats: TensorDictBase,
|
||||
in_keys: Sequence[NestedKey] = None,
|
||||
out_keys: Sequence[NestedKey] | None = None,
|
||||
in_keys_inv: Sequence[NestedKey] | None = None,
|
||||
out_keys_inv: Sequence[NestedKey] | None = None,
|
||||
mode="mean_std",
|
||||
):
|
||||
if out_keys is None:
|
||||
out_keys = in_keys
|
||||
if in_keys_inv is None:
|
||||
in_keys_inv = out_keys
|
||||
if out_keys_inv is None:
|
||||
out_keys_inv = in_keys
|
||||
super().__init__(
|
||||
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
||||
)
|
||||
self.stats = stats
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
self.mode = mode
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
# normalize to [0,1]
|
||||
td[outkey] = (td[inkey] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
td[outkey] = td[outkey] * 2 - 1
|
||||
return td
|
||||
|
||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
td[outkey] = td[inkey] * std + mean
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
td[outkey] = (td[inkey] + 1) / 2
|
||||
td[outkey] = td[outkey] * (max - min) + min
|
||||
return td
|
||||
@@ -1,42 +0,0 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
|
||||
def preprocess_observation(observation):
|
||||
# map to expected inputs for the policy
|
||||
obs = {}
|
||||
|
||||
if isinstance(observation["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observation["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
obs[imgkey] = img
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
||||
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def postprocess_action(action):
|
||||
action = action.to("cpu").numpy()
|
||||
assert (
|
||||
action.ndim == 2
|
||||
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
|
||||
return action
|
||||
@@ -30,7 +30,6 @@ class Logger:
|
||||
self._model_dir = self._log_dir / "models"
|
||||
self._buffer_dir = self._log_dir / "buffers"
|
||||
self._save_model = cfg.save_model
|
||||
self._disable_wandb_artifact = cfg.wandb.disable_artifact
|
||||
self._save_buffer = cfg.save_buffer
|
||||
self._group = cfg_to_group(cfg)
|
||||
self._seed = cfg.seed
|
||||
@@ -39,7 +38,7 @@ class Logger:
|
||||
project = cfg.get("wandb", {}).get("project")
|
||||
entity = cfg.get("wandb", {}).get("entity")
|
||||
enable_wandb = cfg.get("wandb", {}).get("enable", False)
|
||||
run_offline = not enable_wandb or not project
|
||||
run_offline = not enable_wandb or not project or not entity
|
||||
if run_offline:
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
self._wandb = None
|
||||
@@ -64,7 +63,6 @@ class Logger:
|
||||
resume=None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
def save_model(self, policy, identifier):
|
||||
@@ -72,10 +70,9 @@ class Logger:
|
||||
self._model_dir.mkdir(parents=True, exist_ok=True)
|
||||
fp = self._model_dir / f"{str(identifier)}.pt"
|
||||
policy.save(fp)
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" in its name
|
||||
if self._wandb:
|
||||
artifact = self._wandb.Artifact(
|
||||
self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier),
|
||||
self._group + "-" + str(self._seed) + "-" + str(identifier),
|
||||
type="model",
|
||||
)
|
||||
artifact.add_file(fp)
|
||||
|
||||
0
lerobot/common/policies/__init__.py
Normal file
0
lerobot/common/policies/__init__.py
Normal file
115
lerobot/common/policies/act/backbone.py
Normal file
115
lerobot/common/policies/act/backbone.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
|
||||
from .position_encoding import build_position_encoding
|
||||
from .utils import NestedTensor, is_main_process
|
||||
|
||||
|
||||
class FrozenBatchNorm2d(torch.nn.Module):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
|
||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||
produce nans.
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", torch.ones(n))
|
||||
self.register_buffer("bias", torch.zeros(n))
|
||||
self.register_buffer("running_mean", torch.zeros(n))
|
||||
self.register_buffer("running_var", torch.ones(n))
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
num_batches_tracked_key = prefix + "num_batches_tracked"
|
||||
if num_batches_tracked_key in state_dict:
|
||||
del state_dict[num_batches_tracked_key]
|
||||
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# move reshapes to the beginning
|
||||
# to make it fuser-friendly
|
||||
w = self.weight.reshape(1, -1, 1, 1)
|
||||
b = self.bias.reshape(1, -1, 1, 1)
|
||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||
eps = 1e-5
|
||||
scale = w * (rv + eps).rsqrt()
|
||||
bias = b - rm * scale
|
||||
return x * scale + bias
|
||||
|
||||
|
||||
class BackboneBase(nn.Module):
|
||||
def __init__(
|
||||
self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool
|
||||
):
|
||||
super().__init__()
|
||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||
# parameter.requires_grad_(False)
|
||||
if return_interm_layers:
|
||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||
else:
|
||||
return_layers = {"layer4": "0"}
|
||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, tensor):
|
||||
xs = self.body(tensor)
|
||||
return xs
|
||||
# out: Dict[str, NestedTensor] = {}
|
||||
# for name, x in xs.items():
|
||||
# m = tensor_list.mask
|
||||
# assert m is not None
|
||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||
# out[name] = NestedTensor(x, mask)
|
||||
# return out
|
||||
|
||||
|
||||
class Backbone(BackboneBase):
|
||||
"""ResNet backbone with frozen BatchNorm."""
|
||||
|
||||
def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=is_main_process(),
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
) # pretrained # TODO do we want frozen batch_norm??
|
||||
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
class Joiner(nn.Sequential):
|
||||
def __init__(self, backbone, position_embedding):
|
||||
super().__init__(backbone, position_embedding)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
xs = self[0](tensor_list)
|
||||
out: List[NestedTensor] = []
|
||||
pos = []
|
||||
for _, x in xs.items():
|
||||
out.append(x)
|
||||
# position encoding
|
||||
pos.append(self[1](x).to(x.dtype))
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
def build_backbone(args):
|
||||
position_embedding = build_position_encoding(args)
|
||||
train_backbone = args.lr_backbone > 0
|
||||
return_interm_layers = args.masks
|
||||
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
||||
model = Joiner(backbone, position_embedding)
|
||||
model.num_channels = backbone.num_channels
|
||||
return model
|
||||
@@ -1,147 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionChunkingTransformerConfig:
|
||||
"""Configuration class for the Action Chunking Transformers policy.
|
||||
|
||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and 'output_shapes`.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
chunk_size: The size of the action prediction "chunks" in units of environment steps.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||
environment, and throws the other 50 out.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.images.top" refers to an input from the
|
||||
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two availables
|
||||
modes are "mean_std" which substracts the mean and divide by the standard
|
||||
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
||||
convolution.
|
||||
pre_norm: Whether to use "pre-norm" in the transformer blocks.
|
||||
d_model: The transformer blocks' main hidden dimension.
|
||||
n_heads: The number of heads to use in the transformer blocks' multi-head attention.
|
||||
dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
|
||||
layers.
|
||||
feedforward_activation: The activation to use in the transformer block's feed-forward layers.
|
||||
n_encoder_layers: The number of transformer layers to use for the transformer encoder.
|
||||
n_decoder_layers: The number of transformer layers to use for the transformer decoder.
|
||||
use_vae: Whether to use a variational objective during training. This introduces another transformer
|
||||
which is used as the VAE's encoder (not to be confused with the transformer encoder - see
|
||||
documentation in the policy class).
|
||||
latent_dim: The VAE's latent dimension.
|
||||
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
||||
use_temporal_aggregation: Whether to blend the actions of multiple policy invocations for any given
|
||||
environment step.
|
||||
dropout: Dropout to use in the transformer layers (see code for details).
|
||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||
"""
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
input_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": [3, 480, 640],
|
||||
"observation.state": [14],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [14],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "mean_std",
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
replace_final_stride_with_dilation: int = False
|
||||
# Transformer layers.
|
||||
pre_norm: bool = False
|
||||
d_model: int = 512
|
||||
n_heads: int = 8
|
||||
dim_feedforward: int = 3200
|
||||
feedforward_activation: str = "relu"
|
||||
n_encoder_layers: int = 4
|
||||
n_decoder_layers: int = 1
|
||||
# VAE.
|
||||
use_vae: bool = True
|
||||
latent_dim: int = 32
|
||||
n_vae_encoder_layers: int = 4
|
||||
|
||||
# Inference.
|
||||
use_temporal_aggregation: bool = False
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: float = 0.1
|
||||
kl_weight: float = 10.0
|
||||
|
||||
# ---
|
||||
# TODO(alexander-soare): Remove these from the policy config.
|
||||
batch_size: int = 8
|
||||
lr: float = 1e-5
|
||||
lr_backbone: float = 1e-5
|
||||
weight_decay: float = 1e-4
|
||||
grad_clip_norm: float = 10
|
||||
utd: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
if self.use_temporal_aggregation:
|
||||
raise NotImplementedError("Temporal aggregation is not yet implemented.")
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.n_obs_steps != 1:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
# Check that there is only one image.
|
||||
# TODO(alexander-soare): generalize this to multiple images.
|
||||
if (
|
||||
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
|
||||
or "observation.images.top" not in self.input_shapes
|
||||
):
|
||||
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')
|
||||
212
lerobot/common/policies/act/detr_vae.py
Normal file
212
lerobot/common/policies/act/detr_vae.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
from .backbone import build_backbone
|
||||
from .transformer import TransformerEncoder, TransformerEncoderLayer, build_transformer
|
||||
|
||||
|
||||
def reparametrize(mu, logvar):
|
||||
std = logvar.div(2).exp()
|
||||
eps = Variable(std.data.new(std.size()).normal_())
|
||||
return mu + std * eps
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
|
||||
class DETRVAE(nn.Module):
|
||||
"""This is the DETR module that performs object detection"""
|
||||
|
||||
def __init__(
|
||||
self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names, vae
|
||||
):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
self.vae = vae
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
else:
|
||||
# input_dim = 14 + 7 # robot_state + env_state
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
# TODO(rcadene): understand what is env_state, and why it needs to be 7
|
||||
self.input_proj_env_state = nn.Linear(state_dim // 2, hidden_dim)
|
||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||
self.backbones = None
|
||||
|
||||
# encoder extra parameters
|
||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
|
||||
self.latent_proj = nn.Linear(
|
||||
hidden_dim, self.latent_dim * 2
|
||||
) # project hidden state to latent std, var
|
||||
self.register_buffer(
|
||||
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
|
||||
) # [CLS], qpos, a_seq
|
||||
|
||||
# decoder extra parameters
|
||||
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
||||
self.additional_pos_embed = nn.Embedding(
|
||||
2, hidden_dim
|
||||
) # learned position embedding for proprio and latent
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
### Obtain latent z from action sequence
|
||||
if self.vae and is_training:
|
||||
# project action sequence to embedding dim, and concat with a CLS token
|
||||
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
|
||||
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
||||
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
||||
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
|
||||
encoder_input = torch.cat(
|
||||
[cls_embed, qpos_embed, action_embed], axis=1
|
||||
) # (bs, seq+1, hidden_dim)
|
||||
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
|
||||
# do not mask cls token
|
||||
# cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
|
||||
# is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||
# obtain position embedding
|
||||
pos_embed = self.pos_table.clone().detach()
|
||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||
# query model
|
||||
encoder_output = self.encoder(encoder_input, pos=pos_embed) # , src_key_padding_mask=is_pad)
|
||||
encoder_output = encoder_output[0] # take cls output only
|
||||
latent_info = self.latent_proj(encoder_output)
|
||||
mu = latent_info[:, : self.latent_dim]
|
||||
logvar = latent_info[:, self.latent_dim :]
|
||||
latent_sample = reparametrize(mu, logvar)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
else:
|
||||
mu = logvar = None
|
||||
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
if self.backbones is not None:
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
all_cam_pos = []
|
||||
for cam_id, _ in enumerate(self.camera_names):
|
||||
features, pos = self.backbones[0](image[:, cam_id]) # HARDCODED
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0]
|
||||
all_cam_features.append(self.input_proj(features))
|
||||
all_cam_pos.append(pos)
|
||||
# proprioception features
|
||||
proprio_input = self.input_proj_robot_state(qpos)
|
||||
# fold camera dimension into width dimension
|
||||
src = torch.cat(all_cam_features, axis=3)
|
||||
pos = torch.cat(all_cam_pos, axis=3)
|
||||
hs = self.transformer(
|
||||
src,
|
||||
None,
|
||||
self.query_embed.weight,
|
||||
pos,
|
||||
latent_input,
|
||||
proprio_input,
|
||||
self.additional_pos_embed.weight,
|
||||
)[0]
|
||||
else:
|
||||
qpos = self.input_proj_robot_state(qpos)
|
||||
env_state = self.input_proj_env_state(env_state)
|
||||
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
|
||||
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
|
||||
a_hat = self.action_head(hs)
|
||||
is_pad_hat = self.is_pad_head(hs)
|
||||
return a_hat, is_pad_hat, [mu, logvar]
|
||||
|
||||
|
||||
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||
if hidden_depth == 0:
|
||||
mods = [nn.Linear(input_dim, output_dim)]
|
||||
else:
|
||||
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
for _ in range(hidden_depth - 1):
|
||||
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
mods.append(nn.Linear(hidden_dim, output_dim))
|
||||
trunk = nn.Sequential(*mods)
|
||||
return trunk
|
||||
|
||||
|
||||
def build_encoder(args):
|
||||
d_model = args.hidden_dim # 256
|
||||
dropout = args.dropout # 0.1
|
||||
nhead = args.nheads # 8
|
||||
dim_feedforward = args.dim_feedforward # 2048
|
||||
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
|
||||
normalize_before = args.pre_norm # False
|
||||
activation = "relu"
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def build(args):
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
transformer = build_transformer(args)
|
||||
|
||||
encoder = build_encoder(args)
|
||||
|
||||
model = DETRVAE(
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim=args.state_dim,
|
||||
action_dim=args.action_dim,
|
||||
num_queries=args.num_queries,
|
||||
camera_names=args.camera_names,
|
||||
vae=args.vae,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: {:.2f}M".format(n_parameters / 1e6))
|
||||
|
||||
return model
|
||||
@@ -1,610 +0,0 @@
|
||||
"""Action Chunking Transformer Policy
|
||||
|
||||
As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705).
|
||||
The majority of changes here involve removing unused code, unifying naming, and adding helpful comments.
|
||||
"""
|
||||
|
||||
import math
|
||||
import time
|
||||
from collections import deque
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
|
||||
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
|
||||
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
|
||||
model that encodes the target data (a sequence of actions), and the condition (the robot
|
||||
joint-space).
|
||||
- A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with
|
||||
cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we
|
||||
have an option to train this model without the variational objective (in which case we drop the
|
||||
`vae_encoder` altogether, and nothing about this model has anything to do with a VAE).
|
||||
|
||||
Transformer
|
||||
Used alone for inference
|
||||
(acts as VAE decoder
|
||||
during training)
|
||||
┌───────────────────────┐
|
||||
│ Outputs │
|
||||
│ ▲ │
|
||||
│ ┌─────►┌───────┐ │
|
||||
┌──────┐ │ │ │Transf.│ │
|
||||
│ │ │ ├─────►│decoder│ │
|
||||
┌────┴────┐ │ │ │ │ │ │
|
||||
│ │ │ │ ┌───┴───┬─►│ │ │
|
||||
│ VAE │ │ │ │ │ └───────┘ │
|
||||
│ encoder │ │ │ │Transf.│ │
|
||||
│ │ │ │ │encoder│ │
|
||||
└───▲─────┘ │ │ │ │ │
|
||||
│ │ │ └───▲───┘ │
|
||||
│ │ │ │ │
|
||||
inputs └─────┼─────┘ │
|
||||
│ │
|
||||
└───────────────────────┘
|
||||
"""
|
||||
|
||||
name = "act"
|
||||
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
||||
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
if self.cfg.use_vae:
|
||||
self.vae_encoder = _TransformerEncoder(cfg)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
|
||||
# Projection layer for joint-space configuration to hidden dimension.
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
cfg.input_shapes["observation.state"][0], cfg.d_model
|
||||
)
|
||||
# Projection layer for action (joint-space target) to hidden dimension.
|
||||
self.vae_encoder_action_input_proj = nn.Linear(
|
||||
cfg.input_shapes["observation.state"][0], cfg.d_model
|
||||
)
|
||||
self.latent_dim = cfg.latent_dim
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(cfg.d_model, self.latent_dim * 2)
|
||||
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
_create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0),
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
||||
weights=cfg.pretrained_backbone_weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
|
||||
# map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = _TransformerEncoder(cfg)
|
||||
self.decoder = _TransformerDecoder(cfg)
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, robot_state, image_feature_map_pixels].
|
||||
self.encoder_robot_state_input_proj = nn.Linear(cfg.input_shapes["observation.state"][0], cfg.d_model)
|
||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, cfg.d_model)
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, cfg.d_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model)
|
||||
self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
self.decoder_pos_embed = nn.Embedding(cfg.chunk_size, cfg.d_model)
|
||||
|
||||
# Final action regression head on the output of the transformer's decoder.
|
||||
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
|
||||
|
||||
self._reset_parameters()
|
||||
self._create_optimizer()
|
||||
|
||||
def _create_optimizer(self):
|
||||
optimizer_params_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad
|
||||
],
|
||||
"lr": self.cfg.lr_backbone,
|
||||
},
|
||||
]
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
|
||||
)
|
||||
|
||||
def _reset_parameters(self):
|
||||
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
||||
for p in chain(self.encoder.parameters(), self.decoder.parameters()):
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.cfg.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.cfg.n_action_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss}
|
||||
if self.cfg.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.cfg.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
|
||||
return loss_dict
|
||||
|
||||
def update(self, batch, **_) -> dict:
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
self.train()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
loss_dict = self.forward(batch)
|
||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.cfg.lr,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
||||
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, state_dim) batch of robot states.
|
||||
"observation.images.{name}": (B, C, H, W) tensor of images.
|
||||
}
|
||||
"""
|
||||
# Stack images in the order dictated by input_shapes.
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[k] for k in self.cfg.input_shapes if k.startswith("observation.images.")],
|
||||
dim=-4,
|
||||
)
|
||||
|
||||
def _forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:
|
||||
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
||||
|
||||
`batch` should have the following structure:
|
||||
|
||||
{
|
||||
"observation.state": (B, state_dim) batch of robot states.
|
||||
"observation.images": (B, n_cameras, C, H, W) batch of images.
|
||||
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
||||
}
|
||||
|
||||
Returns:
|
||||
(B, chunk_size, action_dim) batch of action sequences
|
||||
Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the
|
||||
latent dimension.
|
||||
"""
|
||||
if self.cfg.use_vae and self.training:
|
||||
assert (
|
||||
"action" in batch
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
|
||||
self._stack_images(batch)
|
||||
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.cfg.use_vae and "action" in batch:
|
||||
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
|
||||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
|
||||
1
|
||||
) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
||||
|
||||
# Prepare fixed positional embedding.
|
||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
||||
|
||||
# Forward pass through VAE encoder to get the latent PDF parameters.
|
||||
cls_token_out = self.vae_encoder(
|
||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||
)[0] # select the class token, with shape (B, D)
|
||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||
mu = latent_pdf_params[:, : self.latent_dim]
|
||||
# This is 2log(sigma). Done this way to match the original implementation.
|
||||
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
||||
|
||||
# Sample the latent with the reparameterization trick.
|
||||
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
||||
else:
|
||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||
mu = log_sigma_x2 = None
|
||||
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
)
|
||||
|
||||
# Prepare all other transformer encoder inputs.
|
||||
# Camera observation features and positional embeddings.
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
images = batch["observation.images"]
|
||||
for cam_index in range(images.shape[-4]):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
|
||||
encoder_in = torch.cat(all_cam_features, axis=3)
|
||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=3)
|
||||
|
||||
# Get positional embeddings for robot state and latent.
|
||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"])
|
||||
latent_embed = self.encoder_latent_input_proj(latent_sample)
|
||||
|
||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||
encoder_in = torch.cat(
|
||||
[
|
||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
||||
encoder_in.flatten(2).permute(2, 0, 1),
|
||||
]
|
||||
)
|
||||
pos_embed = torch.cat(
|
||||
[
|
||||
self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1),
|
||||
cam_pos_embed.flatten(2).permute(2, 0, 1),
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
# Forward pass through the transformer modules.
|
||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||
decoder_in = torch.zeros(
|
||||
(self.cfg.chunk_size, batch_size, self.cfg.d_model),
|
||||
dtype=pos_embed.dtype,
|
||||
device=pos_embed.device,
|
||||
)
|
||||
decoder_out = self.decoder(
|
||||
decoder_in,
|
||||
encoder_out,
|
||||
encoder_pos_embed=pos_embed,
|
||||
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
|
||||
)
|
||||
|
||||
# Move back to (B, S, C).
|
||||
decoder_out = decoder_out.transpose(0, 1)
|
||||
|
||||
actions = self.action_head(decoder_out)
|
||||
|
||||
return actions, (mu, log_sigma_x2)
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
|
||||
class _TransformerEncoder(nn.Module):
|
||||
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
||||
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
|
||||
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
|
||||
|
||||
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(x, pos_embed=pos_embed)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class _TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
|
||||
self.dropout = nn.Dropout(cfg.dropout)
|
||||
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(cfg.d_model)
|
||||
self.norm2 = nn.LayerNorm(cfg.d_model)
|
||||
self.dropout1 = nn.Dropout(cfg.dropout)
|
||||
self.dropout2 = nn.Dropout(cfg.dropout)
|
||||
|
||||
self.activation = _get_activation_fn(cfg.feedforward_activation)
|
||||
self.pre_norm = cfg.pre_norm
|
||||
|
||||
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
|
||||
skip = x
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
q = k = x if pos_embed is None else x + pos_embed
|
||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
x = self.norm2(x)
|
||||
else:
|
||||
x = self.norm1(x)
|
||||
skip = x
|
||||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||
x = skip + self.dropout2(x)
|
||||
if not self.pre_norm:
|
||||
x = self.norm2(x)
|
||||
return x
|
||||
|
||||
|
||||
class _TransformerDecoder(nn.Module):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
|
||||
self.norm = nn.LayerNorm(cfg.d_model)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
encoder_out: Tensor,
|
||||
decoder_pos_embed: Tensor | None = None,
|
||||
encoder_pos_embed: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(
|
||||
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
||||
)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class _TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||
|
||||
# Feed forward layers.
|
||||
self.linear1 = nn.Linear(cfg.d_model, cfg.dim_feedforward)
|
||||
self.dropout = nn.Dropout(cfg.dropout)
|
||||
self.linear2 = nn.Linear(cfg.dim_feedforward, cfg.d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(cfg.d_model)
|
||||
self.norm2 = nn.LayerNorm(cfg.d_model)
|
||||
self.norm3 = nn.LayerNorm(cfg.d_model)
|
||||
self.dropout1 = nn.Dropout(cfg.dropout)
|
||||
self.dropout2 = nn.Dropout(cfg.dropout)
|
||||
self.dropout3 = nn.Dropout(cfg.dropout)
|
||||
|
||||
self.activation = _get_activation_fn(cfg.feedforward_activation)
|
||||
self.pre_norm = cfg.pre_norm
|
||||
|
||||
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
|
||||
return tensor if pos_embed is None else tensor + pos_embed
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
encoder_out: Tensor,
|
||||
decoder_pos_embed: Tensor | None = None,
|
||||
encoder_pos_embed: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
|
||||
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
|
||||
cross-attending with.
|
||||
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
|
||||
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
|
||||
Returns:
|
||||
(DS, B, C) tensor of decoder output features.
|
||||
"""
|
||||
skip = x
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
|
||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
x = self.norm2(x)
|
||||
else:
|
||||
x = self.norm1(x)
|
||||
skip = x
|
||||
x = self.multihead_attn(
|
||||
query=self.maybe_add_pos_embed(x, decoder_pos_embed),
|
||||
key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
|
||||
value=encoder_out,
|
||||
)[0] # select just the output, not the attention weights
|
||||
x = skip + self.dropout2(x)
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
x = self.norm3(x)
|
||||
else:
|
||||
x = self.norm2(x)
|
||||
skip = x
|
||||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||
x = skip + self.dropout3(x)
|
||||
if not self.pre_norm:
|
||||
x = self.norm3(x)
|
||||
return x
|
||||
|
||||
|
||||
def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
|
||||
"""1D sinusoidal positional embeddings as in Attention is All You Need.
|
||||
|
||||
Args:
|
||||
num_positions: Number of token positions required.
|
||||
Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension).
|
||||
|
||||
"""
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
return torch.from_numpy(sinusoid_table).float()
|
||||
|
||||
|
||||
class _SinusoidalPositionEmbedding2D(nn.Module):
|
||||
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
|
||||
|
||||
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
|
||||
for the vertical direction, and 1/W for the horizontal direction.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int):
|
||||
"""
|
||||
Args:
|
||||
dimension: The desired dimension of the embeddings.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dimension = dimension
|
||||
self._two_pi = 2 * math.pi
|
||||
self._eps = 1e-6
|
||||
# Inverse "common ratio" for the geometric progression in sinusoid frequencies.
|
||||
self._temperature = 10000
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.
|
||||
Returns:
|
||||
A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
|
||||
"""
|
||||
not_mask = torch.ones_like(x[0, :1]) # (1, H, W)
|
||||
# Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
|
||||
# they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
|
||||
y_range = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_range = not_mask.cumsum(2, dtype=torch.float32)
|
||||
|
||||
# "Normalize" the position index such that it ranges in [0, 2π].
|
||||
# Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
|
||||
# are non-zero by construction. This is an artifact of the original code.
|
||||
y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
|
||||
x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi
|
||||
|
||||
inverse_frequency = self._temperature ** (
|
||||
2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
|
||||
)
|
||||
|
||||
x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||
y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)
|
||||
|
||||
# Note: this stack then flatten operation results in interleaved sine and cosine terms.
|
||||
# pos_embed_x and pos_embed_y are (1, H, W, C // 2).
|
||||
pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
|
||||
pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W)
|
||||
|
||||
return pos_embed
|
||||
|
||||
|
||||
def _get_activation_fn(activation: str) -> Callable:
|
||||
"""Return an activation function given a string."""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
|
||||
218
lerobot/common/policies/act/policy.py
Normal file
218
lerobot/common/policies/act/policy.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from lerobot.common.policies.act.detr_vae import build
|
||||
|
||||
|
||||
def build_act_model_and_optimizer(cfg):
|
||||
model = build(cfg)
|
||||
|
||||
param_dicts = [
|
||||
{"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
|
||||
"lr": cfg.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def kl_divergence(mu, logvar):
|
||||
batch_size = mu.size(0)
|
||||
assert batch_size != 0
|
||||
if mu.data.ndimension() == 4:
|
||||
mu = mu.view(mu.size(0), mu.size(1))
|
||||
if logvar.data.ndimension() == 4:
|
||||
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
||||
|
||||
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
||||
total_kld = klds.sum(1).mean(0, True)
|
||||
dimension_wise_kld = klds.mean(0)
|
||||
mean_kld = klds.mean(1).mean(0, True)
|
||||
|
||||
return total_kld, dimension_wise_kld, mean_kld
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.n_action_steps = n_action_steps
|
||||
self.device = device
|
||||
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
||||
self.kl_weight = self.cfg.kl_weight
|
||||
logging.info(f"KL Weight {self.kl_weight}")
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
del step
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
self.train()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
def process_batch(batch, horizon, num_slices):
|
||||
# trajectory t = 64, horizon h = 16
|
||||
# (t h) ... -> t h ...
|
||||
batch = batch.reshape(num_slices, horizon)
|
||||
|
||||
image = batch["observation", "image", "top"]
|
||||
image = image[:, 0] # first observation t=0
|
||||
# batch, num_cam, channel, height, width
|
||||
image = image.unsqueeze(1)
|
||||
assert image.ndim == 5
|
||||
image = image.float()
|
||||
|
||||
state = batch["observation", "state"]
|
||||
state = state[:, 0] # first observation t=0
|
||||
# batch, qpos_dim
|
||||
assert state.ndim == 2
|
||||
|
||||
action = batch["action"]
|
||||
# batch, seq, action_dim
|
||||
assert action.ndim == 3
|
||||
assert action.shape[1] == horizon
|
||||
|
||||
if self.cfg.n_obs_steps > 1:
|
||||
raise NotImplementedError()
|
||||
# # keep first n observations of the slice corresponding to t=[-1,0]
|
||||
# image = image[:, : self.cfg.n_obs_steps]
|
||||
# state = state[:, : self.cfg.n_obs_steps]
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": image.to(self.device, non_blocking=True),
|
||||
"agent_pos": state.to(self.device, non_blocking=True),
|
||||
},
|
||||
"action": action.to(self.device, non_blocking=True),
|
||||
}
|
||||
return out
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
||||
loss = self.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
# self.lr_scheduler.step()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
# "lr": self.lr_scheduler.get_last_lr()[0],
|
||||
"lr": self.cfg.lr,
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
loss_dict = self._forward(
|
||||
qpos=batch["obs"]["agent_pos"],
|
||||
image=batch["obs"]["image"],
|
||||
actions=batch["action"],
|
||||
)
|
||||
loss = loss_dict["loss"]
|
||||
return loss
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
self.eval()
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
observation["state"] = observation["state"].unsqueeze(0)
|
||||
|
||||
# TODO(rcadene): remove hack
|
||||
# add 1 camera dimension
|
||||
observation["image"] = observation["image"].unsqueeze(1)
|
||||
|
||||
obs_dict = {
|
||||
"image": observation["image"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
action = self._forward(qpos=obs_dict["agent_pos"], image=obs_dict["image"])
|
||||
|
||||
if self.cfg.temporal_agg:
|
||||
# TODO(rcadene): implement temporal aggregation
|
||||
raise NotImplementedError()
|
||||
# all_time_actions[[t], t:t+num_queries] = action
|
||||
# actions_for_curr_step = all_time_actions[:, t]
|
||||
# actions_populated = torch.all(actions_for_curr_step != 0, axis=1)
|
||||
# actions_for_curr_step = actions_for_curr_step[actions_populated]
|
||||
# k = 0.01
|
||||
# exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
||||
# exp_weights = exp_weights / exp_weights.sum()
|
||||
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
||||
|
||||
# remove bsize=1
|
||||
action = action.squeeze(0)
|
||||
|
||||
# take first predicted action or n first actions
|
||||
action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps]
|
||||
return action
|
||||
|
||||
def _forward(self, qpos, image, actions=None, is_pad=None):
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
|
||||
is_training = actions is not None
|
||||
if is_training: # training time
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
if is_pad is not None:
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
|
||||
loss_dict = {}
|
||||
loss_dict["l1"] = l1
|
||||
if self.cfg.vae:
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
loss_dict["kl"] = total_kld[0]
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = loss_dict["l1"]
|
||||
return loss_dict
|
||||
else:
|
||||
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
return action
|
||||
101
lerobot/common/policies/act/position_encoding.py
Normal file
101
lerobot/common/policies/act/position_encoding.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Various positional encodings for the transformer.
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .utils import NestedTensor
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, tensor):
|
||||
x = tensor
|
||||
# mask = tensor_list.mask
|
||||
# assert mask is not None
|
||||
# not_mask = ~mask
|
||||
|
||||
not_mask = torch.ones_like(x[0, [0]])
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Module):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.uniform_(self.row_embed.weight)
|
||||
nn.init.uniform_(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
h, w = x.shape[-2:]
|
||||
i = torch.arange(w, device=x.device)
|
||||
j = torch.arange(h, device=x.device)
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
pos = (
|
||||
torch.cat(
|
||||
[
|
||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
.permute(2, 0, 1)
|
||||
.unsqueeze(0)
|
||||
.repeat(x.shape[0], 1, 1, 1)
|
||||
)
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(args):
|
||||
n_steps = args.hidden_dim // 2
|
||||
if args.position_embedding in ("v2", "sine"):
|
||||
# TODO find a better way of exposing other arguments
|
||||
position_embedding = PositionEmbeddingSine(n_steps, normalize=True)
|
||||
elif args.position_embedding in ("v3", "learned"):
|
||||
position_embedding = PositionEmbeddingLearned(n_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {args.position_embedding}")
|
||||
|
||||
return position_embedding
|
||||
370
lerobot/common/policies/act/transformer.py
Normal file
370
lerobot/common/policies/act/transformer.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
DETR Transformer class.
|
||||
|
||||
Copy-paste from torch.nn.Transformer with modifications:
|
||||
* positional encodings are passed in MHattention
|
||||
* extra LN at the end of encoder is removed
|
||||
* decoder returns a stack of activations from all decoding layers
|
||||
"""
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
return_intermediate_dec=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoder(
|
||||
decoder_layer, num_decoder_layers, decoder_norm, return_intermediate=return_intermediate_dec
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask,
|
||||
query_embed,
|
||||
pos_embed,
|
||||
latent_input=None,
|
||||
proprio_input=None,
|
||||
additional_pos_embed=None,
|
||||
):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
bs, c, h, w = src.shape
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
# flatten NxHWxC to HWxNxC
|
||||
bs, hw, c = src.shape
|
||||
src = src.permute(1, 0, 2)
|
||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
tgt = torch.zeros_like(query_embed)
|
||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
|
||||
hs = hs.transpose(1, 2)
|
||||
return hs
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = tgt
|
||||
|
||||
intermediate = []
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos,
|
||||
query_pos=query_pos,
|
||||
)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(self.norm(output))
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
if self.return_intermediate:
|
||||
intermediate.pop()
|
||||
intermediate.append(output)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(src, pos)
|
||||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt2, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
return self.forward_post(
|
||||
tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
|
||||
)
|
||||
|
||||
|
||||
def _get_clones(module, n):
|
||||
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
|
||||
|
||||
|
||||
def build_transformer(args):
|
||||
return Transformer(
|
||||
d_model=args.hidden_dim,
|
||||
dropout=args.dropout,
|
||||
nhead=args.nheads,
|
||||
dim_feedforward=args.dim_feedforward,
|
||||
num_encoder_layers=args.enc_layers,
|
||||
num_decoder_layers=args.dec_layers,
|
||||
normalize_before=args.pre_norm,
|
||||
return_intermediate_dec=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
477
lerobot/common/policies/act/utils.py
Normal file
477
lerobot/common/policies/act/utils.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||
import torchvision
|
||||
from packaging import version
|
||||
from torch import Tensor
|
||||
|
||||
if version.parse(torchvision.__version__) < version.parse("0.7"):
|
||||
from torchvision.ops import _new_empty_tensor
|
||||
from torchvision.ops.misc import _output_size
|
||||
|
||||
|
||||
class SmoothedValue:
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
||||
)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list, strict=False):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that all processes
|
||||
have the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values, strict=False)} # noqa: C416
|
||||
return reduced_dict
|
||||
|
||||
|
||||
class MetricLogger:
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append("{}: {}".format(name, str(meter)))
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
if not header:
|
||||
header = ""
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join(
|
||||
[
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
"max mem: {memory:.0f}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = self.delimiter.join(
|
||||
[
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
]
|
||||
)
|
||||
mega_b = 1024.0 * 1024.0
|
||||
for i, obj in enumerate(iterable):
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / mega_b,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
||||
|
||||
sha = "N/A"
|
||||
diff = "clean"
|
||||
branch = "N/A"
|
||||
try:
|
||||
sha = _run(["git", "rev-parse", "HEAD"])
|
||||
subprocess.check_output(["git", "diff"], cwd=cwd)
|
||||
diff = _run(["git", "diff-index", "HEAD"])
|
||||
diff = "has uncommited changes" if diff else "clean"
|
||||
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
batch = list(zip(*batch, strict=False))
|
||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||
return tuple(batch)
|
||||
|
||||
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
# type: (Device) -> NestedTensor # noqa
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
assert mask is not None
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
# TODO make this more general
|
||||
if tensor_list[0].ndim == 3:
|
||||
if torchvision._is_tracing():
|
||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||
|
||||
# TODO make it support different-sized images
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
b, c, h, w = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], : img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError("not supported")
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||
@torch.jit.unused
|
||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||
max_size = []
|
||||
for i in range(tensor_list[0].dim()):
|
||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(
|
||||
torch.int64
|
||||
)
|
||||
max_size.append(max_size_i)
|
||||
max_size = tuple(max_size)
|
||||
|
||||
# work around for
|
||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
# m[: img.shape[1], :img.shape[2]] = False
|
||||
# which is not yet supported in onnx
|
||||
padded_imgs = []
|
||||
padded_masks = []
|
||||
for img in tensor_list:
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape), strict=False)]
|
||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||
padded_imgs.append(padded_img)
|
||||
|
||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||
padded_masks.append(padded_mask.to(torch.bool))
|
||||
|
||||
tensor = torch.stack(padded_imgs)
|
||||
mask = torch.stack(padded_masks)
|
||||
|
||||
return NestedTensor(tensor, mask=mask)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop("force", False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||
elif "SLURM_PROCID" in os.environ:
|
||||
args.rank = int(os.environ["SLURM_PROCID"])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print("Not using distributed mode")
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = "nccl"
|
||||
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
||||
)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if target.numel() == 0:
|
||||
return [torch.zeros([], device=output.device)]
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||
"""
|
||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||
This will eventually be supported natively by PyTorch, and this
|
||||
class can go away.
|
||||
"""
|
||||
if version.parse(torchvision.__version__) < version.parse("0.7"):
|
||||
if input.numel() > 0:
|
||||
return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
|
||||
output_shape = _output_size(2, input, size, scale_factor)
|
||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||
return _new_empty_tensor(input, output_shape)
|
||||
else:
|
||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
@@ -1,157 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusionConfig:
|
||||
"""Configuration class for Diffusion Policy.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two availables
|
||||
modes are "mean_std" which substracts the mean and divide by the standard
|
||||
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||
downsampling.
|
||||
kernel_size: The convolutional kernel size of the diffusion modeling Unet.
|
||||
n_groups: Number of groups used in the group norm of the Unet's convolutional blocks.
|
||||
diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear
|
||||
network. This is the output dimension of that network, i.e., the embedding dimension.
|
||||
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
|
||||
Bias modulation is used be default, while this parameter indicates whether to also use scale
|
||||
modulation.
|
||||
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
|
||||
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
|
||||
beta_start: Beta value for the first forward-diffusion step.
|
||||
beta_end: Beta value for the last forward-diffusion step.
|
||||
prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon"
|
||||
or "sample". These have equivalent outcomes from a latent variable modeling perspective, but
|
||||
"epsilon" has been shown to work better in many deep neural network settings.
|
||||
clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each
|
||||
denoising step at inference time. WARNING: you will need to make sure your action-space is
|
||||
normalized to fit within this range.
|
||||
clip_sample_range: The magnitude of the clipping range as described above.
|
||||
num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly
|
||||
spaced). If not provided, this defaults to be the same as `num_train_timesteps`.
|
||||
"""
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 2
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
input_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
# Unet.
|
||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||
kernel_size: int = 5
|
||||
n_groups: int = 8
|
||||
diffusion_step_embed_dim: int = 128
|
||||
use_film_scale_modulation: bool = True
|
||||
# Noise scheduler.
|
||||
num_train_timesteps: int = 100
|
||||
beta_schedule: str = "squaredcos_cap_v2"
|
||||
beta_start: float = 0.0001
|
||||
beta_end: float = 0.02
|
||||
prediction_type: str = "epsilon"
|
||||
clip_sample: bool = True
|
||||
clip_sample_range: float = 1.0
|
||||
|
||||
# Inference
|
||||
num_inference_steps: int | None = None
|
||||
|
||||
# ---
|
||||
# TODO(alexander-soare): Remove these from the policy config.
|
||||
batch_size: int = 64
|
||||
grad_clip_norm: int = 10
|
||||
lr: float = 1.0e-4
|
||||
lr_scheduler: str = "cosine"
|
||||
lr_warmup_steps: int = 500
|
||||
adam_betas: tuple[float, float] = (0.95, 0.999)
|
||||
adam_eps: float = 1.0e-8
|
||||
adam_weight_decay: float = 1.0e-6
|
||||
utd: int = 1
|
||||
use_ema: bool = True
|
||||
ema_update_after_step: int = 0
|
||||
ema_min_alpha: float = 0.0
|
||||
ema_max_alpha: float = 0.9999
|
||||
ema_inv_gamma: float = 1.0
|
||||
ema_power: float = 0.75
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes["observation.image"][1]
|
||||
or self.crop_shape[1] > self.input_shapes["observation.image"][2]
|
||||
):
|
||||
raise ValueError(
|
||||
f'`crop_shape` should fit within `input_shapes["observation.image"]`. Got {self.crop_shape} '
|
||||
f'for `crop_shape` and {self.input_shapes["observation.image"]} for '
|
||||
'`input_shapes["observation.image"]`.'
|
||||
)
|
||||
supported_prediction_types = ["epsilon", "sample"]
|
||||
if self.prediction_type not in supported_prediction_types:
|
||||
raise ValueError(
|
||||
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
|
||||
)
|
||||
246
lerobot/common/policies/diffusion/diffusion_unet_image_policy.py
Normal file
246
lerobot/common/policies/diffusion/diffusion_unet_image_policy.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from einops import reduce
|
||||
|
||||
from diffusion_policy.common.pytorch_util import dict_apply
|
||||
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
|
||||
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
|
||||
|
||||
|
||||
class DiffusionUnetImagePolicy(BaseImagePolicy):
|
||||
def __init__(
|
||||
self,
|
||||
shape_meta: dict,
|
||||
noise_scheduler: DDPMScheduler,
|
||||
obs_encoder: MultiImageObsEncoder,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
n_obs_steps,
|
||||
num_inference_steps=None,
|
||||
obs_as_global_cond=True,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=(256, 512, 1024),
|
||||
kernel_size=5,
|
||||
n_groups=8,
|
||||
cond_predict_scale=True,
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# parse shapes
|
||||
action_shape = shape_meta["action"]["shape"]
|
||||
assert len(action_shape) == 1
|
||||
action_dim = action_shape[0]
|
||||
# get feature dim
|
||||
obs_feature_dim = obs_encoder.output_shape()[0]
|
||||
|
||||
# create diffusion model
|
||||
input_dim = action_dim + obs_feature_dim
|
||||
global_cond_dim = None
|
||||
if obs_as_global_cond:
|
||||
input_dim = action_dim
|
||||
global_cond_dim = obs_feature_dim * n_obs_steps
|
||||
|
||||
model = ConditionalUnet1D(
|
||||
input_dim=input_dim,
|
||||
local_cond_dim=None,
|
||||
global_cond_dim=global_cond_dim,
|
||||
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
||||
down_dims=down_dims,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
)
|
||||
|
||||
self.obs_encoder = obs_encoder
|
||||
self.model = model
|
||||
self.noise_scheduler = noise_scheduler
|
||||
self.mask_generator = LowdimMaskGenerator(
|
||||
action_dim=action_dim,
|
||||
obs_dim=0 if obs_as_global_cond else obs_feature_dim,
|
||||
max_n_obs_steps=n_obs_steps,
|
||||
fix_obs_steps=True,
|
||||
action_visible=False,
|
||||
)
|
||||
self.horizon = horizon
|
||||
self.obs_feature_dim = obs_feature_dim
|
||||
self.action_dim = action_dim
|
||||
self.n_action_steps = n_action_steps
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.obs_as_global_cond = obs_as_global_cond
|
||||
self.kwargs = kwargs
|
||||
|
||||
if num_inference_steps is None:
|
||||
num_inference_steps = noise_scheduler.config.num_train_timesteps
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self,
|
||||
condition_data,
|
||||
condition_mask,
|
||||
local_cond=None,
|
||||
global_cond=None,
|
||||
generator=None,
|
||||
# keyword arguments to scheduler.step
|
||||
**kwargs,
|
||||
):
|
||||
model = self.model
|
||||
scheduler = self.noise_scheduler
|
||||
|
||||
trajectory = torch.randn(
|
||||
size=condition_data.shape,
|
||||
dtype=condition_data.dtype,
|
||||
device=condition_data.device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
# set step values
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
for t in scheduler.timesteps:
|
||||
# 1. apply conditioning
|
||||
trajectory[condition_mask] = condition_data[condition_mask]
|
||||
|
||||
# 2. predict model output
|
||||
model_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond)
|
||||
|
||||
# 3. compute previous image: x_t -> x_t-1
|
||||
trajectory = scheduler.step(
|
||||
model_output,
|
||||
t,
|
||||
trajectory,
|
||||
generator=generator,
|
||||
# **kwargs # TODO(rcadene): in diffusion_policy, expected to be {}
|
||||
).prev_sample
|
||||
|
||||
# finally make sure conditioning is enforced
|
||||
trajectory[condition_mask] = condition_data[condition_mask]
|
||||
|
||||
return trajectory
|
||||
|
||||
def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
obs_dict: must include "obs" key
|
||||
result: must include "action" key
|
||||
"""
|
||||
assert "past_action" not in obs_dict # not implemented yet
|
||||
nobs = obs_dict
|
||||
value = next(iter(nobs.values()))
|
||||
bsize, n_obs_steps = value.shape[:2]
|
||||
horizon = self.horizon
|
||||
action_dim = self.action_dim
|
||||
obs_dim = self.obs_feature_dim
|
||||
assert self.n_obs_steps == n_obs_steps
|
||||
|
||||
# build input
|
||||
device = self.device
|
||||
dtype = self.dtype
|
||||
|
||||
# handle different ways of passing observation
|
||||
local_cond = None
|
||||
global_cond = None
|
||||
if self.obs_as_global_cond:
|
||||
# condition through global feature
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, Do
|
||||
global_cond = nobs_features.reshape(bsize, -1)
|
||||
# empty data for action
|
||||
cond_data = torch.zeros(size=(bsize, horizon, action_dim), device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
else:
|
||||
# condition through impainting
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, T, Do
|
||||
nobs_features = nobs_features.reshape(bsize, n_obs_steps, -1)
|
||||
cond_data = torch.zeros(size=(bsize, horizon, action_dim + obs_dim), device=device, dtype=dtype)
|
||||
cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
|
||||
cond_data[:, :n_obs_steps, action_dim:] = nobs_features
|
||||
cond_mask[:, :n_obs_steps, action_dim:] = True
|
||||
|
||||
# run sampling
|
||||
nsample = self.conditional_sample(
|
||||
cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs
|
||||
)
|
||||
|
||||
action_pred = nsample[..., :action_dim]
|
||||
|
||||
# get action
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.n_action_steps
|
||||
action = action_pred[:, start:end]
|
||||
|
||||
result = {"action": action, "action_pred": action_pred}
|
||||
return result
|
||||
|
||||
def compute_loss(self, batch):
|
||||
assert "valid_mask" not in batch
|
||||
nobs = batch["obs"]
|
||||
nactions = batch["action"]
|
||||
batch_size = nactions.shape[0]
|
||||
horizon = nactions.shape[1]
|
||||
|
||||
# handle different ways of passing observation
|
||||
local_cond = None
|
||||
global_cond = None
|
||||
trajectory = nactions
|
||||
cond_data = trajectory
|
||||
if self.obs_as_global_cond:
|
||||
# reshape B, T, ... to B*T
|
||||
this_nobs = dict_apply(nobs, lambda x: x[:, : self.n_obs_steps, ...].reshape(-1, *x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, Do
|
||||
global_cond = nobs_features.reshape(batch_size, -1)
|
||||
else:
|
||||
# reshape B, T, ... to B*T
|
||||
this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
|
||||
nobs_features = self.obs_encoder(this_nobs)
|
||||
# reshape back to B, T, Do
|
||||
nobs_features = nobs_features.reshape(batch_size, horizon, -1)
|
||||
cond_data = torch.cat([nactions, nobs_features], dim=-1)
|
||||
trajectory = cond_data.detach()
|
||||
|
||||
# generate impainting mask
|
||||
condition_mask = self.mask_generator(trajectory.shape)
|
||||
|
||||
# Sample noise that we'll add to the images
|
||||
noise = torch.randn(trajectory.shape, device=trajectory.device)
|
||||
bsz = trajectory.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device
|
||||
).long()
|
||||
# Add noise to the clean images according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps)
|
||||
|
||||
# compute loss mask
|
||||
loss_mask = ~condition_mask
|
||||
|
||||
# apply conditioning
|
||||
noisy_trajectory[condition_mask] = cond_data[condition_mask]
|
||||
|
||||
# Predict the noise residual
|
||||
pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
|
||||
|
||||
pred_type = self.noise_scheduler.config.prediction_type
|
||||
if pred_type == "epsilon":
|
||||
target = noise
|
||||
elif pred_type == "sample":
|
||||
target = trajectory
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type {pred_type}")
|
||||
|
||||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
loss = loss * loss_mask.type(loss.dtype)
|
||||
loss = reduce(loss, "b ... -> b (...)", "mean")
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
@@ -1,737 +0,0 @@
|
||||
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
|
||||
TODO(alexander-soare):
|
||||
- Remove reliance on Robomimic for SpatialSoftmax.
|
||||
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||
- Move EMA out of policy.
|
||||
- Consolidate _DiffusionUnetImagePolicy into DiffusionPolicy.
|
||||
- One more pass on comments and documentation.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from robomimic.models.base_nets import SpatialSoftmax
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
populate_queues,
|
||||
)
|
||||
|
||||
|
||||
class DiffusionPolicy(nn.Module):
|
||||
"""
|
||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||
"""
|
||||
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(
|
||||
self, cfg: DiffusionConfig | None = None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
configuration class is used.
|
||||
"""
|
||||
super().__init__()
|
||||
# TODO(alexander-soare): LR scheduler will be removed.
|
||||
assert lr_scheduler_num_training_steps > 0
|
||||
if cfg is None:
|
||||
cfg = DiffusionConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
||||
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
self.diffusion = _DiffusionUnetImagePolicy(cfg)
|
||||
|
||||
# TODO(alexander-soare): This should probably be managed outside of the policy class.
|
||||
self.ema_diffusion = None
|
||||
self.ema = None
|
||||
if self.cfg.use_ema:
|
||||
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
||||
self.ema = _EMA(cfg, model=self.ema_diffusion)
|
||||
|
||||
# TODO(alexander-soare): Move optimizer out of policy.
|
||||
self.optimizer = torch.optim.Adam(
|
||||
self.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
|
||||
)
|
||||
|
||||
# TODO(alexander-soare): Move LR scheduler out of policy.
|
||||
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
|
||||
self.global_step = 0
|
||||
|
||||
# configure lr scheduler
|
||||
self.lr_scheduler = get_scheduler(
|
||||
cfg.lr_scheduler,
|
||||
optimizer=self.optimizer,
|
||||
num_warmup_steps=cfg.lr_warmup_steps,
|
||||
num_training_steps=lr_scheduler_num_training_steps,
|
||||
# pytorch assumes stepping LRScheduler every epoch
|
||||
# however huggingface diffusers steps it every batch
|
||||
last_epoch=self.global_step - 1,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=self.cfg.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.cfg.n_obs_steps),
|
||||
"action": deque(maxlen=self.cfg.n_action_steps),
|
||||
}
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method handles caching a history of observations and an action trajectory generated by the
|
||||
underlying diffusion model. Here's how it works:
|
||||
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
||||
copied `n_obs_steps` times to fill the cache).
|
||||
- The diffusion model generates `horizon` steps worth of actions.
|
||||
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
||||
Schematically this looks like:
|
||||
----------------------------------------------------------------------------------------------
|
||||
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h|
|
||||
|observation is used | YES | YES | YES | NO | NO | NO | NO | NO | NO |
|
||||
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||
----------------------------------------------------------------------------------------------
|
||||
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
|
||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
|
||||
Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model
|
||||
weights.
|
||||
"""
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
if not self.training and self.ema_diffusion is not None:
|
||||
actions = self.ema_diffusion.generate_actions(batch)
|
||||
else:
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
|
||||
def update(self, batch: dict[str, Tensor], **_) -> dict:
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
loss = self.forward(batch)["loss"]
|
||||
loss.backward()
|
||||
|
||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.diffusion.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.ema is not None:
|
||||
self.ema.step(self.diffusion)
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.lr_scheduler.get_last_lr()[0],
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
||||
if len(missing_keys) > 0:
|
||||
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
||||
logging.warning(
|
||||
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
|
||||
)
|
||||
assert len(unexpected_keys) == 0
|
||||
|
||||
|
||||
class _DiffusionUnetImagePolicy(nn.Module):
|
||||
def __init__(self, cfg: DiffusionConfig):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
self.rgb_encoder = _RgbEncoder(cfg)
|
||||
self.unet = _ConditionalUnet1D(
|
||||
cfg,
|
||||
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
|
||||
)
|
||||
|
||||
self.noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=cfg.num_train_timesteps,
|
||||
beta_start=cfg.beta_start,
|
||||
beta_end=cfg.beta_end,
|
||||
beta_schedule=cfg.beta_schedule,
|
||||
variance_type="fixed_small",
|
||||
clip_sample=cfg.clip_sample,
|
||||
clip_sample_range=cfg.clip_sample_range,
|
||||
prediction_type=cfg.prediction_type,
|
||||
)
|
||||
|
||||
if cfg.num_inference_steps is None:
|
||||
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
||||
else:
|
||||
self.num_inference_steps = cfg.num_inference_steps
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
|
||||
) -> Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
dtype = get_dtype_from_parameters(self)
|
||||
|
||||
# Sample prior.
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.cfg.horizon, self.cfg.output_shapes["action"][0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
for t in self.noise_scheduler.timesteps:
|
||||
# Predict model output.
|
||||
model_output = self.unet(
|
||||
sample,
|
||||
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
||||
global_cond=global_cond,
|
||||
)
|
||||
# Compute previous image: x_t -> x_t-1
|
||||
sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
"observation.image": (B, n_obs_steps, C, H, W)
|
||||
}
|
||||
"""
|
||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
assert n_obs_steps == self.cfg.n_obs_steps
|
||||
|
||||
# Extract image feature (first combine batch and sequence dims).
|
||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||
|
||||
# run sampling
|
||||
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
|
||||
# `horizon` steps worth of actions (from the first observation).
|
||||
actions = sample[..., : self.cfg.output_shapes["action"][0]]
|
||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.cfg.n_action_steps
|
||||
actions = actions[:, start:end]
|
||||
|
||||
return actions
|
||||
|
||||
def compute_loss(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
"observation.image": (B, n_obs_steps, C, H, W)
|
||||
"action": (B, horizon, action_dim)
|
||||
"action_is_pad": (B, horizon)
|
||||
}
|
||||
"""
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
horizon = batch["action"].shape[1]
|
||||
assert horizon == self.cfg.horizon
|
||||
assert n_obs_steps == self.cfg.n_obs_steps
|
||||
|
||||
# Extract image feature (first combine batch and sequence dims).
|
||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||
|
||||
trajectory = batch["action"]
|
||||
|
||||
# Forward diffusion.
|
||||
# Sample noise to add to the trajectory.
|
||||
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
||||
# Sample a random noising timestep for each item in the batch.
|
||||
timesteps = torch.randint(
|
||||
low=0,
|
||||
high=self.noise_scheduler.config.num_train_timesteps,
|
||||
size=(trajectory.shape[0],),
|
||||
device=trajectory.device,
|
||||
).long()
|
||||
# Add noise to the clean trajectories according to the noise magnitude at each timestep.
|
||||
noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps)
|
||||
|
||||
# Run the denoising network (that might denoise the trajectory, or attempt to predict the noise).
|
||||
pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond)
|
||||
|
||||
# Compute the loss.
|
||||
# The target is either the original trajectory, or the noise.
|
||||
if self.cfg.prediction_type == "epsilon":
|
||||
target = eps
|
||||
elif self.cfg.prediction_type == "sample":
|
||||
target = batch["action"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type {self.cfg.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(pred, target, reduction="none")
|
||||
|
||||
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
|
||||
if "action_is_pad" in batch:
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class _RgbEncoder(nn.Module):
|
||||
"""Encoder an RGB image into a 1D feature vector.
|
||||
|
||||
Includes the ability to normalize and crop the image first.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if cfg.crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(cfg.crop_shape)
|
||||
if cfg.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(cfg.crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
self.do_crop = False
|
||||
|
||||
# Set up backbone.
|
||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||
weights=cfg.pretrained_backbone_weights
|
||||
)
|
||||
# Note: This assumes that the layer4 feature map is children()[-3]
|
||||
# TODO(alexander-soare): Use a safer alternative.
|
||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
if cfg.use_group_norm:
|
||||
if cfg.pretrained_backbone_weights:
|
||||
raise ValueError(
|
||||
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
||||
)
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
with torch.inference_mode():
|
||||
feat_map_shape = tuple(
|
||||
self.backbone(torch.zeros(size=(1, *cfg.input_shapes["observation.image"]))).shape[1:]
|
||||
)
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = cfg.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, C, H, W) image tensor with pixel values in [0, 1].
|
||||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
else:
|
||||
# Always use center crop for eval.
|
||||
x = self.center_crop(x)
|
||||
# Extract backbone feature.
|
||||
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
|
||||
# Final linear layer with non-linearity.
|
||||
x = self.relu(self.out(x))
|
||||
return x
|
||||
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
root_module: The module for which the submodules need to be replaced
|
||||
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
|
||||
func: Takes a module as an argument and returns a new module to replace it with.
|
||||
Returns:
|
||||
The root module with its submodules replaced.
|
||||
"""
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
parent_module = root_module.get_submodule(".".join(parents))
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
src_module = parent_module[int(k)]
|
||||
else:
|
||||
src_module = getattr(parent_module, k)
|
||||
tgt_module = func(src_module)
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
parent_module[int(k)] = tgt_module
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
return root_module
|
||||
|
||||
|
||||
class _SinusoidalPosEmb(nn.Module):
|
||||
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class _Conv1dBlock(nn.Module):
|
||||
"""Conv1d --> GroupNorm --> Mish"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class _ConditionalUnet1D(nn.Module):
|
||||
"""A 1D convolutional UNet with FiLM modulation for conditioning.
|
||||
|
||||
Note: this removes local conditioning as compared to the original diffusion policy code.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: DiffusionConfig, global_cond_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.cfg = cfg
|
||||
|
||||
# Encoder for the diffusion timestep.
|
||||
self.diffusion_step_encoder = nn.Sequential(
|
||||
_SinusoidalPosEmb(cfg.diffusion_step_embed_dim),
|
||||
nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim),
|
||||
)
|
||||
|
||||
# The FiLM conditioning dimension.
|
||||
cond_dim = cfg.diffusion_step_embed_dim + global_cond_dim
|
||||
|
||||
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
|
||||
# just reverse these.
|
||||
in_out = [(cfg.output_shapes["action"][0], cfg.down_dims[0])] + list(
|
||||
zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True)
|
||||
)
|
||||
|
||||
# Unet encoder.
|
||||
common_res_block_kwargs = {
|
||||
"cond_dim": cond_dim,
|
||||
"kernel_size": cfg.kernel_size,
|
||||
"n_groups": cfg.n_groups,
|
||||
"use_film_scale_modulation": cfg.use_film_scale_modulation,
|
||||
}
|
||||
self.down_modules = nn.ModuleList([])
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
self.down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
_ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs),
|
||||
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Downsample as long as it is not the last block.
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Processing in the middle of the auto-encoder.
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
|
||||
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
|
||||
]
|
||||
)
|
||||
|
||||
# Unet decoder.
|
||||
self.up_modules = nn.ModuleList([])
|
||||
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
self.up_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
# dim_in * 2, because it takes the encoder's skip connection as well
|
||||
_ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Upsample as long as it is not the last block.
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
|
||||
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, T, input_dim) tensor for input to the Unet.
|
||||
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
|
||||
global_cond: (B, global_cond_dim)
|
||||
output: (B, T, input_dim)
|
||||
Returns:
|
||||
(B, T, input_dim) diffusion model prediction.
|
||||
"""
|
||||
# For 1D convolutions we'll need feature dimension first.
|
||||
x = einops.rearrange(x, "b t d -> b d t")
|
||||
|
||||
timesteps_embed = self.diffusion_step_encoder(timestep)
|
||||
|
||||
# If there is a global conditioning feature, concatenate it to the timestep embedding.
|
||||
if global_cond is not None:
|
||||
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
|
||||
else:
|
||||
global_feature = timesteps_embed
|
||||
|
||||
# Run encoder, keeping track of skip features to pass to the decoder.
|
||||
encoder_skip_features: list[Tensor] = []
|
||||
for resnet, resnet2, downsample in self.down_modules:
|
||||
x = resnet(x, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
encoder_skip_features.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
for mid_module in self.mid_modules:
|
||||
x = mid_module(x, global_feature)
|
||||
|
||||
# Run decoder, using the skip features from the encoder.
|
||||
for resnet, resnet2, upsample in self.up_modules:
|
||||
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
|
||||
x = resnet(x, global_feature)
|
||||
x = resnet2(x, global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
x = einops.rearrange(x, "b d t -> b t d")
|
||||
return x
|
||||
|
||||
|
||||
class _ConditionalResidualBlock1D(nn.Module):
|
||||
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
cond_dim: int,
|
||||
kernel_size: int = 3,
|
||||
n_groups: int = 8,
|
||||
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
|
||||
# FiLM just modulates bias).
|
||||
use_film_scale_modulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_film_scale_modulation = use_film_scale_modulation
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||
|
||||
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# A final convolution for dimension matching the residual (if needed).
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, in_channels, T)
|
||||
cond: (B, cond_dim)
|
||||
Returns:
|
||||
(B, out_channels, T)
|
||||
"""
|
||||
out = self.conv1(x)
|
||||
|
||||
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
|
||||
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
|
||||
if self.use_film_scale_modulation:
|
||||
# Treat the embedding as a list of scales and biases.
|
||||
scale = cond_embed[:, : self.out_channels]
|
||||
bias = cond_embed[:, self.out_channels :]
|
||||
out = scale * out + bias
|
||||
else:
|
||||
# Treat the embedding as biases.
|
||||
out = out + cond_embed
|
||||
|
||||
out = self.conv2(out)
|
||||
out = out + self.residual_conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class _EMA:
|
||||
"""
|
||||
Exponential Moving Average of models weights
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: DiffusionConfig, model: nn.Module):
|
||||
"""
|
||||
@crowsonkb's notes on EMA Warmup:
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
||||
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
||||
at 215.4k steps).
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
||||
min_alpha (float): The minimum EMA decay rate. Default: 0.
|
||||
"""
|
||||
|
||||
self.averaged_model = model
|
||||
self.averaged_model.eval()
|
||||
self.averaged_model.requires_grad_(False)
|
||||
|
||||
self.update_after_step = cfg.ema_update_after_step
|
||||
self.inv_gamma = cfg.ema_inv_gamma
|
||||
self.power = cfg.ema_power
|
||||
self.min_alpha = cfg.ema_min_alpha
|
||||
self.max_alpha = cfg.ema_max_alpha
|
||||
|
||||
self.alpha = 0.0
|
||||
self.optimization_step = 0
|
||||
|
||||
def get_decay(self, optimization_step):
|
||||
"""
|
||||
Compute the decay factor for the exponential moving average.
|
||||
"""
|
||||
step = max(0, optimization_step - self.update_after_step - 1)
|
||||
value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
||||
|
||||
if step <= 0:
|
||||
return 0.0
|
||||
|
||||
return max(self.min_alpha, min(value, self.max_alpha))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, new_model):
|
||||
self.alpha = self.get_decay(self.optimization_step)
|
||||
|
||||
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True):
|
||||
# Iterate over immediate parameters only.
|
||||
for param, ema_param in zip(
|
||||
module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True
|
||||
):
|
||||
if isinstance(param, dict):
|
||||
raise RuntimeError("Dict parameter not supported")
|
||||
if isinstance(module, _BatchNorm) or not param.requires_grad:
|
||||
# Copy BatchNorm parameters, and non-trainable parameters directly.
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
else:
|
||||
ema_param.mul_(self.alpha)
|
||||
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha)
|
||||
|
||||
self.optimization_step += 1
|
||||
189
lerobot/common/policies/diffusion/multi_image_obs_encoder.py
Normal file
189
lerobot/common/policies/diffusion/multi_image_obs_encoder.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import copy
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
from diffusion_policy.common.pytorch_util import replace_submodules
|
||||
from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin
|
||||
from diffusion_policy.model.vision.crop_randomizer import CropRandomizer
|
||||
|
||||
|
||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
shape_meta: dict,
|
||||
rgb_model: Union[nn.Module, Dict[str, nn.Module]],
|
||||
resize_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
|
||||
crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
|
||||
random_crop: bool = True,
|
||||
# replace BatchNorm with GroupNorm
|
||||
use_group_norm: bool = False,
|
||||
# use single rgb model for all rgb inputs
|
||||
share_rgb_model: bool = False,
|
||||
# renormalize rgb input with imagenet normalization
|
||||
# assuming input in [0,1]
|
||||
imagenet_norm: bool = False,
|
||||
):
|
||||
"""
|
||||
Assumes rgb input: B,C,H,W
|
||||
Assumes low_dim input: B,D
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
rgb_keys = []
|
||||
low_dim_keys = []
|
||||
key_model_map = nn.ModuleDict()
|
||||
key_transform_map = nn.ModuleDict()
|
||||
key_shape_map = {}
|
||||
|
||||
# handle sharing vision backbone
|
||||
if share_rgb_model:
|
||||
assert isinstance(rgb_model, nn.Module)
|
||||
key_model_map["rgb"] = rgb_model
|
||||
|
||||
obs_shape_meta = shape_meta["obs"]
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = tuple(attr["shape"])
|
||||
type = attr.get("type", "low_dim")
|
||||
key_shape_map[key] = shape
|
||||
if type == "rgb":
|
||||
rgb_keys.append(key)
|
||||
# configure model for this key
|
||||
this_model = None
|
||||
if not share_rgb_model:
|
||||
if isinstance(rgb_model, dict):
|
||||
# have provided model for each key
|
||||
this_model = rgb_model[key]
|
||||
else:
|
||||
assert isinstance(rgb_model, nn.Module)
|
||||
# have a copy of the rgb model
|
||||
this_model = copy.deepcopy(rgb_model)
|
||||
|
||||
if this_model is not None:
|
||||
if use_group_norm:
|
||||
this_model = replace_submodules(
|
||||
root_module=this_model,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16, num_channels=x.num_features
|
||||
),
|
||||
)
|
||||
key_model_map[key] = this_model
|
||||
|
||||
# configure resize
|
||||
input_shape = shape
|
||||
this_resizer = nn.Identity()
|
||||
if resize_shape is not None:
|
||||
if isinstance(resize_shape, dict):
|
||||
h, w = resize_shape[key]
|
||||
else:
|
||||
h, w = resize_shape
|
||||
this_resizer = torchvision.transforms.Resize(size=(h, w))
|
||||
input_shape = (shape[0], h, w)
|
||||
|
||||
# configure randomizer
|
||||
this_randomizer = nn.Identity()
|
||||
if crop_shape is not None:
|
||||
if isinstance(crop_shape, dict):
|
||||
h, w = crop_shape[key]
|
||||
else:
|
||||
h, w = crop_shape
|
||||
if random_crop:
|
||||
this_randomizer = CropRandomizer(
|
||||
input_shape=input_shape, crop_height=h, crop_width=w, num_crops=1, pos_enc=False
|
||||
)
|
||||
else:
|
||||
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
|
||||
# configure normalizer
|
||||
this_normalizer = nn.Identity()
|
||||
if imagenet_norm:
|
||||
# TODO(rcadene): move normalizer to dataset and env
|
||||
this_normalizer = torchvision.transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
|
||||
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
||||
key_transform_map[key] = this_transform
|
||||
elif type == "low_dim":
|
||||
low_dim_keys.append(key)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported obs type: {type}")
|
||||
rgb_keys = sorted(rgb_keys)
|
||||
low_dim_keys = sorted(low_dim_keys)
|
||||
|
||||
self.shape_meta = shape_meta
|
||||
self.key_model_map = key_model_map
|
||||
self.key_transform_map = key_transform_map
|
||||
self.share_rgb_model = share_rgb_model
|
||||
self.rgb_keys = rgb_keys
|
||||
self.low_dim_keys = low_dim_keys
|
||||
self.key_shape_map = key_shape_map
|
||||
|
||||
def forward(self, obs_dict):
|
||||
batch_size = None
|
||||
features = []
|
||||
# process rgb input
|
||||
if self.share_rgb_model:
|
||||
# pass all rgb obs to rgb model
|
||||
imgs = []
|
||||
for key in self.rgb_keys:
|
||||
img = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = img.shape[0]
|
||||
else:
|
||||
assert batch_size == img.shape[0]
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
img = self.key_transform_map[key](img)
|
||||
imgs.append(img)
|
||||
# (N*B,C,H,W)
|
||||
imgs = torch.cat(imgs, dim=0)
|
||||
# (N*B,D)
|
||||
feature = self.key_model_map["rgb"](imgs)
|
||||
# (N,B,D)
|
||||
feature = feature.reshape(-1, batch_size, *feature.shape[1:])
|
||||
# (B,N,D)
|
||||
feature = torch.moveaxis(feature, 0, 1)
|
||||
# (B,N*D)
|
||||
feature = feature.reshape(batch_size, -1)
|
||||
features.append(feature)
|
||||
else:
|
||||
# run each rgb obs to independent models
|
||||
for key in self.rgb_keys:
|
||||
img = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = img.shape[0]
|
||||
else:
|
||||
assert batch_size == img.shape[0]
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
img = self.key_transform_map[key](img)
|
||||
feature = self.key_model_map[key](img)
|
||||
features.append(feature)
|
||||
|
||||
# process lowdim input
|
||||
for key in self.low_dim_keys:
|
||||
data = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = data.shape[0]
|
||||
else:
|
||||
assert batch_size == data.shape[0]
|
||||
assert data.shape[1:] == self.key_shape_map[key]
|
||||
features.append(data)
|
||||
|
||||
# concatenate all features
|
||||
result = torch.cat(features, dim=-1)
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def output_shape(self):
|
||||
example_obs_dict = {}
|
||||
obs_shape_meta = self.shape_meta["obs"]
|
||||
batch_size = 1
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = tuple(attr["shape"])
|
||||
this_obs = torch.zeros((batch_size,) + shape, dtype=self.dtype, device=self.device)
|
||||
example_obs_dict[key] = this_obs
|
||||
example_output = self.forward(example_obs_dict)
|
||||
output_shape = example_output.shape[1:]
|
||||
return output_shape
|
||||
199
lerobot/common/policies/diffusion/policy.py
Normal file
199
lerobot/common/policies/diffusion/policy.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import copy
|
||||
import time
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
|
||||
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from .multi_image_obs_encoder import MultiImageObsEncoder
|
||||
|
||||
|
||||
class DiffusionPolicy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
cfg_device,
|
||||
cfg_noise_scheduler,
|
||||
cfg_rgb_model,
|
||||
cfg_obs_encoder,
|
||||
cfg_optimizer,
|
||||
cfg_ema,
|
||||
shape_meta: dict,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
n_obs_steps,
|
||||
num_inference_steps=None,
|
||||
obs_as_global_cond=True,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=(256, 512, 1024),
|
||||
kernel_size=5,
|
||||
n_groups=8,
|
||||
cond_predict_scale=True,
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
|
||||
obs_encoder = MultiImageObsEncoder(
|
||||
rgb_model=rgb_model,
|
||||
**cfg_obs_encoder,
|
||||
)
|
||||
|
||||
self.diffusion = DiffusionUnetImagePolicy(
|
||||
shape_meta=shape_meta,
|
||||
noise_scheduler=noise_scheduler,
|
||||
obs_encoder=obs_encoder,
|
||||
horizon=horizon,
|
||||
n_action_steps=n_action_steps,
|
||||
n_obs_steps=n_obs_steps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
obs_as_global_cond=obs_as_global_cond,
|
||||
diffusion_step_embed_dim=diffusion_step_embed_dim,
|
||||
down_dims=down_dims,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.device = torch.device(cfg_device)
|
||||
if torch.cuda.is_available() and cfg_device == "cuda":
|
||||
self.diffusion.cuda()
|
||||
|
||||
self.ema = None
|
||||
if self.cfg.use_ema:
|
||||
self.ema = hydra.utils.instantiate(
|
||||
cfg_ema,
|
||||
model=copy.deepcopy(self.diffusion),
|
||||
)
|
||||
|
||||
self.optimizer = hydra.utils.instantiate(
|
||||
cfg_optimizer,
|
||||
params=self.diffusion.parameters(),
|
||||
)
|
||||
|
||||
# TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps
|
||||
self.global_step = 0
|
||||
|
||||
# configure lr scheduler
|
||||
self.lr_scheduler = get_scheduler(
|
||||
cfg.lr_scheduler,
|
||||
optimizer=self.optimizer,
|
||||
num_warmup_steps=cfg.lr_warmup_steps,
|
||||
num_training_steps=cfg.offline_steps,
|
||||
# pytorch assumes stepping LRScheduler every epoch
|
||||
# however huggingface diffusers steps it every batch
|
||||
last_epoch=self.global_step - 1,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
observation["state"] = observation["state"].unsqueeze(0)
|
||||
|
||||
obs_dict = {
|
||||
"image": observation["image"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
|
||||
action = out["action"].squeeze(0)
|
||||
return action
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
start_time = time.time()
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
def process_batch(batch, horizon, num_slices):
|
||||
# trajectory t = 64, horizon h = 16
|
||||
# (t h) ... -> t h ...
|
||||
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
|
||||
|
||||
# |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
|
||||
# |o|o| observations: 2
|
||||
# | |a|a|a|a|a|a|a|a| actions executed: 8
|
||||
# |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
|
||||
# note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
|
||||
|
||||
image = batch["observation", "image"]
|
||||
state = batch["observation", "state"]
|
||||
action = batch["action"]
|
||||
assert image.shape[1] == horizon
|
||||
assert state.shape[1] == horizon
|
||||
assert action.shape[1] == horizon
|
||||
|
||||
if not (horizon == 16 and self.cfg.n_obs_steps == 2):
|
||||
raise NotImplementedError()
|
||||
|
||||
# keep first 2 observations of the slice corresponding to t=[-1,0]
|
||||
image = image[:, : self.cfg.n_obs_steps]
|
||||
state = state[:, : self.cfg.n_obs_steps]
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": image.to(self.device, non_blocking=True),
|
||||
"agent_pos": state.to(self.device, non_blocking=True),
|
||||
},
|
||||
"action": action.to(self.device, non_blocking=True),
|
||||
}
|
||||
return out
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.diffusion.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.ema is not None:
|
||||
self.ema.step(self.diffusion)
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.lr_scheduler.get_last_lr()[0],
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
# TODO(rcadene): remove hardcoding
|
||||
# in diffusion_policy, len(dataloader) is 168 for a batch_size of 64
|
||||
if step % 168 == 0:
|
||||
self.global_step += 1
|
||||
|
||||
return info
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
@@ -1,61 +1,40 @@
|
||||
import inspect
|
||||
def make_policy(cfg):
|
||||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc import TDMPC
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
policy = TDMPC(cfg.policy, cfg.device)
|
||||
elif cfg.policy.name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||
|
||||
from lerobot.common.utils.utils import get_safe_torch_device
|
||||
|
||||
|
||||
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
|
||||
assert set(hydra_cfg.policy).issuperset(
|
||||
expected_kwargs
|
||||
), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||
policy_cfg = policy_cfg_class(
|
||||
**{
|
||||
k: v
|
||||
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
|
||||
if k in expected_kwargs
|
||||
}
|
||||
)
|
||||
return policy_cfg
|
||||
|
||||
|
||||
def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
|
||||
if hydra_cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
policy = TDMPCPolicy(
|
||||
hydra_cfg.policy,
|
||||
n_obs_steps=hydra_cfg.n_obs_steps,
|
||||
n_action_steps=hydra_cfg.n_action_steps,
|
||||
device=hydra_cfg.device,
|
||||
policy = DiffusionPolicy(
|
||||
cfg=cfg.policy,
|
||||
cfg_device=cfg.device,
|
||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
elif hydra_cfg.policy.name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
elif cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
|
||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
elif hydra_cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
policy = ActionChunkingTransformerPolicy(
|
||||
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
|
||||
)
|
||||
else:
|
||||
raise ValueError(hydra_cfg.policy.name)
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
if hydra_cfg.policy.pretrained_model_path:
|
||||
if cfg.policy.pretrained_model_path:
|
||||
# TODO(rcadene): hack for old pretrained models from fowm
|
||||
if hydra_cfg.policy.name == "tdmpc" and "fowm" in hydra_cfg.policy.pretrained_model_path:
|
||||
if "offline" in hydra_cfg.policy.pretrained_model_path:
|
||||
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
||||
if "offline" in cfg.pretrained_model_path:
|
||||
policy.step[0] = 25000
|
||||
elif "final" in hydra_cfg.policy.pretrained_model_path:
|
||||
elif "final" in cfg.pretrained_model_path:
|
||||
policy.step[0] = 100000
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
policy.load(hydra_cfg.policy.pretrained_model_path)
|
||||
policy.load(cfg.policy.pretrained_model_path)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -1,217 +0,0 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
def create_stats_buffers(
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||
statistics.
|
||||
|
||||
Args: (see Normalize and Unnormalize)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
||||
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, mode in modes.items():
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
|
||||
shape = tuple(shapes[key])
|
||||
|
||||
if "image" in key:
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
buffer = {}
|
||||
if mode == "mean_std":
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"mean": nn.Parameter(mean, requires_grad=False),
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif mode == "min_max":
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
if stats is not None:
|
||||
if mode == "mean_std":
|
||||
buffer["mean"].data = stats[key]["mean"]
|
||||
buffer["std"].data = stats[key]["std"]
|
||||
elif mode == "min_max":
|
||||
buffer["min"].data = stats[key]["min"]
|
||||
buffer["max"].data = stats[key]["max"]
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), (
|
||||
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(std).any(), (
|
||||
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), (
|
||||
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(max).any(), (
|
||||
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
"""
|
||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
||||
original range used by the environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), (
|
||||
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(std).any(), (
|
||||
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), (
|
||||
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(max).any(), (
|
||||
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
@@ -1,45 +0,0 @@
|
||||
"""A protocol that all policies should follow.
|
||||
|
||||
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
|
||||
subclass a base class.
|
||||
|
||||
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
|
||||
how to implement new policies.
|
||||
"""
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Policy(Protocol):
|
||||
"""The required interface for implementing a policy."""
|
||||
|
||||
name: str
|
||||
|
||||
def reset(self):
|
||||
"""To be called whenever the environment is reset.
|
||||
|
||||
Does things like clearing caches.
|
||||
"""
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
|
||||
Returns a dictionary with "loss" and maybe other information.
|
||||
"""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
with caching.
|
||||
"""
|
||||
|
||||
def update(self, batch):
|
||||
"""Does compute_loss then an optimization step.
|
||||
|
||||
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
|
||||
disappear.
|
||||
"""
|
||||
@@ -1,7 +1,6 @@
|
||||
# ruff: noqa: N806
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
|
||||
import einops
|
||||
@@ -9,9 +8,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import lerobot.common.policies.tdmpc.helper as h
|
||||
from lerobot.common.policies.utils import populate_queues
|
||||
from lerobot.common.utils.utils import get_safe_torch_device
|
||||
import lerobot.common.policies.tdmpc_helper as h
|
||||
|
||||
FIRST_FRAME = 0
|
||||
|
||||
@@ -88,28 +85,24 @@ class TOLD(nn.Module):
|
||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||
|
||||
|
||||
class TDMPCPolicy(nn.Module):
|
||||
class TDMPC(nn.Module):
|
||||
"""Implementation of TD-MPC learning + inference."""
|
||||
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(self, cfg, n_obs_steps, n_action_steps, device):
|
||||
def __init__(self, cfg, device):
|
||||
super().__init__()
|
||||
self.action_dim = cfg.action_dim
|
||||
|
||||
self.cfg = cfg
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.n_action_steps = n_action_steps
|
||||
self.device = get_safe_torch_device(device)
|
||||
self.device = torch.device(device)
|
||||
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
||||
self.model = TOLD(cfg)
|
||||
self.model.to(self.device)
|
||||
self.model = TOLD(cfg).cuda() if torch.cuda.is_available() and device == "cuda" else TOLD(cfg)
|
||||
self.model_target = deepcopy(self.model)
|
||||
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
||||
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.model.eval()
|
||||
self.model_target.eval()
|
||||
self.batch_size = cfg.batch_size
|
||||
|
||||
self.register_buffer("step", torch.zeros(1))
|
||||
|
||||
@@ -130,54 +123,21 @@ class TDMPCPolicy(nn.Module):
|
||||
self.model.load_state_dict(d["model"])
|
||||
self.model_target.load_state_dict(d["model_target"])
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=self.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.n_obs_steps),
|
||||
"action": deque(maxlen=self.n_action_steps),
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch, step):
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
def forward(self, observation, step_count):
|
||||
t0 = step_count.item() == 0
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
# TODO(rcadene): remove unsqueeze hack...
|
||||
if observation["image"].ndim == 3:
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
observation["state"] = observation["state"].unsqueeze(0)
|
||||
|
||||
t0 = step == 0
|
||||
|
||||
self.eval()
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
|
||||
if self.n_obs_steps == 1:
|
||||
# hack to remove the time dimension
|
||||
for key in batch:
|
||||
assert batch[key].shape[1] == 1
|
||||
batch[key] = batch[key][:, 0]
|
||||
|
||||
actions = []
|
||||
batch_size = batch["observation.image"].shape[0]
|
||||
for i in range(batch_size):
|
||||
obs = {
|
||||
"rgb": batch["observation.image"][[i]],
|
||||
"state": batch["observation.state"][[i]],
|
||||
}
|
||||
# Note: unsqueeze needed because `act` still uses non-batch logic.
|
||||
action = self.act(obs, t0=t0, step=self.step)
|
||||
actions.append(action)
|
||||
action = torch.stack(actions)
|
||||
|
||||
# tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time
|
||||
if i in range(self.n_action_steps):
|
||||
self._queues["action"].append(action)
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
obs = {
|
||||
# TODO(rcadene): remove contiguous hack...
|
||||
"rgb": observation["image"].contiguous(),
|
||||
"state": observation["state"].contiguous(),
|
||||
}
|
||||
action = self.act(obs, t0=t0, step=self.step.item())
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -326,58 +286,117 @@ class TDMPCPolicy(nn.Module):
|
||||
def _td_target(self, next_z, reward, mask):
|
||||
"""Compute the TD-target from a reward and the observation at the following time step."""
|
||||
next_v = self.model.V(next_z)
|
||||
td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
|
||||
td_target = reward + self.cfg.discount * mask * next_v
|
||||
return td_target
|
||||
|
||||
def forward(self, batch, step):
|
||||
# TODO(alexander-soare): Refactor TDMPC and make it comply with the policy interface documentation.
|
||||
raise NotImplementedError()
|
||||
|
||||
def update(self, batch, step):
|
||||
def update(self, replay_buffer, step, demo_buffer=None):
|
||||
"""Main update function. Corresponds to one iteration of the model learning."""
|
||||
start_time = time.time()
|
||||
|
||||
batch_size = batch["index"].shape[0]
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
||||
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
|
||||
# batch size b = 256, time/horizon t = 5
|
||||
# b t ... -> t b ...
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
if demo_buffer is None:
|
||||
demo_batch_size = 0
|
||||
else:
|
||||
# Update oversampling ratio
|
||||
demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
|
||||
demo_num_slices = int(demo_pc_batch * self.batch_size)
|
||||
demo_batch_size = self.cfg.horizon * demo_num_slices
|
||||
batch_size -= demo_batch_size
|
||||
num_slices -= demo_num_slices
|
||||
replay_buffer._sampler.num_slices = num_slices
|
||||
demo_buffer._sampler.num_slices = demo_num_slices
|
||||
|
||||
action = batch["action"]
|
||||
reward = batch["next.reward"]
|
||||
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
|
||||
assert demo_batch_size % self.cfg.horizon == 0
|
||||
assert demo_batch_size % demo_num_slices == 0
|
||||
|
||||
obses = {
|
||||
"rgb": batch["observation.image"],
|
||||
"state": batch["observation.state"],
|
||||
}
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
shapes = {}
|
||||
for k in obses:
|
||||
shapes[k] = obses[k].shape
|
||||
obses[k] = einops.rearrange(obses[k], "t b ... -> (t b) ... ")
|
||||
# Sample from interaction dataset
|
||||
|
||||
def process_batch(batch, horizon, num_slices):
|
||||
# trajectory t = 256, horizon h = 5
|
||||
# (t h) ... -> h t ...
|
||||
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
||||
|
||||
obs = {
|
||||
"rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
|
||||
"state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
|
||||
}
|
||||
action = batch["action"].to(self.device, non_blocking=True)
|
||||
next_obses = {
|
||||
"rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
|
||||
"state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
|
||||
}
|
||||
reward = batch["next", "reward"].to(self.device, non_blocking=True)
|
||||
|
||||
idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
|
||||
weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
|
||||
|
||||
# TODO(rcadene): rearrange directly in offline dataset
|
||||
if reward.ndim == 2:
|
||||
reward = einops.rearrange(reward, "h t -> h t 1")
|
||||
|
||||
assert reward.ndim == 3
|
||||
assert reward.shape == (horizon, num_slices, 1)
|
||||
# We dont use `batch["next", "done"]` since it only indicates the end of an
|
||||
# episode, but not the end of the trajectory of an episode.
|
||||
# Neither does `batch["next", "terminated"]`
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
return obs, action, next_obses, reward, mask, done, idxs, weights
|
||||
|
||||
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
||||
|
||||
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
|
||||
batch, self.cfg.horizon, num_slices
|
||||
)
|
||||
|
||||
# Sample from demonstration dataset
|
||||
if demo_batch_size > 0:
|
||||
demo_batch = demo_buffer.sample(demo_batch_size)
|
||||
(
|
||||
demo_obs,
|
||||
demo_action,
|
||||
demo_next_obses,
|
||||
demo_reward,
|
||||
demo_mask,
|
||||
demo_done,
|
||||
demo_idxs,
|
||||
demo_weights,
|
||||
) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
|
||||
|
||||
if isinstance(obs, dict):
|
||||
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
|
||||
next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
|
||||
else:
|
||||
obs = torch.cat([obs, demo_obs])
|
||||
next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
|
||||
action = torch.cat([action, demo_action], dim=1)
|
||||
reward = torch.cat([reward, demo_reward], dim=1)
|
||||
mask = torch.cat([mask, demo_mask], dim=1)
|
||||
done = torch.cat([done, demo_done], dim=1)
|
||||
idxs = torch.cat([idxs, demo_idxs])
|
||||
weights = torch.cat([weights, demo_weights])
|
||||
|
||||
# Apply augmentations
|
||||
aug_tf = h.aug(self.cfg)
|
||||
obses = aug_tf(obses)
|
||||
obs = aug_tf(obs)
|
||||
|
||||
for k in obses:
|
||||
t, b = shapes[k][:2]
|
||||
obses[k] = einops.rearrange(obses[k], "(t b) ... -> t b ... ", b=b, t=t)
|
||||
for k in next_obses:
|
||||
next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...")
|
||||
next_obses = aug_tf(next_obses)
|
||||
for k in next_obses:
|
||||
next_obses[k] = einops.rearrange(
|
||||
next_obses[k],
|
||||
"(h t) ... -> h t ...",
|
||||
h=self.cfg.horizon,
|
||||
t=self.cfg.batch_size,
|
||||
)
|
||||
|
||||
obs, next_obses = {}, {}
|
||||
for k in obses:
|
||||
obs[k] = obses[k][0]
|
||||
next_obses[k] = obses[k][1:].clone()
|
||||
|
||||
horizon = next_obses["rgb"].shape[0]
|
||||
horizon = self.cfg.horizon
|
||||
loss_mask = torch.ones_like(mask, device=self.device)
|
||||
for t in range(1, horizon):
|
||||
loss_mask[t] = loss_mask[t - 1] * (~done[t - 1])
|
||||
@@ -395,7 +414,7 @@ class TDMPCPolicy(nn.Module):
|
||||
td_targets = self._td_target(next_z, reward, mask)
|
||||
|
||||
# Latent rollout
|
||||
zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
|
||||
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
|
||||
reward_preds = torch.empty_like(reward, device=self.device)
|
||||
assert reward.shape[0] == horizon
|
||||
z = self.model.encode(obs)
|
||||
@@ -404,21 +423,22 @@ class TDMPCPolicy(nn.Module):
|
||||
for t in range(horizon):
|
||||
z, reward_pred = self.model.next(z, action[t])
|
||||
zs[t + 1] = z
|
||||
reward_preds[t] = reward_pred.squeeze(1)
|
||||
reward_preds[t] = reward_pred
|
||||
|
||||
with torch.no_grad():
|
||||
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
|
||||
|
||||
# Predictions
|
||||
qs = self.model.Q(zs[:-1], action, return_type="all")
|
||||
qs = qs.squeeze(3)
|
||||
value_info["Q"] = qs.mean().item()
|
||||
v = self.model.V(zs[:-1])
|
||||
value_info["V"] = v.mean().item()
|
||||
|
||||
# Losses
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
|
||||
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
|
||||
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
|
||||
dim=0
|
||||
)
|
||||
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
|
||||
q_value_loss, priority_loss = 0, 0
|
||||
for q in range(self.cfg.num_q):
|
||||
@@ -426,9 +446,7 @@ class TDMPCPolicy(nn.Module):
|
||||
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||
|
||||
expectile = h.linear_schedule(self.cfg.expectile, step)
|
||||
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
|
||||
dim=0
|
||||
)
|
||||
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
|
||||
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss
|
||||
@@ -437,7 +455,7 @@ class TDMPCPolicy(nn.Module):
|
||||
+ self.cfg.value_coef * v_value_loss
|
||||
)
|
||||
|
||||
weighted_loss = (total_loss * weights).mean()
|
||||
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
||||
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
||||
has_nan = torch.isnan(weighted_loss).item()
|
||||
if has_nan:
|
||||
@@ -450,20 +468,19 @@ class TDMPCPolicy(nn.Module):
|
||||
)
|
||||
self.optim.step()
|
||||
|
||||
# TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
|
||||
# if self.cfg.per:
|
||||
# # Update priorities
|
||||
# priorities = priority_loss.clamp(max=1e4).detach()
|
||||
# has_nan = torch.isnan(priorities).any().item()
|
||||
# if has_nan:
|
||||
# print(f"priorities has nan: {priorities=}")
|
||||
# else:
|
||||
# replay_buffer.update_priority(
|
||||
# idxs[:num_slices],
|
||||
# priorities[:num_slices],
|
||||
# )
|
||||
# if demo_batch_size > 0:
|
||||
# demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
|
||||
if self.cfg.per:
|
||||
# Update priorities
|
||||
priorities = priority_loss.clamp(max=1e4).detach()
|
||||
has_nan = torch.isnan(priorities).any().item()
|
||||
if has_nan:
|
||||
print(f"priorities has nan: {priorities=}")
|
||||
else:
|
||||
replay_buffer.update_priority(
|
||||
idxs[:num_slices],
|
||||
priorities[:num_slices],
|
||||
)
|
||||
if demo_batch_size > 0:
|
||||
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
|
||||
|
||||
# Update policy + target network
|
||||
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
|
||||
@@ -486,7 +503,7 @@ class TDMPCPolicy(nn.Module):
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
# info["demo_batch_size"] = demo_batch_size
|
||||
info["demo_batch_size"] = demo_batch_size
|
||||
info["expectile"] = expectile
|
||||
info.update(value_info)
|
||||
info.update(pi_update_info)
|
||||
@@ -1,30 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def populate_queues(queues, batch):
|
||||
for key in batch:
|
||||
if len(queues[key]) != queues[key].maxlen:
|
||||
# initialize by copying the first observation several times until the queue is full
|
||||
while len(queues[key]) != queues[key].maxlen:
|
||||
queues[key].append(batch[key])
|
||||
else:
|
||||
# add latest observation to the queue
|
||||
queues[key].append(batch[key])
|
||||
return queues
|
||||
|
||||
|
||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
"""Get a module's device by checking one of its parameters.
|
||||
|
||||
Note: assumes that all parameters have the same device
|
||||
"""
|
||||
return next(iter(module.parameters())).device
|
||||
|
||||
|
||||
def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
||||
"""Get a module's parameter dtype by checking one of its parameters.
|
||||
|
||||
Note: assumes that all parameters have the same dtype.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
45
lerobot/common/utils.py
Normal file
45
lerobot/common/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def init_logging():
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
return message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
formatter = logging.Formatter()
|
||||
formatter.format = custom_format
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
|
||||
def format_big_number(num):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
divisor = 1000.0
|
||||
|
||||
for suffix in suffixes:
|
||||
if abs(num) < divisor:
|
||||
return f"{num:.0f}{suffix}"
|
||||
num /= divisor
|
||||
|
||||
return num
|
||||
@@ -1,44 +0,0 @@
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
|
||||
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
|
||||
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
|
||||
Check if the package spec exists and grab its version to avoid importing a local directory.
|
||||
**Note:** this doesn't work for all packages.
|
||||
"""
|
||||
package_exists = importlib.util.find_spec(pkg_name) is not None
|
||||
package_version = "N/A"
|
||||
if package_exists:
|
||||
try:
|
||||
# Primary method to get the package version
|
||||
package_version = importlib.metadata.version(pkg_name)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
# Fallback method: Only for "torch" and versions containing "dev"
|
||||
if pkg_name == "torch":
|
||||
try:
|
||||
package = importlib.import_module(pkg_name)
|
||||
temp_version = getattr(package, "__version__", "N/A")
|
||||
# Check if the version contains "dev"
|
||||
if "dev" in temp_version:
|
||||
package_version = temp_version
|
||||
package_exists = True
|
||||
else:
|
||||
package_exists = False
|
||||
except ImportError:
|
||||
# If the package can't be imported, it's not available
|
||||
package_exists = False
|
||||
else:
|
||||
# For packages other than "torch", don't attempt the fallback and set as not available
|
||||
package_exists = False
|
||||
logging.debug(f"Detected {pkg_name} version: {package_version}")
|
||||
if return_version:
|
||||
return package_exists, package_version
|
||||
else:
|
||||
return package_exists
|
||||
|
||||
|
||||
_torch_available, _torch_version = is_package_available("torch", return_version=True)
|
||||
_gym_xarm_available = is_package_available("gym_xarm")
|
||||
_gym_aloha_available = is_package_available("gym_aloha")
|
||||
_gym_pusht_available = is_package_available("gym_pusht")
|
||||
@@ -1,12 +0,0 @@
|
||||
import warnings
|
||||
|
||||
import imageio
|
||||
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
# Filter out DeprecationWarnings raised from pkg_resources
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
|
||||
)
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
@@ -1,112 +0,0 @@
|
||||
import logging
|
||||
import os.path as osp
|
||||
import random
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
match cfg_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device("cuda")
|
||||
case "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
case "cpu":
|
||||
device = torch.device("cpu")
|
||||
if log:
|
||||
logging.warning("Using CPU, this will be slow.")
|
||||
case _:
|
||||
device = torch.device(cfg_device)
|
||||
if log:
|
||||
logging.warning(f"Using custom {cfg_device} device.")
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def set_global_seed(seed):
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def init_logging():
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
return message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
formatter = logging.Formatter()
|
||||
formatter.format = custom_format
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
|
||||
def format_big_number(num):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
divisor = 1000.0
|
||||
|
||||
for suffix in suffixes:
|
||||
if abs(num) < divisor:
|
||||
return f"{num:.0f}{suffix}"
|
||||
num /= divisor
|
||||
|
||||
return num
|
||||
|
||||
|
||||
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
"""Returns path1 relative to path2."""
|
||||
path1 = path1.absolute()
|
||||
path2 = path2.absolute()
|
||||
try:
|
||||
return path1.relative_to(path2)
|
||||
except ValueError: # most likely because path1 is not a subpath of path2
|
||||
common_parts = Path(osp.commonpath([path1, path2])).parts
|
||||
return Path(
|
||||
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
||||
)
|
||||
|
||||
|
||||
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
|
||||
"""Initialize a Hydra config given only the path to the relevant config file.
|
||||
|
||||
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
|
||||
"""
|
||||
# TODO(alexander-soare): Resolve configs without Hydra initialization.
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
# Hydra needs a path relative to this file.
|
||||
hydra.initialize(
|
||||
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)),
|
||||
version_base="1.2",
|
||||
)
|
||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||
return cfg
|
||||
|
||||
|
||||
def print_cuda_memory_usage():
|
||||
"""Use this function to locate and debug memory leak."""
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
# Also clear the cache if you want to fully release the memory
|
||||
torch.cuda.empty_cache()
|
||||
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
|
||||
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
|
||||
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
|
||||
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user