forked from tangger/lerobot
Compare commits
12 Commits
qgallouede
...
qgallouede
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b11905b168 | ||
|
|
0f7552c3d4 | ||
|
|
ccffa9e406 | ||
|
|
791506dfb8 | ||
|
|
55dc9f7f51 | ||
|
|
81e490d46f | ||
|
|
a4b6c5e3b1 | ||
|
|
bf2eebb090 | ||
|
|
fe2b9af64f | ||
|
|
fdf6a0c4e3 | ||
|
|
45f351c618 | ||
|
|
b980c5dd9e |
142
.dockerignore
Normal file
142
.dockerignore
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
# Misc
|
||||||
|
.git
|
||||||
|
tmp
|
||||||
|
wandb
|
||||||
|
data
|
||||||
|
outputs
|
||||||
|
.vscode
|
||||||
|
rl
|
||||||
|
media
|
||||||
|
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
logs
|
||||||
|
|
||||||
|
# HPC
|
||||||
|
nautilus/*.yaml
|
||||||
|
*.key
|
||||||
|
|
||||||
|
# Slurm
|
||||||
|
sbatch*.sh
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
!tests/data
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
25
.github/PULL_REQUEST_TEMPLATE.md
vendored
25
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,12 +1,29 @@
|
|||||||
# What does this PR do?
|
# What does this PR do?
|
||||||
|
|
||||||
Example: Fixes # (issue)
|
Examples:
|
||||||
|
- Fixes # (issue)
|
||||||
|
- Adds new dataset
|
||||||
|
- Optimizes something
|
||||||
|
|
||||||
|
## How was it tested?
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- Added `test_something` in `tests/test_stuff.py`.
|
||||||
|
- Added `new_feature` and checked that training converges with policy X on dataset/environment Y.
|
||||||
|
- Optimized `some_function`, it now runs X times faster than previously.
|
||||||
|
|
||||||
|
## How to checkout & try? (for the reviewer)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```bash
|
||||||
|
DATA_DIR=tests/data pytest -sx tests/test_stuff.py::test_something
|
||||||
|
```
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py --some.option=true
|
||||||
|
```
|
||||||
|
|
||||||
## Before submitting
|
## Before submitting
|
||||||
- Read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
|
Please read the [contributor guideline](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md#submitting-a-pull-request-pr).
|
||||||
- Provide a minimal code example for the reviewer to checkout & try.
|
|
||||||
- Explain how you tested your changes.
|
|
||||||
|
|
||||||
|
|
||||||
## Who can review?
|
## Who can review?
|
||||||
|
|||||||
3917
.github/poetry/cpu/poetry.lock
generated
vendored
3917
.github/poetry/cpu/poetry.lock
generated
vendored
File diff suppressed because it is too large
Load Diff
107
.github/poetry/cpu/pyproject.toml
vendored
107
.github/poetry/cpu/pyproject.toml
vendored
@@ -1,107 +0,0 @@
|
|||||||
[tool.poetry]
|
|
||||||
name = "lerobot"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
|
||||||
authors = [
|
|
||||||
"Rémi Cadène <re.cadene@gmail.com>",
|
|
||||||
"Alexander Soare <alexander.soare159@gmail.com>",
|
|
||||||
"Quentin Gallouédec <quentin.gallouedec@ec-lyon.fr>",
|
|
||||||
"Simon Alibert <alibert.sim@gmail.com>",
|
|
||||||
"Thomas Wolf <thomaswolfcontact@gmail.com>",
|
|
||||||
]
|
|
||||||
repository = "https://github.com/huggingface/lerobot"
|
|
||||||
readme = "README.md"
|
|
||||||
license = "Apache-2.0"
|
|
||||||
classifiers=[
|
|
||||||
"Development Status :: 3 - Alpha",
|
|
||||||
"Intended Audience :: Developers",
|
|
||||||
"Intended Audience :: Education",
|
|
||||||
"Intended Audience :: Science/Research",
|
|
||||||
"Topic :: Software Development :: Build Tools",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
||||||
"License :: OSI Approved :: Apache Software License",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
]
|
|
||||||
packages = [{include = "lerobot"}]
|
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
|
||||||
python = "^3.10"
|
|
||||||
termcolor = "^2.4.0"
|
|
||||||
omegaconf = "^2.3.0"
|
|
||||||
wandb = "^0.16.3"
|
|
||||||
imageio = {extras = ["ffmpeg"], version = "^2.34.0"}
|
|
||||||
gdown = "^5.1.0"
|
|
||||||
hydra-core = "^1.3.2"
|
|
||||||
einops = "^0.7.0"
|
|
||||||
pymunk = "^6.6.0"
|
|
||||||
zarr = "^2.17.0"
|
|
||||||
numba = "^0.59.0"
|
|
||||||
torch = {version = "^2.2.1", source = "torch-cpu"}
|
|
||||||
opencv-python = "^4.9.0.80"
|
|
||||||
diffusers = "^0.26.3"
|
|
||||||
torchvision = {version = "^0.17.1", source = "torch-cpu"}
|
|
||||||
h5py = "^3.10.0"
|
|
||||||
huggingface-hub = "^0.21.4"
|
|
||||||
robomimic = "0.2.0"
|
|
||||||
gymnasium = "^0.29.1"
|
|
||||||
cmake = "^3.29.0.1"
|
|
||||||
gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
|
|
||||||
gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
|
|
||||||
gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true}
|
|
||||||
pre-commit = {version = "^3.7.0", optional = true}
|
|
||||||
debugpy = {version = "^1.8.1", optional = true}
|
|
||||||
pytest = {version = "^8.1.0", optional = true}
|
|
||||||
pytest-cov = {version = "^5.0.0", optional = true}
|
|
||||||
datasets = "^2.19.0"
|
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
|
||||||
pusht = ["gym-pusht"]
|
|
||||||
xarm = ["gym-xarm"]
|
|
||||||
aloha = ["gym-aloha"]
|
|
||||||
dev = ["pre-commit", "debugpy"]
|
|
||||||
test = ["pytest", "pytest-cov"]
|
|
||||||
|
|
||||||
|
|
||||||
[[tool.poetry.source]]
|
|
||||||
name = "torch-cpu"
|
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
|
||||||
priority = "supplemental"
|
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 110
|
|
||||||
target-version = "py310"
|
|
||||||
exclude = [
|
|
||||||
".bzr",
|
|
||||||
".direnv",
|
|
||||||
".eggs",
|
|
||||||
".git",
|
|
||||||
".git-rewrite",
|
|
||||||
".hg",
|
|
||||||
".mypy_cache",
|
|
||||||
".nox",
|
|
||||||
".pants.d",
|
|
||||||
".pytype",
|
|
||||||
".ruff_cache",
|
|
||||||
".svn",
|
|
||||||
".tox",
|
|
||||||
".venv",
|
|
||||||
"__pypackages__",
|
|
||||||
"_build",
|
|
||||||
"buck-out",
|
|
||||||
"build",
|
|
||||||
"dist",
|
|
||||||
"node_modules",
|
|
||||||
"venv",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["poetry-core>=1.5.0"]
|
|
||||||
build-backend = "poetry.core.masonry.api"
|
|
||||||
30
.github/scripts/dep_build.py
vendored
Normal file
30
.github/scripts/dep_build.py
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
PYPROJECT = "pyproject.toml"
|
||||||
|
DEPS = {
|
||||||
|
"gym-pusht": '{ git = "git@github.com:huggingface/gym-pusht.git", optional = true}',
|
||||||
|
"gym-xarm": '{ git = "git@github.com:huggingface/gym-xarm.git", optional = true}',
|
||||||
|
"gym-aloha": '{ git = "git@github.com:huggingface/gym-aloha.git", optional = true}',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def update_envs_as_path_dependencies():
|
||||||
|
with open(PYPROJECT) as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
|
||||||
|
new_lines = []
|
||||||
|
for line in lines:
|
||||||
|
if any(dep in line for dep in DEPS.values()):
|
||||||
|
for dep in DEPS:
|
||||||
|
if dep in line:
|
||||||
|
new_line = f'{dep} = {{ path = "envs/{dep}/", optional = true}}\n'
|
||||||
|
new_lines.append(new_line)
|
||||||
|
break
|
||||||
|
|
||||||
|
else:
|
||||||
|
new_lines.append(line)
|
||||||
|
|
||||||
|
with open(PYPROJECT, "w") as file:
|
||||||
|
file.writelines(new_lines)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
update_envs_as_path_dependencies()
|
||||||
203
.github/workflows/build-docker-images.yml
vendored
Normal file
203
.github/workflows/build-docker-images.yml
vendored
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
# Inspired by
|
||||||
|
# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml
|
||||||
|
name: Builds
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
workflow_call:
|
||||||
|
schedule:
|
||||||
|
- cron: "0 1 * * *"
|
||||||
|
|
||||||
|
env:
|
||||||
|
PYTHON_VERSION: "3.10"
|
||||||
|
# CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
latest-cpu:
|
||||||
|
name: "Build CPU"
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Cleanup disk
|
||||||
|
run: |
|
||||||
|
sudo df -h
|
||||||
|
# sudo ls -l /usr/local/lib/
|
||||||
|
# sudo ls -l /usr/share/
|
||||||
|
sudo du -sh /usr/local/lib/
|
||||||
|
sudo du -sh /usr/share/
|
||||||
|
sudo rm -rf /usr/local/lib/android
|
||||||
|
sudo rm -rf /usr/share/dotnet
|
||||||
|
sudo du -sh /usr/local/lib/
|
||||||
|
sudo du -sh /usr/share/
|
||||||
|
sudo df -h
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Check out code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
# HACK(aliberts): to be removed for release
|
||||||
|
# -----------------------------------------
|
||||||
|
- name: Checkout gym-aloha
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
repository: huggingface/gym-aloha
|
||||||
|
path: envs/gym-aloha
|
||||||
|
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
|
|
||||||
|
- name: Checkout gym-xarm
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
repository: huggingface/gym-xarm
|
||||||
|
path: envs/gym-xarm
|
||||||
|
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
|
|
||||||
|
- name: Checkout gym-pusht
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
repository: huggingface/gym-pusht
|
||||||
|
path: envs/gym-pusht
|
||||||
|
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
|
|
||||||
|
- name: Set up Python 3.10
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Change envs dependencies as local path
|
||||||
|
run: python .github/scripts/dep_build.py
|
||||||
|
# -----------------------------------------
|
||||||
|
|
||||||
|
- name: Login to DockerHub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
|
|
||||||
|
- name: Build and Push CPU
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: ./docker/lerobot-cpu/Dockerfile
|
||||||
|
push: true
|
||||||
|
tags: huggingface/lerobot-cpu
|
||||||
|
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
||||||
|
|
||||||
|
# - name: Post to a Slack channel
|
||||||
|
# id: slack
|
||||||
|
# #uses: slackapi/slack-github-action@v1.25.0
|
||||||
|
# uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||||
|
# with:
|
||||||
|
# # Slack channel id, channel name, or user id to post message.
|
||||||
|
# # See also: https://api.slack.com/methods/chat.postMessage#channels
|
||||||
|
# channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
||||||
|
# # For posting a rich message using Block Kit
|
||||||
|
# payload: |
|
||||||
|
# {
|
||||||
|
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
||||||
|
# "blocks": [
|
||||||
|
# {
|
||||||
|
# "type": "section",
|
||||||
|
# "text": {
|
||||||
|
# "type": "mrkdwn",
|
||||||
|
# "text": "lerobot-cpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# env:
|
||||||
|
# SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||||
|
|
||||||
|
latest-cuda:
|
||||||
|
name: "Build GPU"
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Cleanup disk
|
||||||
|
run: |
|
||||||
|
sudo df -h
|
||||||
|
# sudo ls -l /usr/local/lib/
|
||||||
|
# sudo ls -l /usr/share/
|
||||||
|
sudo du -sh /usr/local/lib/
|
||||||
|
sudo du -sh /usr/share/
|
||||||
|
sudo rm -rf /usr/local/lib/android
|
||||||
|
sudo rm -rf /usr/share/dotnet
|
||||||
|
sudo du -sh /usr/local/lib/
|
||||||
|
sudo du -sh /usr/share/
|
||||||
|
sudo df -h
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Check out code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
# HACK(aliberts): to be removed for release
|
||||||
|
# -----------------------------------------
|
||||||
|
- name: Checkout gym-aloha
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
repository: huggingface/gym-aloha
|
||||||
|
path: envs/gym-aloha
|
||||||
|
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
|
|
||||||
|
- name: Checkout gym-xarm
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
repository: huggingface/gym-xarm
|
||||||
|
path: envs/gym-xarm
|
||||||
|
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
|
|
||||||
|
- name: Checkout gym-pusht
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
repository: huggingface/gym-pusht
|
||||||
|
path: envs/gym-pusht
|
||||||
|
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
|
|
||||||
|
- name: Set up Python 3.10
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Change envs dependencies as local path
|
||||||
|
run: python .github/scripts/dep_build.py
|
||||||
|
# -----------------------------------------
|
||||||
|
|
||||||
|
- name: Login to DockerHub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
|
|
||||||
|
- name: Build and Push GPU
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: ./docker/lerobot-gpu/Dockerfile
|
||||||
|
push: true
|
||||||
|
tags: huggingface/lerobot-gpu
|
||||||
|
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
||||||
|
|
||||||
|
# - name: Post to a Slack channel
|
||||||
|
# id: slack
|
||||||
|
# #uses: slackapi/slack-github-action@v1.25.0
|
||||||
|
# uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||||
|
# with:
|
||||||
|
# # Slack channel id, channel name, or user id to post message.
|
||||||
|
# # See also: https://api.slack.com/methods/chat.postMessage#channels
|
||||||
|
# channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
||||||
|
# # For posting a rich message using Block Kit
|
||||||
|
# payload: |
|
||||||
|
# {
|
||||||
|
# "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
||||||
|
# "blocks": [
|
||||||
|
# {
|
||||||
|
# "type": "section",
|
||||||
|
# "text": {
|
||||||
|
# "type": "mrkdwn",
|
||||||
|
# "text": "lerobot-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# env:
|
||||||
|
# SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||||
79
.github/workflows/nightly-tests.yml
vendored
Normal file
79
.github/workflows/nightly-tests.yml
vendored
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
# Inspired by
|
||||||
|
# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml
|
||||||
|
name: Nightly
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
schedule:
|
||||||
|
- cron: "0 2 * * *"
|
||||||
|
|
||||||
|
env:
|
||||||
|
DATA_DIR: tests/data
|
||||||
|
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_all_tests_cpu:
|
||||||
|
name: "Test CPU"
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container:
|
||||||
|
image: huggingface/lerobot-cpu:latest
|
||||||
|
options: --shm-size "16gb"
|
||||||
|
credentials:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: /lerobot
|
||||||
|
steps:
|
||||||
|
- name: Tests
|
||||||
|
env:
|
||||||
|
DATA_DIR: tests/data
|
||||||
|
run: pytest -v --cov=./lerobot --disable-warnings tests
|
||||||
|
|
||||||
|
- name: Tests end-to-end
|
||||||
|
env:
|
||||||
|
DATA_DIR: tests/data
|
||||||
|
run: make test-end-to-end
|
||||||
|
|
||||||
|
|
||||||
|
run_all_tests_single_gpu:
|
||||||
|
name: "Test GPU"
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
||||||
|
env:
|
||||||
|
CUDA_VISIBLE_DEVICES: "0"
|
||||||
|
TEST_TYPE: "single_gpu"
|
||||||
|
container:
|
||||||
|
image: huggingface/lerobot-gpu:latest
|
||||||
|
options: --gpus all --shm-size "16gb"
|
||||||
|
credentials:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: /lerobot
|
||||||
|
steps:
|
||||||
|
- name: Nvidia-smi
|
||||||
|
run: nvidia-smi
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: pytest -v --cov=./lerobot --cov-report=xml --disable-warnings tests
|
||||||
|
# TODO(aliberts): Link with HF Codecov account
|
||||||
|
# - name: Upload coverage reports to Codecov with GitHub Action
|
||||||
|
# uses: codecov/codecov-action@v4
|
||||||
|
# with:
|
||||||
|
# files: ./coverage.xml
|
||||||
|
# verbose: true
|
||||||
|
- name: Tests end-to-end
|
||||||
|
run: make test-end-to-end
|
||||||
|
|
||||||
|
# - name: Generate Report
|
||||||
|
# if: always()
|
||||||
|
# run: |
|
||||||
|
# pip install slack_sdk tabulate
|
||||||
|
# python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||||
38
.github/workflows/style.yml
vendored
Normal file
38
.github/workflows/style.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
name: Style
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
workflow_call:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
env:
|
||||||
|
PYTHON_VERSION: "3.10"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
ruff_check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout Repository
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
|
|
||||||
|
- name: Get Ruff Version from pre-commit-config.yaml
|
||||||
|
id: get-ruff-version
|
||||||
|
run: |
|
||||||
|
RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml)
|
||||||
|
echo "RUFF_VERSION=${RUFF_VERSION}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Install Ruff
|
||||||
|
run: python -m pip install "ruff==${{ env.RUFF_VERSION }}"
|
||||||
|
|
||||||
|
- name: Run Ruff
|
||||||
|
run: ruff check .
|
||||||
109
.github/workflows/test-docker-build.yml
vendored
Normal file
109
.github/workflows/test-docker-build.yml
vendored
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# Inspired by
|
||||||
|
# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml
|
||||||
|
name: Test Docker builds (PR)
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
paths:
|
||||||
|
# Run only when DockerFile files are modified
|
||||||
|
- "docker/**"
|
||||||
|
|
||||||
|
env:
|
||||||
|
PYTHON_VERSION: "3.10"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
get_changed_files:
|
||||||
|
name: "Get all modified Dockerfiles"
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||||
|
steps:
|
||||||
|
- name: Check out code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get changed files
|
||||||
|
id: changed-files
|
||||||
|
uses: tj-actions/changed-files@v44
|
||||||
|
with:
|
||||||
|
files: docker/**
|
||||||
|
json: "true"
|
||||||
|
|
||||||
|
- name: Run step if only the files listed above change
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
id: set-matrix
|
||||||
|
env:
|
||||||
|
ALL_CHANGED_FILES: ${{ steps.changed-files.outputs.all_changed_files }}
|
||||||
|
run: |
|
||||||
|
echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
|
||||||
|
build_modified_dockerfiles:
|
||||||
|
name: "Build all modified Docker images"
|
||||||
|
needs: get_changed_files
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
|
||||||
|
steps:
|
||||||
|
- name: Cleanup disk
|
||||||
|
run: |
|
||||||
|
sudo df -h
|
||||||
|
# sudo ls -l /usr/local/lib/
|
||||||
|
# sudo ls -l /usr/share/
|
||||||
|
sudo du -sh /usr/local/lib/
|
||||||
|
sudo du -sh /usr/share/
|
||||||
|
sudo rm -rf /usr/local/lib/android
|
||||||
|
sudo rm -rf /usr/share/dotnet
|
||||||
|
sudo du -sh /usr/local/lib/
|
||||||
|
sudo du -sh /usr/share/
|
||||||
|
sudo df -h
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Check out code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
# HACK(aliberts): to be removed for release
|
||||||
|
# -----------------------------------------
|
||||||
|
- name: Checkout gym-aloha
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
repository: huggingface/gym-aloha
|
||||||
|
path: envs/gym-aloha
|
||||||
|
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
|
|
||||||
|
- name: Checkout gym-xarm
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
repository: huggingface/gym-xarm
|
||||||
|
path: envs/gym-xarm
|
||||||
|
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
|
|
||||||
|
- name: Checkout gym-pusht
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
repository: huggingface/gym-pusht
|
||||||
|
path: envs/gym-pusht
|
||||||
|
ssh-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
|
|
||||||
|
- name: Set up Python 3.10
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Change envs dependencies as local path
|
||||||
|
run: python .github/scripts/dep_build.py
|
||||||
|
# -----------------------------------------
|
||||||
|
|
||||||
|
- name: Build Docker image
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
file: ${{ matrix.docker-file }}
|
||||||
|
context: .
|
||||||
|
push: False
|
||||||
|
build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }}
|
||||||
231
.github/workflows/test.yml
vendored
231
.github/workflows/test.yml
vendored
@@ -4,210 +4,71 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
types: [opened, synchronize, reopened, labeled]
|
paths:
|
||||||
|
- "lerobot/**"
|
||||||
|
- "tests/**"
|
||||||
|
- "examples/**"
|
||||||
|
- ".github/**"
|
||||||
|
- "poetry.lock"
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
|
paths:
|
||||||
|
- "lerobot/**"
|
||||||
|
- "tests/**"
|
||||||
|
- "examples/**"
|
||||||
|
- ".github/**"
|
||||||
|
- "poetry.lock"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
tests:
|
tests:
|
||||||
if: |
|
runs-on: ${{ matrix.os }}
|
||||||
${{ github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'CI') }} ||
|
strategy:
|
||||||
${{ github.event_name == 'push' }}
|
matrix:
|
||||||
runs-on: ubuntu-latest
|
os: [ubuntu-latest, macos-latest, macos-latest-large]
|
||||||
env:
|
env:
|
||||||
POETRY_VERSION: 1.8.2
|
|
||||||
DATA_DIR: tests/data
|
DATA_DIR: tests/data
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
steps:
|
steps:
|
||||||
#----------------------------------------------
|
|
||||||
# check-out repo and set-up python
|
|
||||||
#----------------------------------------------
|
|
||||||
- name: Check out repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
lfs: true
|
|
||||||
|
|
||||||
- name: Set up python
|
|
||||||
id: setup-python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Add SSH key for installing envs
|
- name: Add SSH key for installing envs
|
||||||
uses: webfactory/ssh-agent@v0.9.0
|
uses: webfactory/ssh-agent@v0.9.0
|
||||||
with:
|
with:
|
||||||
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
|
|
||||||
#----------------------------------------------
|
- uses: actions/checkout@v4
|
||||||
# install & configure poetry
|
|
||||||
#----------------------------------------------
|
- name: Install EGL
|
||||||
- name: Load cached Poetry installation
|
run: |
|
||||||
id: restore-poetry-cache
|
if [[ "${{ matrix.os }}" == 'ubuntu-latest' ]]; then
|
||||||
uses: actions/cache/restore@v3
|
sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||||
|
elif [[ "${{ matrix.os }}" == 'macos-latest' || "${{ matrix.os }}" == 'macos-latest-large' ]]; then
|
||||||
|
brew install mesa
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Install poetry
|
||||||
|
run: |
|
||||||
|
pipx install poetry && poetry config virtualenvs.in-project true
|
||||||
|
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||||
|
|
||||||
|
- name: Set up Python 3.10
|
||||||
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
path: ~/.local
|
python-version: "3.10"
|
||||||
key: poetry-${{ env.POETRY_VERSION }}
|
cache: "poetry"
|
||||||
|
|
||||||
- name: Install Poetry
|
- name: Install poetry dependencies
|
||||||
if: steps.restore-poetry-cache.outputs.cache-hit != 'true'
|
|
||||||
uses: snok/install-poetry@v1
|
|
||||||
with:
|
|
||||||
version: ${{ env.POETRY_VERSION }}
|
|
||||||
virtualenvs-create: true
|
|
||||||
installer-parallel: true
|
|
||||||
|
|
||||||
- name: Save cached Poetry installation
|
|
||||||
if: |
|
|
||||||
steps.restore-poetry-cache.outputs.cache-hit != 'true' &&
|
|
||||||
github.ref_name == 'main'
|
|
||||||
id: save-poetry-cache
|
|
||||||
uses: actions/cache/save@v3
|
|
||||||
with:
|
|
||||||
path: ~/.local
|
|
||||||
key: poetry-${{ env.POETRY_VERSION }}
|
|
||||||
|
|
||||||
- name: Configure Poetry
|
|
||||||
run: poetry config virtualenvs.in-project true
|
|
||||||
|
|
||||||
#----------------------------------------------
|
|
||||||
# install dependencies
|
|
||||||
#----------------------------------------------
|
|
||||||
# TODO(aliberts): move to gpu runners
|
|
||||||
- name: Select cpu dependencies # HACK
|
|
||||||
run: cp -t . .github/poetry/cpu/pyproject.toml .github/poetry/cpu/poetry.lock
|
|
||||||
|
|
||||||
- name: Load cached venv
|
|
||||||
id: restore-dependencies-cache
|
|
||||||
uses: actions/cache/restore@v3
|
|
||||||
with:
|
|
||||||
path: .venv
|
|
||||||
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
|
|
||||||
env:
|
|
||||||
TMPDIR: ~/tmp
|
|
||||||
TEMP: ~/tmp
|
|
||||||
TMP: ~/tmp
|
|
||||||
run: |
|
run: |
|
||||||
mkdir ~/tmp
|
poetry install --all-extras
|
||||||
poetry install --no-interaction --no-root --all-extras
|
|
||||||
|
|
||||||
- name: Save cached venv
|
- name: Test with pytest
|
||||||
if: |
|
|
||||||
steps.restore-dependencies-cache.outputs.cache-hit != 'true' &&
|
|
||||||
github.ref_name == 'main'
|
|
||||||
id: save-dependencies-cache
|
|
||||||
uses: actions/cache/save@v3
|
|
||||||
with:
|
|
||||||
path: .venv
|
|
||||||
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
|
|
||||||
|
|
||||||
- name: Install libegl1-mesa-dev (to use MUJOCO_GL=egl)
|
|
||||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
|
||||||
|
|
||||||
#----------------------------------------------
|
|
||||||
# install project
|
|
||||||
#----------------------------------------------
|
|
||||||
- name: Install project
|
|
||||||
run: poetry install --no-interaction --all-extras
|
|
||||||
|
|
||||||
#----------------------------------------------
|
|
||||||
# run tests & coverage
|
|
||||||
#----------------------------------------------
|
|
||||||
- name: Run tests
|
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
pytest tests -v --cov=./lerobot --durations=0 \
|
||||||
pytest -v --cov=./lerobot --cov-report=xml tests
|
-W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \
|
||||||
|
-W ignore::UserWarning:torch.utils.data.dataloader:558 \
|
||||||
|
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||||
|
&& rm -rf tests/outputs outputs
|
||||||
|
|
||||||
# TODO(aliberts): Link with HF Codecov account
|
- name: Test end-to-end
|
||||||
# - name: Upload coverage reports to Codecov with GitHub Action
|
|
||||||
# uses: codecov/codecov-action@v4
|
|
||||||
# with:
|
|
||||||
# files: ./coverage.xml
|
|
||||||
# verbose: true
|
|
||||||
|
|
||||||
#----------------------------------------------
|
|
||||||
# run end-to-end tests
|
|
||||||
#----------------------------------------------
|
|
||||||
- name: Test train ACT on ALOHA end-to-end
|
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
make test-end-to-end \
|
||||||
python lerobot/scripts/train.py \
|
&& rm -rf outputs
|
||||||
policy=act \
|
|
||||||
env=aloha \
|
|
||||||
wandb.enable=False \
|
|
||||||
offline_steps=2 \
|
|
||||||
online_steps=0 \
|
|
||||||
eval_episodes=1 \
|
|
||||||
device=cpu \
|
|
||||||
save_model=true \
|
|
||||||
save_freq=2 \
|
|
||||||
policy.n_action_steps=20 \
|
|
||||||
policy.chunk_size=20 \
|
|
||||||
policy.batch_size=2 \
|
|
||||||
hydra.run.dir=tests/outputs/act/
|
|
||||||
|
|
||||||
- name: Test eval ACT on ALOHA end-to-end
|
|
||||||
run: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
python lerobot/scripts/eval.py \
|
|
||||||
--config tests/outputs/act/.hydra/config.yaml \
|
|
||||||
eval_episodes=1 \
|
|
||||||
env.episode_length=8 \
|
|
||||||
device=cpu \
|
|
||||||
policy.pretrained_model_path=tests/outputs/act/models/2.pt
|
|
||||||
|
|
||||||
- name: Test train Diffusion on PushT end-to-end
|
|
||||||
run: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
python lerobot/scripts/train.py \
|
|
||||||
policy=diffusion \
|
|
||||||
env=pusht \
|
|
||||||
wandb.enable=False \
|
|
||||||
offline_steps=2 \
|
|
||||||
online_steps=0 \
|
|
||||||
eval_episodes=1 \
|
|
||||||
device=cpu \
|
|
||||||
save_model=true \
|
|
||||||
save_freq=2 \
|
|
||||||
policy.batch_size=2 \
|
|
||||||
hydra.run.dir=tests/outputs/diffusion/
|
|
||||||
|
|
||||||
- name: Test eval Diffusion on PushT end-to-end
|
|
||||||
run: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
python lerobot/scripts/eval.py \
|
|
||||||
--config tests/outputs/diffusion/.hydra/config.yaml \
|
|
||||||
eval_episodes=1 \
|
|
||||||
env.episode_length=8 \
|
|
||||||
device=cpu \
|
|
||||||
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
|
|
||||||
|
|
||||||
- name: Test train TDMPC on Simxarm end-to-end
|
|
||||||
run: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
python lerobot/scripts/train.py \
|
|
||||||
policy=tdmpc \
|
|
||||||
env=xarm \
|
|
||||||
wandb.enable=False \
|
|
||||||
offline_steps=1 \
|
|
||||||
online_steps=2 \
|
|
||||||
eval_episodes=1 \
|
|
||||||
env.episode_length=2 \
|
|
||||||
device=cpu \
|
|
||||||
save_model=true \
|
|
||||||
save_freq=2 \
|
|
||||||
policy.batch_size=2 \
|
|
||||||
hydra.run.dir=tests/outputs/tdmpc/
|
|
||||||
|
|
||||||
- name: Test eval TDMPC on Simxarm end-to-end
|
|
||||||
run: |
|
|
||||||
source .venv/bin/activate
|
|
||||||
python lerobot/scripts/eval.py \
|
|
||||||
--config tests/outputs/tdmpc/.hydra/config.yaml \
|
|
||||||
eval_episodes=1 \
|
|
||||||
env.episode_length=8 \
|
|
||||||
device=cpu \
|
|
||||||
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
exclude: ^(data/|tests/data)
|
exclude: ^(tests/data)
|
||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.10
|
python: python3.10
|
||||||
repos:
|
repos:
|
||||||
@@ -18,7 +18,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.3.7
|
rev: v0.4.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix]
|
args: [--fix]
|
||||||
|
|||||||
@@ -129,41 +129,38 @@ Follow these steps to start contributing:
|
|||||||
|
|
||||||
🚨 **Do not** work on the `main` branch.
|
🚨 **Do not** work on the `main` branch.
|
||||||
|
|
||||||
4. Instead of using `pip` directly, we use `poetry` for development purposes to easily track our dependencies.
|
4. for development, we use `poetry` instead of just `pip` to easily track our dependencies.
|
||||||
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
|
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
|
||||||
Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
|
|
||||||
Install the project with dev dependencies and all environments:
|
|
||||||
```bash
|
|
||||||
poetry install --sync --with dev --all-extras
|
|
||||||
```
|
|
||||||
This command should be run when pulling code with and updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the dependencies.
|
|
||||||
|
|
||||||
To selectively install environments (for example aloha and pusht) use:
|
Set up a development environment with conda or miniconda:
|
||||||
```bash
|
```bash
|
||||||
poetry install --sync --with dev --extras "aloha pusht"
|
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library:
|
||||||
|
```bash
|
||||||
|
poetry install --sync --extras "dev test"
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also install the project with all its dependencies (including environments):
|
||||||
|
```bash
|
||||||
|
poetry install --sync --all-extras
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they *will* be tested in the CI. In general, we advise you to install everything and test locally before pushing.
|
||||||
|
|
||||||
|
Whichever command you chose to install the project (e.g. `poetry install --sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies.
|
||||||
|
|
||||||
The equivalent of `pip install some-package`, would just be:
|
The equivalent of `pip install some-package`, would just be:
|
||||||
```bash
|
```bash
|
||||||
poetry add some-package
|
poetry add some-package
|
||||||
```
|
```
|
||||||
|
|
||||||
When changes are made to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
|
||||||
```bash
|
```bash
|
||||||
poetry lock --no-update
|
poetry lock --no-update
|
||||||
```
|
```
|
||||||
|
|
||||||
**NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated to the CI. To do this, you should create a separate environment and add the new package there as well. For example:
|
|
||||||
```bash
|
|
||||||
# Add the new package to your main poetry env
|
|
||||||
poetry add some-package
|
|
||||||
# Add the same package to the CPU-only env dedicated to CI
|
|
||||||
conda create -y -n lerobot-ci python=3.10
|
|
||||||
conda activate lerobot-ci
|
|
||||||
cd .github/poetry/cpu
|
|
||||||
poetry add some-package
|
|
||||||
```
|
|
||||||
|
|
||||||
5. Develop the features on your branch.
|
5. Develop the features on your branch.
|
||||||
|
|
||||||
As you work on the features, you should make sure that the test suite
|
As you work on the features, you should make sure that the test suite
|
||||||
|
|||||||
95
Makefile
Normal file
95
Makefile
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
.PHONY: tests
|
||||||
|
|
||||||
|
PYTHON_PATH := $(shell which python)
|
||||||
|
|
||||||
|
# If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python
|
||||||
|
POETRY_CHECK := $(shell command -v poetry)
|
||||||
|
ifneq ($(POETRY_CHECK),)
|
||||||
|
PYTHON_PATH := $(shell poetry run which python)
|
||||||
|
endif
|
||||||
|
|
||||||
|
export PATH := $(dir $(PYTHON_PATH)):$(PATH)
|
||||||
|
|
||||||
|
|
||||||
|
build-cpu:
|
||||||
|
docker build -t lerobot:latest -f docker/lerobot-cpu/Dockerfile .
|
||||||
|
|
||||||
|
build-gpu:
|
||||||
|
docker build -t lerobot:latest -f docker/lerobot-gpu/Dockerfile .
|
||||||
|
|
||||||
|
test-end-to-end:
|
||||||
|
${MAKE} test-act-ete-train
|
||||||
|
${MAKE} test-act-ete-eval
|
||||||
|
${MAKE} test-diffusion-ete-train
|
||||||
|
${MAKE} test-diffusion-ete-eval
|
||||||
|
${MAKE} test-tdmpc-ete-train
|
||||||
|
${MAKE} test-tdmpc-ete-eval
|
||||||
|
|
||||||
|
test-act-ete-train:
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
policy=act \
|
||||||
|
env=aloha \
|
||||||
|
wandb.enable=False \
|
||||||
|
offline_steps=2 \
|
||||||
|
online_steps=0 \
|
||||||
|
eval_episodes=1 \
|
||||||
|
device=cpu \
|
||||||
|
save_model=true \
|
||||||
|
save_freq=2 \
|
||||||
|
policy.n_action_steps=20 \
|
||||||
|
policy.chunk_size=20 \
|
||||||
|
policy.batch_size=2 \
|
||||||
|
hydra.run.dir=tests/outputs/act/
|
||||||
|
|
||||||
|
test-act-ete-eval:
|
||||||
|
python lerobot/scripts/eval.py \
|
||||||
|
--config tests/outputs/act/.hydra/config.yaml \
|
||||||
|
eval_episodes=1 \
|
||||||
|
env.episode_length=8 \
|
||||||
|
device=cpu \
|
||||||
|
policy.pretrained_model_path=tests/outputs/act/models/2.pt
|
||||||
|
|
||||||
|
test-diffusion-ete-train:
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
policy=diffusion \
|
||||||
|
env=pusht \
|
||||||
|
wandb.enable=False \
|
||||||
|
offline_steps=2 \
|
||||||
|
online_steps=0 \
|
||||||
|
eval_episodes=1 \
|
||||||
|
device=cpu \
|
||||||
|
save_model=true \
|
||||||
|
save_freq=2 \
|
||||||
|
policy.batch_size=2 \
|
||||||
|
hydra.run.dir=tests/outputs/diffusion/
|
||||||
|
|
||||||
|
test-diffusion-ete-eval:
|
||||||
|
python lerobot/scripts/eval.py \
|
||||||
|
--config tests/outputs/diffusion/.hydra/config.yaml \
|
||||||
|
eval_episodes=1 \
|
||||||
|
env.episode_length=8 \
|
||||||
|
device=cpu \
|
||||||
|
policy.pretrained_model_path=tests/outputs/diffusion/models/2.pt
|
||||||
|
|
||||||
|
test-tdmpc-ete-train:
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
policy=tdmpc \
|
||||||
|
env=xarm \
|
||||||
|
wandb.enable=False \
|
||||||
|
offline_steps=1 \
|
||||||
|
online_steps=2 \
|
||||||
|
eval_episodes=1 \
|
||||||
|
env.episode_length=2 \
|
||||||
|
device=cpu \
|
||||||
|
save_model=true \
|
||||||
|
save_freq=2 \
|
||||||
|
policy.batch_size=2 \
|
||||||
|
hydra.run.dir=tests/outputs/tdmpc/
|
||||||
|
|
||||||
|
test-tdmpc-ete-eval:
|
||||||
|
python lerobot/scripts/eval.py \
|
||||||
|
--config tests/outputs/tdmpc/.hydra/config.yaml \
|
||||||
|
eval_episodes=1 \
|
||||||
|
env.episode_length=8 \
|
||||||
|
device=cpu \
|
||||||
|
policy.pretrained_model_path=tests/outputs/tdmpc/models/2.pt
|
||||||
@@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
[](https://github.com/huggingface/lerobot/actions/workflows/test.yml?query=branch%3Amain)
|
[](https://github.com/huggingface/lerobot/actions/workflows/nightly-tests.yml?query=branch%3Amain)
|
||||||
[](https://codecov.io/gh/huggingface/lerobot)
|
[](https://codecov.io/gh/huggingface/lerobot)
|
||||||
[](https://www.python.org/downloads/)
|
[](https://www.python.org/downloads/)
|
||||||
[](https://github.com/huggingface/lerobot/blob/main/LICENSE)
|
[](https://github.com/huggingface/lerobot/blob/main/LICENSE)
|
||||||
@@ -73,7 +73,7 @@ conda create -y -n lerobot python=3.10 && conda activate lerobot
|
|||||||
|
|
||||||
Install 🤗 LeRobot:
|
Install 🤗 LeRobot:
|
||||||
```bash
|
```bash
|
||||||
python -m pip install .
|
pip install .
|
||||||
```
|
```
|
||||||
|
|
||||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||||
@@ -83,7 +83,7 @@ For simulations, 🤗 LeRobot comes with gymnasium environments that can be inst
|
|||||||
|
|
||||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||||
```bash
|
```bash
|
||||||
python -m pip install ".[aloha, pusht]"
|
pip install ".[aloha, pusht]"
|
||||||
```
|
```
|
||||||
|
|
||||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
|
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
|
||||||
|
|||||||
31
docker/lerobot-cpu/Dockerfile
Normal file
31
docker/lerobot-cpu/Dockerfile
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# Configure image
|
||||||
|
ARG PYTHON_VERSION=3.10
|
||||||
|
|
||||||
|
FROM python:${PYTHON_VERSION}-slim
|
||||||
|
ARG PYTHON_VERSION
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# Install apt dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
build-essential cmake \
|
||||||
|
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||||
|
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Create virtual environment
|
||||||
|
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||||
|
RUN python -m venv /opt/venv
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||||
|
|
||||||
|
# Install LeRobot
|
||||||
|
COPY . /lerobot
|
||||||
|
WORKDIR /lerobot
|
||||||
|
RUN pip install --upgrade --no-cache-dir pip
|
||||||
|
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]" \
|
||||||
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
|
# Set EGL as the rendering backend for MuJoCo
|
||||||
|
ENV MUJOCO_GL="egl"
|
||||||
|
|
||||||
|
# Execute in bash shell rather than python
|
||||||
|
CMD ["/bin/bash"]
|
||||||
27
docker/lerobot-gpu/Dockerfile
Normal file
27
docker/lerobot-gpu/Dockerfile
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||||
|
|
||||||
|
# Configure image
|
||||||
|
ARG PYTHON_VERSION=3.10
|
||||||
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# Install apt dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
build-essential cmake \
|
||||||
|
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||||
|
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||||
|
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Create virtual environment
|
||||||
|
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
||||||
|
RUN python -m venv /opt/venv
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||||
|
|
||||||
|
# Install LeRobot
|
||||||
|
COPY . /lerobot
|
||||||
|
WORKDIR /lerobot
|
||||||
|
RUN pip install --upgrade --no-cache-dir pip
|
||||||
|
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"
|
||||||
|
|
||||||
|
# Set EGL as the rendering backend for MuJoCo
|
||||||
|
ENV MUJOCO_GL="egl"
|
||||||
@@ -1,550 +0,0 @@
|
|||||||
"""
|
|
||||||
This file contains all obsolete download scripts. They are centralized here to not have to load
|
|
||||||
useless dependencies when using datasets.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import pickle
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import einops
|
|
||||||
import h5py
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
|
||||||
from huggingface_hub import HfApi
|
|
||||||
from PIL import Image as PILImage
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload(root, revision, dataset_id):
|
|
||||||
# TODO(rcadene, adilzouitine): add community_id/user_id (e.g. "lerobot", "cadene") or repo_id (e.g. "lerobot/pusht")
|
|
||||||
if "pusht" in dataset_id:
|
|
||||||
download_and_upload_pusht(root, revision, dataset_id)
|
|
||||||
elif "xarm" in dataset_id:
|
|
||||||
download_and_upload_xarm(root, revision, dataset_id)
|
|
||||||
elif "aloha" in dataset_id:
|
|
||||||
download_and_upload_aloha(root, revision, dataset_id)
|
|
||||||
else:
|
|
||||||
raise ValueError(dataset_id)
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
print(f"downloading from {url}")
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
if response.status_code == 200:
|
|
||||||
total_size = int(response.headers.get("content-length", 0))
|
|
||||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
|
||||||
|
|
||||||
zip_file = io.BytesIO()
|
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
|
||||||
if chunk:
|
|
||||||
zip_file.write(chunk)
|
|
||||||
progress_bar.update(len(chunk))
|
|
||||||
|
|
||||||
progress_bar.close()
|
|
||||||
|
|
||||||
zip_file.seek(0)
|
|
||||||
|
|
||||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
|
||||||
zip_ref.extractall(destination_folder)
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def concatenate_episodes(ep_dicts):
|
|
||||||
data_dict = {}
|
|
||||||
|
|
||||||
keys = ep_dicts[0].keys()
|
|
||||||
for key in keys:
|
|
||||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
|
||||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
|
||||||
else:
|
|
||||||
if key not in data_dict:
|
|
||||||
data_dict[key] = []
|
|
||||||
for ep_dict in ep_dicts:
|
|
||||||
for x in ep_dict[key]:
|
|
||||||
data_dict[key].append(x)
|
|
||||||
|
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
return data_dict
|
|
||||||
|
|
||||||
|
|
||||||
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
|
|
||||||
# push to main to indicate latest version
|
|
||||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
|
||||||
|
|
||||||
# push to version branch
|
|
||||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision)
|
|
||||||
|
|
||||||
# create and store meta_data
|
|
||||||
meta_data_dir = root / dataset_id / "meta_data"
|
|
||||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
api = HfApi()
|
|
||||||
|
|
||||||
# info
|
|
||||||
info_path = meta_data_dir / "info.json"
|
|
||||||
with open(str(info_path), "w") as f:
|
|
||||||
json.dump(info, f, indent=4)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=info_path,
|
|
||||||
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
|
|
||||||
repo_id=f"lerobot/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=info_path,
|
|
||||||
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
|
|
||||||
repo_id=f"lerobot/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
# stats
|
|
||||||
stats_path = meta_data_dir / "stats.safetensors"
|
|
||||||
save_file(flatten_dict(stats), stats_path)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=stats_path,
|
|
||||||
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
|
|
||||||
repo_id=f"lerobot/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=stats_path,
|
|
||||||
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
|
|
||||||
repo_id=f"lerobot/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
# episode_data_index
|
|
||||||
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
|
||||||
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
|
||||||
save_file(episode_data_index, ep_data_idx_path)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=ep_data_idx_path,
|
|
||||||
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
|
|
||||||
repo_id=f"lerobot/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
)
|
|
||||||
api.upload_file(
|
|
||||||
path_or_fileobj=ep_data_idx_path,
|
|
||||||
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
|
|
||||||
repo_id=f"lerobot/{dataset_id}",
|
|
||||||
repo_type="dataset",
|
|
||||||
revision=revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
# copy in tests folder, the first episode and the meta_data directory
|
|
||||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
|
||||||
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
|
|
||||||
f"tests/data/lerobot/{dataset_id}/train"
|
|
||||||
)
|
|
||||||
if Path(f"tests/data/lerobot/{dataset_id}/meta_data").exists():
|
|
||||||
shutil.rmtree(f"tests/data/lerobot/{dataset_id}/meta_data")
|
|
||||||
shutil.copytree(meta_data_dir, f"tests/data/lerobot/{dataset_id}/meta_data")
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
|
||||||
try:
|
|
||||||
import pymunk
|
|
||||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
|
||||||
|
|
||||||
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
|
|
||||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
|
||||||
)
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# as define in env
|
|
||||||
success_threshold = 0.95 # 95% coverage,
|
|
||||||
|
|
||||||
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
|
||||||
pusht_zarr = Path("pusht/pusht_cchi_v7_replay.zarr")
|
|
||||||
|
|
||||||
root = Path(root)
|
|
||||||
raw_dir = root / f"{dataset_id}_raw"
|
|
||||||
zarr_path = (raw_dir / pusht_zarr).resolve()
|
|
||||||
if not zarr_path.is_dir():
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
download_and_extract_zip(pusht_url, raw_dir)
|
|
||||||
|
|
||||||
# load
|
|
||||||
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) # , keys=['img', 'state', 'action'])
|
|
||||||
|
|
||||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
|
||||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
|
||||||
assert len(
|
|
||||||
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
|
||||||
), "Some data type dont have the same number of total frames."
|
|
||||||
|
|
||||||
# TODO: verify that goal pose is expected to be fixed
|
|
||||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
|
||||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
|
||||||
|
|
||||||
imgs = torch.from_numpy(dataset_dict["img"]) # b h w c
|
|
||||||
states = torch.from_numpy(dataset_dict["state"])
|
|
||||||
actions = torch.from_numpy(dataset_dict["action"])
|
|
||||||
|
|
||||||
ep_dicts = []
|
|
||||||
episode_data_index = {"from": [], "to": []}
|
|
||||||
|
|
||||||
id_from = 0
|
|
||||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
|
||||||
id_to = dataset_dict.meta["episode_ends"][episode_id]
|
|
||||||
|
|
||||||
num_frames = id_to - id_from
|
|
||||||
|
|
||||||
assert (episode_ids[id_from:id_to] == episode_id).all()
|
|
||||||
|
|
||||||
image = imgs[id_from:id_to]
|
|
||||||
assert image.min() >= 0.0
|
|
||||||
assert image.max() <= 255.0
|
|
||||||
image = image.type(torch.uint8)
|
|
||||||
|
|
||||||
state = states[id_from:id_to]
|
|
||||||
agent_pos = state[:, :2]
|
|
||||||
block_pos = state[:, 2:4]
|
|
||||||
block_angle = state[:, 4]
|
|
||||||
|
|
||||||
reward = torch.zeros(num_frames)
|
|
||||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
|
||||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
|
||||||
for i in range(num_frames):
|
|
||||||
space = pymunk.Space()
|
|
||||||
space.gravity = 0, 0
|
|
||||||
space.damping = 0
|
|
||||||
|
|
||||||
# Add walls.
|
|
||||||
walls = [
|
|
||||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
|
||||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
|
||||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
|
||||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
|
||||||
]
|
|
||||||
space.add(*walls)
|
|
||||||
|
|
||||||
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
|
||||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
|
||||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
|
||||||
intersection_area = goal_geom.intersection(block_geom).area
|
|
||||||
goal_area = goal_geom.area
|
|
||||||
coverage = intersection_area / goal_area
|
|
||||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
|
||||||
success[i] = coverage > success_threshold
|
|
||||||
|
|
||||||
# last step of demonstration is considered done
|
|
||||||
done[-1] = True
|
|
||||||
|
|
||||||
ep_dict = {
|
|
||||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
|
||||||
"observation.state": agent_pos,
|
|
||||||
"action": actions[id_from:id_to],
|
|
||||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
|
||||||
"frame_index": torch.arange(0, num_frames, 1),
|
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
|
||||||
# "next.observation.image": image[1:],
|
|
||||||
# "next.observation.state": agent_pos[1:],
|
|
||||||
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
|
||||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
|
||||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
|
||||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
|
||||||
}
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from += num_frames
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
|
|
||||||
features = {
|
|
||||||
"observation.image": Image(),
|
|
||||||
"observation.state": Sequence(
|
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
),
|
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
|
||||||
"episode_index": Value(dtype="int64", id=None),
|
|
||||||
"frame_index": Value(dtype="int64", id=None),
|
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
|
||||||
"next.reward": Value(dtype="float32", id=None),
|
|
||||||
"next.done": Value(dtype="bool", id=None),
|
|
||||||
"next.success": Value(dtype="bool", id=None),
|
|
||||||
"index": Value(dtype="int64", id=None),
|
|
||||||
}
|
|
||||||
features = Features(features)
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"fps": fps,
|
|
||||||
}
|
|
||||||
stats = compute_stats(hf_dataset)
|
|
||||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
|
||||||
root = Path(root)
|
|
||||||
raw_dir = root / "xarm_datasets_raw"
|
|
||||||
if not raw_dir.exists():
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
import gdown
|
|
||||||
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
|
||||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
|
||||||
zip_path = raw_dir / "data.zip"
|
|
||||||
gdown.download(url, str(zip_path), quiet=False)
|
|
||||||
print("Extracting...")
|
|
||||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
|
||||||
for member in zip_f.namelist():
|
|
||||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
|
||||||
print(member)
|
|
||||||
zip_f.extract(member=member)
|
|
||||||
zip_path.unlink()
|
|
||||||
|
|
||||||
dataset_path = root / f"{dataset_id}" / "buffer.pkl"
|
|
||||||
print(f"Using offline dataset '{dataset_path}'")
|
|
||||||
with open(dataset_path, "rb") as f:
|
|
||||||
dataset_dict = pickle.load(f)
|
|
||||||
|
|
||||||
ep_dicts = []
|
|
||||||
episode_data_index = {"from": [], "to": []}
|
|
||||||
|
|
||||||
id_from = 0
|
|
||||||
id_to = 0
|
|
||||||
episode_id = 0
|
|
||||||
total_frames = dataset_dict["actions"].shape[0]
|
|
||||||
for i in tqdm.tqdm(range(total_frames)):
|
|
||||||
id_to += 1
|
|
||||||
|
|
||||||
if not dataset_dict["dones"][i]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
num_frames = id_to - id_from
|
|
||||||
|
|
||||||
image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to])
|
|
||||||
image = einops.rearrange(image, "b c h w -> b h w c")
|
|
||||||
state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to])
|
|
||||||
action = torch.tensor(dataset_dict["actions"][id_from:id_to])
|
|
||||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
|
||||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
|
||||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to])
|
|
||||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to])
|
|
||||||
next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to])
|
|
||||||
next_done = torch.tensor(dataset_dict["dones"][id_from:id_to])
|
|
||||||
|
|
||||||
ep_dict = {
|
|
||||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
|
||||||
"observation.state": state,
|
|
||||||
"action": action,
|
|
||||||
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
|
||||||
"frame_index": torch.arange(0, num_frames, 1),
|
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
|
||||||
# "next.observation.image": next_image,
|
|
||||||
# "next.observation.state": next_state,
|
|
||||||
"next.reward": next_reward,
|
|
||||||
"next.done": next_done,
|
|
||||||
}
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from = id_to
|
|
||||||
episode_id += 1
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
|
|
||||||
features = {
|
|
||||||
"observation.image": Image(),
|
|
||||||
"observation.state": Sequence(
|
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
),
|
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
|
||||||
"episode_index": Value(dtype="int64", id=None),
|
|
||||||
"frame_index": Value(dtype="int64", id=None),
|
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
|
||||||
"next.reward": Value(dtype="float32", id=None),
|
|
||||||
"next.done": Value(dtype="bool", id=None),
|
|
||||||
#'next.success': Value(dtype='bool', id=None),
|
|
||||||
"index": Value(dtype="int64", id=None),
|
|
||||||
}
|
|
||||||
features = Features(features)
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"fps": fps,
|
|
||||||
}
|
|
||||||
stats = compute_stats(hf_dataset)
|
|
||||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
|
||||||
folder_urls = {
|
|
||||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
|
||||||
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
|
||||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
|
|
||||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
|
|
||||||
}
|
|
||||||
|
|
||||||
ep48_urls = {
|
|
||||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
|
|
||||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
|
|
||||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
|
|
||||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
|
|
||||||
}
|
|
||||||
|
|
||||||
ep49_urls = {
|
|
||||||
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
|
|
||||||
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
|
|
||||||
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
|
|
||||||
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
|
|
||||||
}
|
|
||||||
|
|
||||||
num_episodes = {
|
|
||||||
"aloha_sim_insertion_human": 50,
|
|
||||||
"aloha_sim_insertion_scripted": 50,
|
|
||||||
"aloha_sim_transfer_cube_human": 50,
|
|
||||||
"aloha_sim_transfer_cube_scripted": 50,
|
|
||||||
}
|
|
||||||
|
|
||||||
episode_len = {
|
|
||||||
"aloha_sim_insertion_human": 500,
|
|
||||||
"aloha_sim_insertion_scripted": 400,
|
|
||||||
"aloha_sim_transfer_cube_human": 400,
|
|
||||||
"aloha_sim_transfer_cube_scripted": 400,
|
|
||||||
}
|
|
||||||
|
|
||||||
cameras = {
|
|
||||||
"aloha_sim_insertion_human": ["top"],
|
|
||||||
"aloha_sim_insertion_scripted": ["top"],
|
|
||||||
"aloha_sim_transfer_cube_human": ["top"],
|
|
||||||
"aloha_sim_transfer_cube_scripted": ["top"],
|
|
||||||
}
|
|
||||||
|
|
||||||
root = Path(root)
|
|
||||||
raw_dir = root / f"{dataset_id}_raw"
|
|
||||||
if not raw_dir.is_dir():
|
|
||||||
import gdown
|
|
||||||
|
|
||||||
assert dataset_id in folder_urls
|
|
||||||
assert dataset_id in ep48_urls
|
|
||||||
assert dataset_id in ep49_urls
|
|
||||||
|
|
||||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
gdown.download_folder(folder_urls[dataset_id], output=str(raw_dir))
|
|
||||||
|
|
||||||
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
|
||||||
gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True)
|
|
||||||
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
|
|
||||||
|
|
||||||
ep_dicts = []
|
|
||||||
episode_data_index = {"from": [], "to": []}
|
|
||||||
|
|
||||||
id_from = 0
|
|
||||||
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
|
|
||||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
|
||||||
with h5py.File(ep_path, "r") as ep:
|
|
||||||
num_frames = ep["/action"].shape[0]
|
|
||||||
assert episode_len[dataset_id] == num_frames
|
|
||||||
|
|
||||||
# last step of demonstration is considered done
|
|
||||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
|
||||||
done[-1] = True
|
|
||||||
|
|
||||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
|
||||||
action = torch.from_numpy(ep["/action"][:])
|
|
||||||
|
|
||||||
ep_dict = {}
|
|
||||||
|
|
||||||
for cam in cameras[dataset_id]:
|
|
||||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
|
|
||||||
# image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
|
||||||
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
|
|
||||||
# ep_dict[f"next.observation.images.{cam}"] = image
|
|
||||||
|
|
||||||
ep_dict.update(
|
|
||||||
{
|
|
||||||
"observation.state": state,
|
|
||||||
"action": action,
|
|
||||||
"episode_index": torch.tensor([ep_id] * num_frames),
|
|
||||||
"frame_index": torch.arange(0, num_frames, 1),
|
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
|
||||||
# "next.observation.state": state,
|
|
||||||
# TODO(rcadene): compute reward and success
|
|
||||||
# "next.reward": reward,
|
|
||||||
"next.done": done,
|
|
||||||
# "next.success": success,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(ep_id, int)
|
|
||||||
ep_dicts.append(ep_dict)
|
|
||||||
|
|
||||||
episode_data_index["from"].append(id_from)
|
|
||||||
episode_data_index["to"].append(id_from + num_frames)
|
|
||||||
|
|
||||||
id_from += num_frames
|
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
|
||||||
|
|
||||||
features = {
|
|
||||||
"observation.images.top": Image(),
|
|
||||||
"observation.state": Sequence(
|
|
||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
|
||||||
),
|
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
|
||||||
"episode_index": Value(dtype="int64", id=None),
|
|
||||||
"frame_index": Value(dtype="int64", id=None),
|
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
|
||||||
#'next.reward': Value(dtype='float32', id=None),
|
|
||||||
"next.done": Value(dtype="bool", id=None),
|
|
||||||
#'next.success': Value(dtype='bool', id=None),
|
|
||||||
"index": Value(dtype="int64", id=None),
|
|
||||||
}
|
|
||||||
features = Features(features)
|
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"fps": fps,
|
|
||||||
}
|
|
||||||
stats = compute_stats(hf_dataset)
|
|
||||||
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
root = "data"
|
|
||||||
revision = "v1.1"
|
|
||||||
|
|
||||||
dataset_ids = [
|
|
||||||
"pusht",
|
|
||||||
"xarm_lift_medium",
|
|
||||||
"xarm_lift_medium_replay",
|
|
||||||
"xarm_push_medium",
|
|
||||||
"xarm_push_medium_replay",
|
|
||||||
"aloha_sim_insertion_human",
|
|
||||||
"aloha_sim_insertion_scripted",
|
|
||||||
"aloha_sim_transfer_cube_human",
|
|
||||||
"aloha_sim_transfer_cube_scripted",
|
|
||||||
]
|
|
||||||
for dataset_id in dataset_ids:
|
|
||||||
download_and_upload(root, revision, dataset_id)
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
"""
|
|
||||||
This script demonstrates the visualization of various robotic datasets from Hugging Face hub.
|
|
||||||
It covers the steps from loading the datasets, filtering specific episodes, and converting the frame data to MP4 videos.
|
|
||||||
Importantly, the dataset format is agnostic to any deep learning library and doesn't require using `lerobot` functions.
|
|
||||||
It is compatible with pytorch, jax, numpy, etc.
|
|
||||||
|
|
||||||
As an example, this script saves frames of episode number 5 of the PushT dataset to a mp4 video and saves the result here:
|
|
||||||
`outputs/examples/1_visualize_hugging_face_datasets/episode_5.mp4`
|
|
||||||
|
|
||||||
This script supports several Hugging Face datasets, among which:
|
|
||||||
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
|
|
||||||
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
|
|
||||||
3. [Xarm Lift Medium Replay](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay)
|
|
||||||
4. [Xarm Push Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium)
|
|
||||||
5. [Xarm Push Medium Replay](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay)
|
|
||||||
6. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
|
||||||
7. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
|
||||||
8. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
|
||||||
9. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
|
||||||
|
|
||||||
To try a different Hugging Face dataset, you can replace this line:
|
|
||||||
```python
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
|
||||||
```
|
|
||||||
by one of these:
|
|
||||||
```python
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium_replay", split="train"), 15
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium", split="train"), 15
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium_replay", split="train"), 15
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
# TODO(rcadene): remove this example file of using hf_dataset
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import imageio
|
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
|
||||||
|
|
||||||
# download/load hugging face dataset in pyarrow format
|
|
||||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train", revision="v1.1"), 10
|
|
||||||
|
|
||||||
# display name of dataset and its features
|
|
||||||
# TODO(rcadene): update to make the print pretty
|
|
||||||
print(f"{hf_dataset=}")
|
|
||||||
print(f"{hf_dataset.features=}")
|
|
||||||
|
|
||||||
# display useful statistics about frames and episodes, which are sequences of frames from the same video
|
|
||||||
print(f"number of frames: {len(hf_dataset)=}")
|
|
||||||
print(f"number of episodes: {len(hf_dataset.unique('episode_index'))=}")
|
|
||||||
print(
|
|
||||||
f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_index')):.3f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# select the frames belonging to episode number 5
|
|
||||||
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
|
||||||
|
|
||||||
# load all frames of episode 5 in RAM in PIL format
|
|
||||||
frames = hf_dataset["observation.image"]
|
|
||||||
|
|
||||||
# save episode frames to a mp4 video
|
|
||||||
Path("outputs/examples/1_load_hugging_face_dataset").mkdir(parents=True, exist_ok=True)
|
|
||||||
imageio.mimsave("outputs/examples/1_load_hugging_face_dataset/episode_5.mp4", frames, fps=fps)
|
|
||||||
@@ -58,8 +58,8 @@ frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
|||||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||||
|
|
||||||
# and finally save them to a mp4 video
|
# and finally save them to a mp4 video
|
||||||
Path("outputs/examples/2_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
||||||
imageio.mimsave("outputs/examples/2_load_lerobot_dataset/episode_5.mp4", frames, fps=dataset.fps)
|
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_5.mp4", frames, fps=dataset.fps)
|
||||||
|
|
||||||
# For many machine learning applications we need to load histories of past observations, or trajectorys of future actions. Our datasets can load previous and future frames for each key/modality,
|
# For many machine learning applications we need to load histories of past observations, or trajectorys of future actions. Our datasets can load previous and future frames for each key/modality,
|
||||||
# using timestamps differences with the current loaded frame. For instance:
|
# using timestamps differences with the current loaded frame. For instance:
|
||||||
@@ -25,6 +25,8 @@ When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps
|
|||||||
- Update variables in `tests/test_available.py` by importing your new Policy class
|
- Update variables in `tests/test_available.py` by importing your new Policy class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
from lerobot.__version__ import __version__ # noqa: F401
|
from lerobot.__version__ import __version__ # noqa: F401
|
||||||
|
|
||||||
available_tasks_per_env = {
|
available_tasks_per_env = {
|
||||||
@@ -52,7 +54,19 @@ available_datasets_per_env = {
|
|||||||
"lerobot/xarm_push_medium_replay",
|
"lerobot/xarm_push_medium_replay",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
available_datasets = [dataset for datasets in available_datasets_per_env.values() for dataset in datasets]
|
|
||||||
|
available_datasets_without_env = ["lerobot/umi_cup_in_the_wild"]
|
||||||
|
|
||||||
|
available_datasets = list(
|
||||||
|
itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env)
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(rcadene, aliberts, alexander-soare): Add real-world env with a gym API
|
||||||
|
available_datasets_without_env = ["lerobot/umi_cup_in_the_wild"]
|
||||||
|
|
||||||
|
available_datasets = list(
|
||||||
|
itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env)
|
||||||
|
)
|
||||||
|
|
||||||
available_policies = [
|
available_policies = [
|
||||||
"act",
|
"act",
|
||||||
|
|||||||
179
lerobot/common/datasets/push_dataset_to_hub/_download_raw.py
Normal file
179
lerobot/common/datasets/push_dataset_to_hub/_download_raw.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
"""
|
||||||
|
This file contains all obsolete download scripts. They are centralized here to not have to load
|
||||||
|
useless dependencies when using datasets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def download_raw(root, dataset_id) -> Path:
|
||||||
|
if "pusht" in dataset_id:
|
||||||
|
return download_pusht(root=root, dataset_id=dataset_id)
|
||||||
|
elif "xarm" in dataset_id:
|
||||||
|
return download_xarm(root=root, dataset_id=dataset_id)
|
||||||
|
elif "aloha" in dataset_id:
|
||||||
|
return download_aloha(root=root, dataset_id=dataset_id)
|
||||||
|
elif "umi" in dataset_id:
|
||||||
|
return download_umi(root=root, dataset_id=dataset_id)
|
||||||
|
else:
|
||||||
|
raise ValueError(dataset_id)
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
print(f"downloading from {url}")
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
if response.status_code == 200:
|
||||||
|
total_size = int(response.headers.get("content-length", 0))
|
||||||
|
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||||
|
|
||||||
|
zip_file = io.BytesIO()
|
||||||
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
zip_file.write(chunk)
|
||||||
|
progress_bar.update(len(chunk))
|
||||||
|
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
zip_file.seek(0)
|
||||||
|
|
||||||
|
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||||
|
zip_ref.extractall(destination_folder)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def download_pusht(root: str, dataset_id: str = "pusht", fps: int = 10) -> Path:
|
||||||
|
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||||
|
pusht_zarr = Path("pusht/pusht_cchi_v7_replay.zarr")
|
||||||
|
|
||||||
|
root = Path(root)
|
||||||
|
raw_dir: Path = root / f"{dataset_id}_raw"
|
||||||
|
zarr_path: Path = (raw_dir / pusht_zarr).resolve()
|
||||||
|
if not zarr_path.is_dir():
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
download_and_extract_zip(pusht_url, raw_dir)
|
||||||
|
return zarr_path
|
||||||
|
|
||||||
|
|
||||||
|
def download_xarm(root: str, dataset_id: str, fps: int = 15) -> Path:
|
||||||
|
root = Path(root)
|
||||||
|
raw_dir: Path = root / "xarm_datasets_raw"
|
||||||
|
if not raw_dir.exists():
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
import gdown
|
||||||
|
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
||||||
|
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||||
|
zip_path = raw_dir / "data.zip"
|
||||||
|
gdown.download(url, str(zip_path), quiet=False)
|
||||||
|
print("Extracting...")
|
||||||
|
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
||||||
|
for member in zip_f.namelist():
|
||||||
|
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||||
|
print(member)
|
||||||
|
zip_f.extract(member=member)
|
||||||
|
zip_path.unlink()
|
||||||
|
|
||||||
|
dataset_path: Path = root / f"{dataset_id}"
|
||||||
|
return dataset_path
|
||||||
|
|
||||||
|
|
||||||
|
def download_aloha(root: str, dataset_id: str) -> Path:
|
||||||
|
folder_urls = {
|
||||||
|
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||||
|
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
|
||||||
|
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
|
||||||
|
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
|
||||||
|
}
|
||||||
|
|
||||||
|
ep48_urls = {
|
||||||
|
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
|
||||||
|
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
|
||||||
|
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
|
||||||
|
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
|
||||||
|
}
|
||||||
|
|
||||||
|
ep49_urls = {
|
||||||
|
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
|
||||||
|
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
|
||||||
|
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
|
||||||
|
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
|
||||||
|
}
|
||||||
|
num_episodes = { # noqa: F841 # we keep this for reference
|
||||||
|
"aloha_sim_insertion_human": 50,
|
||||||
|
"aloha_sim_insertion_scripted": 50,
|
||||||
|
"aloha_sim_transfer_cube_human": 50,
|
||||||
|
"aloha_sim_transfer_cube_scripted": 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
episode_len = { # noqa: F841 # we keep this for reference
|
||||||
|
"aloha_sim_insertion_human": 500,
|
||||||
|
"aloha_sim_insertion_scripted": 400,
|
||||||
|
"aloha_sim_transfer_cube_human": 400,
|
||||||
|
"aloha_sim_transfer_cube_scripted": 400,
|
||||||
|
}
|
||||||
|
|
||||||
|
cameras = { # noqa: F841 # we keep this for reference
|
||||||
|
"aloha_sim_insertion_human": ["top"],
|
||||||
|
"aloha_sim_insertion_scripted": ["top"],
|
||||||
|
"aloha_sim_transfer_cube_human": ["top"],
|
||||||
|
"aloha_sim_transfer_cube_scripted": ["top"],
|
||||||
|
}
|
||||||
|
root = Path(root)
|
||||||
|
raw_dir: Path = root / f"{dataset_id}_raw"
|
||||||
|
if not raw_dir.is_dir():
|
||||||
|
import gdown
|
||||||
|
|
||||||
|
assert dataset_id in folder_urls
|
||||||
|
assert dataset_id in ep48_urls
|
||||||
|
assert dataset_id in ep49_urls
|
||||||
|
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
gdown.download_folder(folder_urls[dataset_id], output=str(raw_dir))
|
||||||
|
|
||||||
|
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
|
||||||
|
gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True)
|
||||||
|
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
|
||||||
|
return raw_dir
|
||||||
|
|
||||||
|
|
||||||
|
def download_umi(root: str, dataset_id: str) -> Path:
|
||||||
|
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
|
||||||
|
cup_in_the_wild_zarr = Path("umi/cup_in_the_wild/cup_in_the_wild.zarr")
|
||||||
|
|
||||||
|
root = Path(root)
|
||||||
|
raw_dir: Path = root / f"{dataset_id}_raw"
|
||||||
|
zarr_path: Path = (raw_dir / cup_in_the_wild_zarr).resolve()
|
||||||
|
if not zarr_path.is_dir():
|
||||||
|
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
|
||||||
|
return zarr_path
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
root = "data"
|
||||||
|
dataset_ids = [
|
||||||
|
"pusht",
|
||||||
|
"xarm_lift_medium",
|
||||||
|
"xarm_lift_medium_replay",
|
||||||
|
"xarm_push_medium",
|
||||||
|
"xarm_push_medium_replay",
|
||||||
|
"aloha_sim_insertion_human",
|
||||||
|
"aloha_sim_insertion_scripted",
|
||||||
|
"aloha_sim_transfer_cube_human",
|
||||||
|
"aloha_sim_transfer_cube_scripted",
|
||||||
|
"umi_cup_in_the_wild",
|
||||||
|
]
|
||||||
|
for dataset_id in dataset_ids:
|
||||||
|
download_raw(root=root, dataset_id=dataset_id)
|
||||||
@@ -0,0 +1,311 @@
|
|||||||
|
# imagecodecs/numcodecs.py
|
||||||
|
|
||||||
|
# Copyright (c) 2021-2022, Christoph Gohlke
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
#
|
||||||
|
# 1. Redistributions of source code must retain the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer.
|
||||||
|
#
|
||||||
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
#
|
||||||
|
# 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived from
|
||||||
|
# this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||||
|
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
# POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
# Copied from: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/codecs/imagecodecs_numcodecs.py#L1
|
||||||
|
"""Additional numcodecs implemented using imagecodecs."""
|
||||||
|
|
||||||
|
__version__ = "2022.9.26"
|
||||||
|
|
||||||
|
__all__ = ("register_codecs",)
|
||||||
|
|
||||||
|
import imagecodecs
|
||||||
|
import numpy
|
||||||
|
from numcodecs.abc import Codec
|
||||||
|
from numcodecs.registry import get_codec, register_codec
|
||||||
|
|
||||||
|
# TODO (azouitine): Remove useless codecs
|
||||||
|
|
||||||
|
|
||||||
|
def protective_squeeze(x: numpy.ndarray):
|
||||||
|
"""
|
||||||
|
Squeeze dim only if it's not the last dim.
|
||||||
|
Image dim expected to be *, H, W, C
|
||||||
|
"""
|
||||||
|
img_shape = x.shape[-3:]
|
||||||
|
if len(x.shape) > 3:
|
||||||
|
n_imgs = numpy.prod(x.shape[:-3])
|
||||||
|
if n_imgs > 1:
|
||||||
|
img_shape = (-1,) + img_shape
|
||||||
|
return x.reshape(img_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_image_compressor(**kwargs):
|
||||||
|
if imagecodecs.JPEGXL:
|
||||||
|
# has JPEGXL
|
||||||
|
this_kwargs = {
|
||||||
|
"effort": 3,
|
||||||
|
"distance": 0.3,
|
||||||
|
# bug in libjxl, invalid codestream for non-lossless
|
||||||
|
# when decoding speed > 1
|
||||||
|
"decodingspeed": 1,
|
||||||
|
}
|
||||||
|
this_kwargs.update(kwargs)
|
||||||
|
return JpegXl(**this_kwargs)
|
||||||
|
else:
|
||||||
|
this_kwargs = {"level": 50}
|
||||||
|
this_kwargs.update(kwargs)
|
||||||
|
return Jpeg2k(**this_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Jpeg2k(Codec):
|
||||||
|
"""JPEG 2000 codec for numcodecs."""
|
||||||
|
|
||||||
|
codec_id = "imagecodecs_jpeg2k"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
level=None,
|
||||||
|
codecformat=None,
|
||||||
|
colorspace=None,
|
||||||
|
tile=None,
|
||||||
|
reversible=None,
|
||||||
|
bitspersample=None,
|
||||||
|
resolutions=None,
|
||||||
|
numthreads=None,
|
||||||
|
verbose=0,
|
||||||
|
):
|
||||||
|
self.level = level
|
||||||
|
self.codecformat = codecformat
|
||||||
|
self.colorspace = colorspace
|
||||||
|
self.tile = None if tile is None else tuple(tile)
|
||||||
|
self.reversible = reversible
|
||||||
|
self.bitspersample = bitspersample
|
||||||
|
self.resolutions = resolutions
|
||||||
|
self.numthreads = numthreads
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
def encode(self, buf):
|
||||||
|
buf = protective_squeeze(numpy.asarray(buf))
|
||||||
|
return imagecodecs.jpeg2k_encode(
|
||||||
|
buf,
|
||||||
|
level=self.level,
|
||||||
|
codecformat=self.codecformat,
|
||||||
|
colorspace=self.colorspace,
|
||||||
|
tile=self.tile,
|
||||||
|
reversible=self.reversible,
|
||||||
|
bitspersample=self.bitspersample,
|
||||||
|
resolutions=self.resolutions,
|
||||||
|
numthreads=self.numthreads,
|
||||||
|
verbose=self.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, buf, out=None):
|
||||||
|
return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out)
|
||||||
|
|
||||||
|
|
||||||
|
class JpegXl(Codec):
|
||||||
|
"""JPEG XL codec for numcodecs."""
|
||||||
|
|
||||||
|
codec_id = "imagecodecs_jpegxl"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# encode
|
||||||
|
level=None,
|
||||||
|
effort=None,
|
||||||
|
distance=None,
|
||||||
|
lossless=None,
|
||||||
|
decodingspeed=None,
|
||||||
|
photometric=None,
|
||||||
|
planar=None,
|
||||||
|
usecontainer=None,
|
||||||
|
# decode
|
||||||
|
index=None,
|
||||||
|
keeporientation=None,
|
||||||
|
# both
|
||||||
|
numthreads=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return JPEG XL image from numpy array.
|
||||||
|
Float must be in nominal range 0..1.
|
||||||
|
|
||||||
|
Currently L, LA, RGB, RGBA images are supported in contig mode.
|
||||||
|
Extra channels are only supported for grayscale images in planar mode.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
level : Default to None, i.e. not overwriting lossess and decodingspeed options.
|
||||||
|
When < 0: Use lossless compression
|
||||||
|
When in [0,1,2,3,4]: Sets the decoding speed tier for the provided options.
|
||||||
|
Minimum is 0 (slowest to decode, best quality/density), and maximum
|
||||||
|
is 4 (fastest to decode, at the cost of some quality/density).
|
||||||
|
effort : Default to 3.
|
||||||
|
Sets encoder effort/speed level without affecting decoding speed.
|
||||||
|
Valid values are, from faster to slower speed: 1:lightning 2:thunder
|
||||||
|
3:falcon 4:cheetah 5:hare 6:wombat 7:squirrel 8:kitten 9:tortoise.
|
||||||
|
Speed: lightning, thunder, falcon, cheetah, hare, wombat, squirrel, kitten, tortoise
|
||||||
|
control the encoder effort in ascending order.
|
||||||
|
This also affects memory usage: using lower effort will typically reduce memory
|
||||||
|
consumption during encoding.
|
||||||
|
lightning and thunder are fast modes useful for lossless mode (modular).
|
||||||
|
falcon disables all of the following tools.
|
||||||
|
cheetah enables coefficient reordering, context clustering, and heuristics for selecting DCT sizes and quantization steps.
|
||||||
|
hare enables Gaborish filtering, chroma from luma, and an initial estimate of quantization steps.
|
||||||
|
wombat enables error diffusion quantization and full DCT size selection heuristics.
|
||||||
|
squirrel (default) enables dots, patches, and spline detection, and full context clustering.
|
||||||
|
kitten optimizes the adaptive quantization for a psychovisual metric.
|
||||||
|
tortoise enables a more thorough adaptive quantization search.
|
||||||
|
distance : Default to 1.0
|
||||||
|
Sets the distance level for lossy compression: target max butteraugli distance,
|
||||||
|
lower = higher quality. Range: 0 .. 15. 0.0 = mathematically lossless
|
||||||
|
(however, use JxlEncoderSetFrameLossless instead to use true lossless,
|
||||||
|
as setting distance to 0 alone is not the only requirement).
|
||||||
|
1.0 = visually lossless. Recommended range: 0.5 .. 3.0.
|
||||||
|
lossess : Default to False.
|
||||||
|
Use lossess encoding.
|
||||||
|
decodingspeed : Default to 0.
|
||||||
|
Duplicate to level. [0,4]
|
||||||
|
photometric : Return JxlColorSpace value.
|
||||||
|
Default logic is quite complicated but works most of the time.
|
||||||
|
Accepted value:
|
||||||
|
int: [-1,3]
|
||||||
|
str: ['RGB',
|
||||||
|
'WHITEISZERO', 'MINISWHITE',
|
||||||
|
'BLACKISZERO', 'MINISBLACK', 'GRAY',
|
||||||
|
'XYB', 'KNOWN']
|
||||||
|
planar : Enable multi-channel mode.
|
||||||
|
Default to false.
|
||||||
|
usecontainer :
|
||||||
|
Forces the encoder to use the box-based container format (BMFF)
|
||||||
|
even when not necessary.
|
||||||
|
When using JxlEncoderUseBoxes, JxlEncoderStoreJPEGMetadata or
|
||||||
|
JxlEncoderSetCodestreamLevel with level 10, the encoder will
|
||||||
|
automatically also use the container format, it is not necessary
|
||||||
|
to use JxlEncoderUseContainer for those use cases.
|
||||||
|
By default this setting is disabled.
|
||||||
|
index : Selectively decode frames for animation.
|
||||||
|
Default to 0, decode all frames.
|
||||||
|
When set to > 0, decode that frame index only.
|
||||||
|
keeporientation :
|
||||||
|
Enables or disables preserving of as-in-bitstream pixeldata orientation.
|
||||||
|
Some images are encoded with an Orientation tag indicating that the
|
||||||
|
decoder must perform a rotation and/or mirroring to the encoded image data.
|
||||||
|
|
||||||
|
If skip_reorientation is JXL_FALSE (the default): the decoder will apply
|
||||||
|
the transformation from the orientation setting, hence rendering the image
|
||||||
|
according to its specified intent. When producing a JxlBasicInfo, the decoder
|
||||||
|
will always set the orientation field to JXL_ORIENT_IDENTITY (matching the
|
||||||
|
returned pixel data) and also align xsize and ysize so that they correspond
|
||||||
|
to the width and the height of the returned pixel data.
|
||||||
|
|
||||||
|
If skip_reorientation is JXL_TRUE: the decoder will skip applying the
|
||||||
|
transformation from the orientation setting, returning the image in
|
||||||
|
the as-in-bitstream pixeldata orientation. This may be faster to decode
|
||||||
|
since the decoder doesnt have to apply the transformation, but can
|
||||||
|
cause wrong display of the image if the orientation tag is not correctly
|
||||||
|
taken into account by the user.
|
||||||
|
|
||||||
|
By default, this option is disabled, and the returned pixel data is
|
||||||
|
re-oriented according to the images Orientation setting.
|
||||||
|
threads : Default to 1.
|
||||||
|
If <= 0, use all cores.
|
||||||
|
If > 32, clipped to 32.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.level = level
|
||||||
|
self.effort = effort
|
||||||
|
self.distance = distance
|
||||||
|
self.lossless = bool(lossless)
|
||||||
|
self.decodingspeed = decodingspeed
|
||||||
|
self.photometric = photometric
|
||||||
|
self.planar = planar
|
||||||
|
self.usecontainer = usecontainer
|
||||||
|
self.index = index
|
||||||
|
self.keeporientation = keeporientation
|
||||||
|
self.numthreads = numthreads
|
||||||
|
|
||||||
|
def encode(self, buf):
|
||||||
|
# TODO: only squeeze all but last dim
|
||||||
|
buf = protective_squeeze(numpy.asarray(buf))
|
||||||
|
return imagecodecs.jpegxl_encode(
|
||||||
|
buf,
|
||||||
|
level=self.level,
|
||||||
|
effort=self.effort,
|
||||||
|
distance=self.distance,
|
||||||
|
lossless=self.lossless,
|
||||||
|
decodingspeed=self.decodingspeed,
|
||||||
|
photometric=self.photometric,
|
||||||
|
planar=self.planar,
|
||||||
|
usecontainer=self.usecontainer,
|
||||||
|
numthreads=self.numthreads,
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, buf, out=None):
|
||||||
|
return imagecodecs.jpegxl_decode(
|
||||||
|
buf,
|
||||||
|
index=self.index,
|
||||||
|
keeporientation=self.keeporientation,
|
||||||
|
numthreads=self.numthreads,
|
||||||
|
out=out,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _flat(out):
|
||||||
|
"""Return numpy array as contiguous view of bytes if possible."""
|
||||||
|
if out is None:
|
||||||
|
return None
|
||||||
|
view = memoryview(out)
|
||||||
|
if view.readonly or not view.contiguous:
|
||||||
|
return None
|
||||||
|
return view.cast("B")
|
||||||
|
|
||||||
|
|
||||||
|
def register_codecs(codecs=None, force=False, verbose=True):
|
||||||
|
"""Register codecs in this module with numcodecs."""
|
||||||
|
for name, cls in globals().items():
|
||||||
|
if not hasattr(cls, "codec_id") or name == "Codec":
|
||||||
|
continue
|
||||||
|
if codecs is not None and cls.codec_id not in codecs:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
try: # noqa: SIM105
|
||||||
|
get_codec({"id": cls.codec_id})
|
||||||
|
except TypeError:
|
||||||
|
# registered, but failed
|
||||||
|
pass
|
||||||
|
except ValueError:
|
||||||
|
# not registered yet
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if not force:
|
||||||
|
if verbose:
|
||||||
|
log_warning(f"numcodec {cls.codec_id!r} already registered")
|
||||||
|
continue
|
||||||
|
if verbose:
|
||||||
|
log_warning(f"replacing registered numcodec {cls.codec_id!r}")
|
||||||
|
register_codec(cls)
|
||||||
|
|
||||||
|
|
||||||
|
def log_warning(msg, *args, **kwargs):
|
||||||
|
"""Log message with level WARNING."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.getLogger(__name__).warning(msg, *args, **kwargs)
|
||||||
199
lerobot/common/datasets/push_dataset_to_hub/aloha_processor.py
Normal file
199
lerobot/common/datasets/push_dataset_to_hub/aloha_processor.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
hf_transform_to_torch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AlohaProcessor:
|
||||||
|
"""
|
||||||
|
Process HDF5 files formatted like in: https://github.com/tonyzhaozh/act
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
folder_path (Path): Path to the directory containing HDF5 files.
|
||||||
|
cameras (list[str]): List of camera identifiers to check in the files.
|
||||||
|
fps (int): Frames per second used in timestamp calculations.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
is_valid() -> bool:
|
||||||
|
Validates if each HDF5 file within the folder contains all required datasets.
|
||||||
|
preprocess() -> dict:
|
||||||
|
Processes the files and returns structured data suitable for further analysis.
|
||||||
|
to_hf_dataset(data_dict: dict) -> Dataset:
|
||||||
|
Converts processed data into a Hugging Face Dataset object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, folder_path: Path, cameras: list[str] | None = None, fps: int | None = None):
|
||||||
|
"""
|
||||||
|
Initializes the AlohaProcessor with a specified directory path containing HDF5 files,
|
||||||
|
an optional list of cameras, and a frame rate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_path (Path): The directory path where HDF5 files are stored.
|
||||||
|
cameras (list[str] | None): Optional list of cameras to validate within the files. Defaults to ['top'] if None.
|
||||||
|
fps (int): Frame rate for the datasets, used in time calculations. Default is 50.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> processor = AlohaProcessor(Path("path_to_hdf5_directory"), ["camera1", "camera2"])
|
||||||
|
>>> processor.is_valid()
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
self.folder_path = folder_path
|
||||||
|
if cameras is None:
|
||||||
|
cameras = ["top"]
|
||||||
|
self.cameras = cameras
|
||||||
|
if fps is None:
|
||||||
|
fps = 50
|
||||||
|
self._fps = fps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps(self) -> int:
|
||||||
|
return self._fps
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
"""
|
||||||
|
Validates the HDF5 files in the specified folder to ensure they contain the required datasets
|
||||||
|
for actions, positions, and images for each specified camera.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if all files are valid HDF5 files with all required datasets, False otherwise.
|
||||||
|
"""
|
||||||
|
hdf5_files: list[Path] = list(self.folder_path.glob("episode_*.hdf5"))
|
||||||
|
if len(hdf5_files) == 0:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
hdf5_files = sorted(
|
||||||
|
hdf5_files, key=lambda x: int(re.search(r"episode_(\d+).hdf5", x.name).group(1))
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
# All file names must contain a numerical identifier matching 'episode_(\\d+).hdf5
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the sequence is consecutive eg episode_0, episode_1, episode_2, etc.
|
||||||
|
# If not, return False
|
||||||
|
previous_number = None
|
||||||
|
for file in hdf5_files:
|
||||||
|
current_number = int(re.search(r"episode_(\d+).hdf5", file.name).group(1))
|
||||||
|
if previous_number is not None and current_number - previous_number != 1:
|
||||||
|
return False
|
||||||
|
previous_number = current_number
|
||||||
|
|
||||||
|
for file in hdf5_files:
|
||||||
|
try:
|
||||||
|
with h5py.File(file, "r") as file:
|
||||||
|
# Check for the expected datasets within the HDF5 file
|
||||||
|
required_datasets = ["/action", "/observations/qpos"]
|
||||||
|
# Add camera-specific image datasets to the required datasets
|
||||||
|
camera_datasets = [f"/observations/images/{cam}" for cam in self.cameras]
|
||||||
|
required_datasets.extend(camera_datasets)
|
||||||
|
|
||||||
|
if not all(dataset in file for dataset in required_datasets):
|
||||||
|
return False
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def preprocess(self):
|
||||||
|
"""
|
||||||
|
Collects episode data from the HDF5 file and returns it as an AlohaStep named tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AlohaStep: Named tuple containing episode data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the file is not valid.
|
||||||
|
"""
|
||||||
|
if not self.is_valid():
|
||||||
|
raise ValueError("The HDF5 file is invalid or does not contain the required datasets.")
|
||||||
|
|
||||||
|
hdf5_files = list(self.folder_path.glob("*.hdf5"))
|
||||||
|
hdf5_files = sorted(hdf5_files, key=lambda x: int(re.search(r"episode_(\d+)", x.name).group(1)))
|
||||||
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
|
id_from = 0
|
||||||
|
|
||||||
|
for ep_path in tqdm.tqdm(hdf5_files):
|
||||||
|
with h5py.File(ep_path, "r") as ep:
|
||||||
|
ep_id = int(re.search(r"episode_(\d+)", ep_path.name).group(1))
|
||||||
|
num_frames = ep["/action"].shape[0]
|
||||||
|
|
||||||
|
# last step of demonstration is considered done
|
||||||
|
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||||
|
done[-1] = True
|
||||||
|
|
||||||
|
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||||
|
action = torch.from_numpy(ep["/action"][:])
|
||||||
|
|
||||||
|
ep_dict = {}
|
||||||
|
|
||||||
|
for cam in self.cameras:
|
||||||
|
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
|
||||||
|
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
|
||||||
|
|
||||||
|
ep_dict.update(
|
||||||
|
{
|
||||||
|
"observation.state": state,
|
||||||
|
"action": action,
|
||||||
|
"episode_index": torch.tensor([ep_id] * num_frames),
|
||||||
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
|
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||||
|
# TODO(rcadene): compute reward and success
|
||||||
|
# "next.reward": reward,
|
||||||
|
"next.done": done,
|
||||||
|
# "next.success": success,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(ep_id, int)
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
|
||||||
|
id_from += num_frames
|
||||||
|
|
||||||
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
return data_dict, episode_data_index
|
||||||
|
|
||||||
|
def to_hf_dataset(self, data_dict) -> Dataset:
|
||||||
|
"""
|
||||||
|
Converts a dictionary of data into a Hugging Face Dataset object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict (dict): A dictionary containing the data to be converted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset: The converted Hugging Face Dataset object.
|
||||||
|
"""
|
||||||
|
image_features = {f"observation.images.{cam}": Image() for cam in self.cameras}
|
||||||
|
features = {
|
||||||
|
"observation.state": Sequence(
|
||||||
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
),
|
||||||
|
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||||
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
|
# "next.reward": Value(dtype="float32", id=None),
|
||||||
|
"next.done": Value(dtype="bool", id=None),
|
||||||
|
# "next.success": Value(dtype="bool", id=None),
|
||||||
|
"index": Value(dtype="int64", id=None),
|
||||||
|
}
|
||||||
|
update_features = {**image_features, **features}
|
||||||
|
features = Features(update_features)
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
|
return hf_dataset
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
pass
|
||||||
180
lerobot/common/datasets/push_dataset_to_hub/pusht_processor.py
Normal file
180
lerobot/common/datasets/push_dataset_to_hub/pusht_processor.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import zarr
|
||||||
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
hf_transform_to_torch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PushTProcessor:
|
||||||
|
""" Process zarr files formatted like in: https://github.com/real-stanford/diffusion_policy
|
||||||
|
"""
|
||||||
|
def __init__(self, folder_path: Path, fps: int | None = None):
|
||||||
|
self.zarr_path = folder_path
|
||||||
|
if fps is None:
|
||||||
|
fps = 10
|
||||||
|
self._fps = fps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps(self) -> int:
|
||||||
|
return self._fps
|
||||||
|
|
||||||
|
def is_valid(self):
|
||||||
|
try:
|
||||||
|
zarr_data = zarr.open(self.zarr_path, mode="r")
|
||||||
|
except Exception:
|
||||||
|
# TODO (azouitine): Handle the exception properly
|
||||||
|
return False
|
||||||
|
required_datasets = {
|
||||||
|
"data/action",
|
||||||
|
"data/img",
|
||||||
|
"data/keypoint",
|
||||||
|
"data/n_contacts",
|
||||||
|
"data/state",
|
||||||
|
"meta/episode_ends",
|
||||||
|
}
|
||||||
|
for dataset in required_datasets:
|
||||||
|
if dataset not in zarr_data:
|
||||||
|
return False
|
||||||
|
nb_frames = zarr_data["data/img"].shape[0]
|
||||||
|
|
||||||
|
required_datasets.remove("meta/episode_ends")
|
||||||
|
|
||||||
|
return all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||||
|
|
||||||
|
def preprocess(self):
|
||||||
|
try:
|
||||||
|
import pymunk
|
||||||
|
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||||
|
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
||||||
|
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||||
|
)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# as define in env
|
||||||
|
success_threshold = 0.95 # 95% coverage,
|
||||||
|
|
||||||
|
dataset_dict = DiffusionPolicyReplayBuffer.copy_from_path(
|
||||||
|
self.zarr_path
|
||||||
|
) # , keys=['img', 'state', 'action'])
|
||||||
|
|
||||||
|
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||||
|
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||||
|
assert len(
|
||||||
|
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||||
|
), "Some data type dont have the same number of total frames."
|
||||||
|
|
||||||
|
# TODO: verify that goal pose is expected to be fixed
|
||||||
|
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||||
|
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||||
|
|
||||||
|
imgs = torch.from_numpy(dataset_dict["img"]) # b h w c
|
||||||
|
states = torch.from_numpy(dataset_dict["state"])
|
||||||
|
actions = torch.from_numpy(dataset_dict["action"])
|
||||||
|
|
||||||
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
|
id_from = 0
|
||||||
|
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||||
|
id_to = dataset_dict.meta["episode_ends"][episode_id]
|
||||||
|
|
||||||
|
num_frames = id_to - id_from
|
||||||
|
|
||||||
|
assert (episode_ids[id_from:id_to] == episode_id).all()
|
||||||
|
|
||||||
|
image = imgs[id_from:id_to]
|
||||||
|
assert image.min() >= 0.0
|
||||||
|
assert image.max() <= 255.0
|
||||||
|
image = image.type(torch.uint8)
|
||||||
|
|
||||||
|
state = states[id_from:id_to]
|
||||||
|
agent_pos = state[:, :2]
|
||||||
|
block_pos = state[:, 2:4]
|
||||||
|
block_angle = state[:, 4]
|
||||||
|
|
||||||
|
reward = torch.zeros(num_frames)
|
||||||
|
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||||
|
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||||
|
for i in range(num_frames):
|
||||||
|
space = pymunk.Space()
|
||||||
|
space.gravity = 0, 0
|
||||||
|
space.damping = 0
|
||||||
|
|
||||||
|
# Add walls.
|
||||||
|
walls = [
|
||||||
|
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||||
|
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||||
|
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||||
|
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||||
|
]
|
||||||
|
space.add(*walls)
|
||||||
|
|
||||||
|
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||||
|
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||||
|
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||||
|
intersection_area = goal_geom.intersection(block_geom).area
|
||||||
|
goal_area = goal_geom.area
|
||||||
|
coverage = intersection_area / goal_area
|
||||||
|
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
||||||
|
success[i] = coverage > success_threshold
|
||||||
|
|
||||||
|
# last step of demonstration is considered done
|
||||||
|
done[-1] = True
|
||||||
|
|
||||||
|
ep_dict = {
|
||||||
|
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||||
|
"observation.state": agent_pos,
|
||||||
|
"action": actions[id_from:id_to],
|
||||||
|
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||||
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
|
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||||
|
# "next.observation.image": image[1:],
|
||||||
|
# "next.observation.state": agent_pos[1:],
|
||||||
|
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
||||||
|
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||||
|
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||||
|
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||||
|
}
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
|
||||||
|
id_from += num_frames
|
||||||
|
|
||||||
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
return data_dict, episode_data_index
|
||||||
|
|
||||||
|
def to_hf_dataset(self, data_dict):
|
||||||
|
features = {
|
||||||
|
"observation.image": Image(),
|
||||||
|
"observation.state": Sequence(
|
||||||
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
),
|
||||||
|
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||||
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
|
"next.reward": Value(dtype="float32", id=None),
|
||||||
|
"next.done": Value(dtype="bool", id=None),
|
||||||
|
"next.success": Value(dtype="bool", id=None),
|
||||||
|
"index": Value(dtype="int64", id=None),
|
||||||
|
}
|
||||||
|
features = Features(features)
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
return hf_dataset
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
pass
|
||||||
280
lerobot/common/datasets/push_dataset_to_hub/umi_processor.py
Normal file
280
lerobot/common/datasets/push_dataset_to_hub/umi_processor.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import zarr
|
||||||
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
hf_transform_to_torch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UmiProcessor:
|
||||||
|
"""
|
||||||
|
Process UMI (Universal Manipulation Interface) data stored in Zarr format like in: https://github.com/real-stanford/universal_manipulation_interface
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
folder_path (str): The path to the folder containing Zarr datasets.
|
||||||
|
fps (int): Frames per second, used to calculate timestamps for frames.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, folder_path: str, fps: int | None = None):
|
||||||
|
self.zarr_path = folder_path
|
||||||
|
if fps is None:
|
||||||
|
# TODO (azouitine): Add reference to the paper
|
||||||
|
fps = 15
|
||||||
|
self._fps = fps
|
||||||
|
register_codecs()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps(self) -> int:
|
||||||
|
return self._fps
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
"""
|
||||||
|
Validates the Zarr folder to ensure it contains all required datasets with consistent frame counts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if all required datasets are present and have consistent frame counts, False otherwise.
|
||||||
|
"""
|
||||||
|
# Check if the Zarr folder is valid
|
||||||
|
try:
|
||||||
|
zarr_data = zarr.open(self.zarr_path, mode="r")
|
||||||
|
except Exception:
|
||||||
|
# TODO (azouitine): Handle the exception properly
|
||||||
|
return False
|
||||||
|
required_datasets = {
|
||||||
|
"data/robot0_demo_end_pose",
|
||||||
|
"data/robot0_demo_start_pose",
|
||||||
|
"data/robot0_eef_pos",
|
||||||
|
"data/robot0_eef_rot_axis_angle",
|
||||||
|
"data/robot0_gripper_width",
|
||||||
|
"meta/episode_ends",
|
||||||
|
"data/camera0_rgb",
|
||||||
|
}
|
||||||
|
for dataset in required_datasets:
|
||||||
|
if dataset not in zarr_data:
|
||||||
|
return False
|
||||||
|
nb_frames = zarr_data["data/camera0_rgb"].shape[0]
|
||||||
|
|
||||||
|
required_datasets.remove("meta/episode_ends")
|
||||||
|
|
||||||
|
return all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||||
|
|
||||||
|
def preprocess(self):
|
||||||
|
"""
|
||||||
|
Collects and processes all episodes from the Zarr dataset into structured data dictionaries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, Dict]: A tuple containing the structured episode data and episode index mappings.
|
||||||
|
"""
|
||||||
|
zarr_data = zarr.open(self.zarr_path, mode="r")
|
||||||
|
|
||||||
|
# We process the image data separately because it is too large to fit in memory
|
||||||
|
end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:])
|
||||||
|
start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:])
|
||||||
|
eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:])
|
||||||
|
eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:])
|
||||||
|
gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:])
|
||||||
|
|
||||||
|
states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1)
|
||||||
|
states = torch.cat([states_pos, gripper_width], dim=1)
|
||||||
|
|
||||||
|
episode_ends = zarr_data["meta/episode_ends"][:]
|
||||||
|
num_episodes: int = episode_ends.shape[0]
|
||||||
|
|
||||||
|
episode_ids = torch.from_numpy(self.get_episode_idxs(episode_ends))
|
||||||
|
|
||||||
|
# We convert it in torch tensor later because the jit function does not support torch tensors
|
||||||
|
episode_ends = torch.from_numpy(episode_ends)
|
||||||
|
|
||||||
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
id_from = 0
|
||||||
|
|
||||||
|
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||||
|
id_to = episode_ends[episode_id]
|
||||||
|
|
||||||
|
num_frames = id_to - id_from
|
||||||
|
|
||||||
|
assert (
|
||||||
|
episode_ids[id_from:id_to] == episode_id
|
||||||
|
).all(), f"episode_ids[{id_from}:{id_to}] != {episode_id}"
|
||||||
|
|
||||||
|
state = states[id_from:id_to]
|
||||||
|
ep_dict = {
|
||||||
|
# observation.image will be filled later
|
||||||
|
"observation.state": state,
|
||||||
|
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||||
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
|
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||||
|
"episode_data_index_from": torch.tensor([id_from] * num_frames),
|
||||||
|
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
|
||||||
|
"end_pose": end_pose[id_from:id_to],
|
||||||
|
"start_pos": start_pos[id_from:id_to],
|
||||||
|
"gripper_width": gripper_width[id_from:id_to],
|
||||||
|
}
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
id_from += num_frames
|
||||||
|
|
||||||
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
|
||||||
|
total_frames = id_from
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
|
print("Saving images to disk in temporary folder...")
|
||||||
|
# datasets.Image() can take a list of paths to images, so we save the images to a temporary folder
|
||||||
|
# to avoid loading them all in memory
|
||||||
|
_save_images_concurrently(
|
||||||
|
data=zarr_data, image_key="data/camera0_rgb", folder_path="tmp_umi_images", max_workers=12
|
||||||
|
)
|
||||||
|
print("Saving images to disk in temporary folder... Done")
|
||||||
|
|
||||||
|
# Sort files by number eg. 1.png, 2.png, 3.png, 9.png, 10.png instead of 1.png, 10.png, 2.png, 3.png, 9.png
|
||||||
|
# to correctly match the images with the data
|
||||||
|
images_path = sorted(
|
||||||
|
glob("tmp_umi_images/*"), key=lambda x: int(re.search(r"(\d+)\.png$", x).group(1))
|
||||||
|
)
|
||||||
|
data_dict["observation.image"] = images_path
|
||||||
|
print("Images saved to disk, do not forget to delete the folder tmp_umi_images/")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
return data_dict, episode_data_index
|
||||||
|
|
||||||
|
def to_hf_dataset(self, data_dict):
|
||||||
|
"""
|
||||||
|
Converts the processed data dictionary into a Hugging Face dataset with defined features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict (Dict): The data dictionary containing tensors and episode information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset: A Hugging Face dataset constructed from the provided data dictionary.
|
||||||
|
"""
|
||||||
|
features = {
|
||||||
|
"observation.image": Image(),
|
||||||
|
"observation.state": Sequence(
|
||||||
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
),
|
||||||
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
|
"index": Value(dtype="int64", id=None),
|
||||||
|
"episode_data_index_from": Value(dtype="int64", id=None),
|
||||||
|
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||||
|
# `start_pos` and `end_pos` respectively represent the positions of the end-effector
|
||||||
|
# at the beginning and the end of the episode.
|
||||||
|
# `gripper_width` indicates the distance between the grippers, and this value is included
|
||||||
|
# in the state vector, which comprises the concatenation of the end-effector position
|
||||||
|
# and gripper width.
|
||||||
|
"end_pose": Sequence(
|
||||||
|
length=data_dict["end_pose"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
),
|
||||||
|
"start_pos": Sequence(
|
||||||
|
length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
),
|
||||||
|
"gripper_width": Sequence(
|
||||||
|
length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
features = Features(features)
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
|
return hf_dataset
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
# Cleanup
|
||||||
|
if os.path.exists("tmp_umi_images"):
|
||||||
|
print("Removing temporary images folder")
|
||||||
|
shutil.rmtree("tmp_umi_images")
|
||||||
|
print("Cleanup done")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_episode_idxs(cls, episode_ends: np.ndarray) -> np.ndarray:
|
||||||
|
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
|
||||||
|
from numba import jit
|
||||||
|
|
||||||
|
@jit(nopython=True)
|
||||||
|
def _get_episode_idxs(episode_ends):
|
||||||
|
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
||||||
|
start_idx = 0
|
||||||
|
for episode_number, end_idx in enumerate(episode_ends):
|
||||||
|
result[start_idx:end_idx] = episode_number
|
||||||
|
start_idx = end_idx
|
||||||
|
return result
|
||||||
|
|
||||||
|
return _get_episode_idxs(episode_ends)
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_folder(folder_path: str):
|
||||||
|
"""
|
||||||
|
Clears all the content of the specified folder. Creates the folder if it does not exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_path (str): Path to the folder to clear.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import os
|
||||||
|
>>> os.makedirs('example_folder', exist_ok=True)
|
||||||
|
>>> with open('example_folder/temp_file.txt', 'w') as f:
|
||||||
|
... f.write('example')
|
||||||
|
>>> clear_folder('example_folder')
|
||||||
|
>>> os.listdir('example_folder')
|
||||||
|
[]
|
||||||
|
"""
|
||||||
|
if os.path.exists(folder_path):
|
||||||
|
for filename in os.listdir(folder_path):
|
||||||
|
file_path = os.path.join(folder_path, filename)
|
||||||
|
try:
|
||||||
|
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||||
|
os.unlink(file_path)
|
||||||
|
elif os.path.isdir(file_path):
|
||||||
|
shutil.rmtree(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to delete {file_path}. Reason: {e}")
|
||||||
|
else:
|
||||||
|
os.makedirs(folder_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_image(img_array: np.array, i: int, folder_path: str):
|
||||||
|
"""
|
||||||
|
Saves a single image to the specified folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_array (ndarray): The numpy array of the image.
|
||||||
|
i (int): Index of the image, used for naming.
|
||||||
|
folder_path (str): Path to the folder where the image will be saved.
|
||||||
|
"""
|
||||||
|
img = PILImage.fromarray(img_array)
|
||||||
|
img_format = "PNG" if img_array.dtype == np.uint8 else "JPEG"
|
||||||
|
img.save(os.path.join(folder_path, f"{i}.{img_format.lower()}"), quality=100)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_images_concurrently(data: dict, image_key: str, folder_path: str, max_workers: int = 4):
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
"""
|
||||||
|
Saves images from the zarr_data to the specified folder using multithreading.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zarr_data (dict): A dictionary containing image data in an array format.
|
||||||
|
folder_path (str): Path to the folder where images will be saved.
|
||||||
|
max_workers (int): The maximum number of threads to use for saving images.
|
||||||
|
"""
|
||||||
|
num_images = len(data["data/camera0_rgb"])
|
||||||
|
_clear_folder(folder_path) # Clear or create folder first
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
[executor.submit(_save_image, data[image_key][i], i, folder_path) for i in range(num_images)]
|
||||||
20
lerobot/common/datasets/push_dataset_to_hub/utils.py
Normal file
20
lerobot/common/datasets/push_dataset_to_hub/utils.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def concatenate_episodes(ep_dicts):
|
||||||
|
data_dict = {}
|
||||||
|
|
||||||
|
keys = ep_dicts[0].keys()
|
||||||
|
for key in keys:
|
||||||
|
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||||
|
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||||
|
else:
|
||||||
|
if key not in data_dict:
|
||||||
|
data_dict[key] = []
|
||||||
|
for ep_dict in ep_dicts:
|
||||||
|
for x in ep_dict[key]:
|
||||||
|
data_dict[key].append(x)
|
||||||
|
|
||||||
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
return data_dict
|
||||||
145
lerobot/common/datasets/push_dataset_to_hub/xarm_processor.py
Normal file
145
lerobot/common/datasets/push_dataset_to_hub/xarm_processor.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import einops
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
hf_transform_to_torch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class XarmProcessor:
|
||||||
|
"""Process pickle files formatted like in: https://github.com/fyhMer/fowm"""
|
||||||
|
|
||||||
|
def __init__(self, folder_path: str, fps: int | None = None):
|
||||||
|
self.folder_path = Path(folder_path)
|
||||||
|
self.keys = {"actions", "rewards", "dones", "masks"}
|
||||||
|
self.nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}}
|
||||||
|
if fps is None:
|
||||||
|
fps = 15
|
||||||
|
self._fps = fps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps(self) -> int:
|
||||||
|
return self._fps
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
# get all .pkl files
|
||||||
|
xarm_files = list(self.folder_path.glob("*.pkl"))
|
||||||
|
if len(xarm_files) != 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(xarm_files[0], "rb") as f:
|
||||||
|
dataset_dict = pickle.load(f)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not isinstance(dataset_dict, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not all(k in dataset_dict for k in self.keys):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for consistent lengths in nested keys
|
||||||
|
try:
|
||||||
|
expected_len = len(dataset_dict["actions"])
|
||||||
|
if any(len(dataset_dict[key]) != expected_len for key in self.keys if key in dataset_dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for key, subkeys in self.nested_keys.items():
|
||||||
|
nested_dict = dataset_dict.get(key, {})
|
||||||
|
if any(
|
||||||
|
len(nested_dict[subkey]) != expected_len for subkey in subkeys if subkey in nested_dict
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
except KeyError: # If any expected key or subkey is missing
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True # All checks passed
|
||||||
|
|
||||||
|
def preprocess(self):
|
||||||
|
if not self.is_valid():
|
||||||
|
raise ValueError("The Xarm file is invalid or does not contain the required datasets.")
|
||||||
|
|
||||||
|
xarm_files = list(self.folder_path.glob("*.pkl"))
|
||||||
|
|
||||||
|
with open(xarm_files[0], "rb") as f:
|
||||||
|
dataset_dict = pickle.load(f)
|
||||||
|
ep_dicts = []
|
||||||
|
episode_data_index = {"from": [], "to": []}
|
||||||
|
|
||||||
|
id_from = 0
|
||||||
|
id_to = 0
|
||||||
|
episode_id = 0
|
||||||
|
total_frames = dataset_dict["actions"].shape[0]
|
||||||
|
for i in tqdm.tqdm(range(total_frames)):
|
||||||
|
id_to += 1
|
||||||
|
|
||||||
|
if not dataset_dict["dones"][i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_frames = id_to - id_from
|
||||||
|
|
||||||
|
image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to])
|
||||||
|
image = einops.rearrange(image, "b c h w -> b h w c")
|
||||||
|
state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to])
|
||||||
|
action = torch.tensor(dataset_dict["actions"][id_from:id_to])
|
||||||
|
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||||
|
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||||
|
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to])
|
||||||
|
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to])
|
||||||
|
next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to])
|
||||||
|
next_done = torch.tensor(dataset_dict["dones"][id_from:id_to])
|
||||||
|
|
||||||
|
ep_dict = {
|
||||||
|
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||||
|
"observation.state": state,
|
||||||
|
"action": action,
|
||||||
|
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||||
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
|
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||||
|
# "next.observation.image": next_image,
|
||||||
|
# "next.observation.state": next_state,
|
||||||
|
"next.reward": next_reward,
|
||||||
|
"next.done": next_done,
|
||||||
|
}
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
|
episode_data_index["from"].append(id_from)
|
||||||
|
episode_data_index["to"].append(id_from + num_frames)
|
||||||
|
|
||||||
|
id_from = id_to
|
||||||
|
episode_id += 1
|
||||||
|
|
||||||
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
return data_dict, episode_data_index
|
||||||
|
|
||||||
|
def to_hf_dataset(self, data_dict):
|
||||||
|
features = {
|
||||||
|
"observation.image": Image(),
|
||||||
|
"observation.state": Sequence(
|
||||||
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
|
),
|
||||||
|
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||||
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
|
"next.reward": Value(dtype="float32", id=None),
|
||||||
|
"next.done": Value(dtype="bool", id=None),
|
||||||
|
#'next.success': Value(dtype='bool', id=None),
|
||||||
|
"index": Value(dtype="int64", id=None),
|
||||||
|
}
|
||||||
|
features = Features(features)
|
||||||
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
|
return hf_dataset
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
pass
|
||||||
@@ -342,7 +342,6 @@ def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
|
|||||||
"max": max[key],
|
"max": max[key],
|
||||||
"min": min[key],
|
"min": min[key],
|
||||||
}
|
}
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ class ActionChunkingTransformerConfig:
|
|||||||
deviation and "min_max" which rescale in a [-1, 1] range.
|
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||||
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||||
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
|
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||||
torchvision.
|
`None` means no pretrained weights.
|
||||||
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
||||||
convolution.
|
convolution.
|
||||||
pre_norm: Whether to use "pre-norm" in the transformer blocks.
|
pre_norm: Whether to use "pre-norm" in the transformer blocks.
|
||||||
@@ -75,13 +75,13 @@ class ActionChunkingTransformerConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes: dict[str, str] = field(
|
input_normalization_modes: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.image": "mean_std",
|
"observation.images.top": "mean_std",
|
||||||
"observation.state": "mean_std",
|
"observation.state": "mean_std",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
unnormalize_output_modes: dict[str, str] = field(
|
output_normalization_modes: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": "mean_std",
|
"action": "mean_std",
|
||||||
}
|
}
|
||||||
@@ -90,7 +90,7 @@ class ActionChunkingTransformerConfig:
|
|||||||
# Architecture.
|
# Architecture.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
vision_backbone: str = "resnet18"
|
vision_backbone: str = "resnet18"
|
||||||
use_pretrained_backbone: bool = True
|
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||||
replace_final_stride_with_dilation: int = False
|
replace_final_stride_with_dilation: int = False
|
||||||
# Transformer layers.
|
# Transformer layers.
|
||||||
pre_norm: bool = False
|
pre_norm: bool = False
|
||||||
|
|||||||
@@ -72,8 +72,11 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = ActionChunkingTransformerConfig()
|
cfg = ActionChunkingTransformerConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
||||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
||||||
|
self.unnormalize_outputs = Unnormalize(
|
||||||
|
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
||||||
|
)
|
||||||
|
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
@@ -101,7 +104,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||||||
# Backbone for image feature extraction.
|
# Backbone for image feature extraction.
|
||||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||||
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
||||||
pretrained=cfg.use_pretrained_backbone,
|
weights=cfg.pretrained_backbone_weights,
|
||||||
norm_layer=FrozenBatchNorm2d,
|
norm_layer=FrozenBatchNorm2d,
|
||||||
)
|
)
|
||||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
|
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
|
||||||
@@ -216,6 +219,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||||||
self.train()
|
self.train()
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch = self.normalize_targets(batch)
|
||||||
|
|
||||||
loss_dict = self.forward(batch)
|
loss_dict = self.forward(batch)
|
||||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||||
|
|||||||
@@ -35,8 +35,8 @@ class DiffusionConfig:
|
|||||||
within the image size. If None, no cropping is done.
|
within the image size. If None, no cropping is done.
|
||||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||||
mode).
|
mode).
|
||||||
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
|
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||||
torchvision.
|
`None` means no pretrained weights.
|
||||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||||
@@ -83,24 +83,20 @@ class DiffusionConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes: dict[str, str] = field(
|
input_normalization_modes: dict[str, str] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"observation.image": "mean_std",
|
"observation.image": "mean_std",
|
||||||
"observation.state": "min_max",
|
"observation.state": "min_max",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
unnormalize_output_modes: dict[str, str] = field(
|
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||||
default_factory=lambda: {
|
|
||||||
"action": "min_max",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
vision_backbone: str = "resnet18"
|
vision_backbone: str = "resnet18"
|
||||||
crop_shape: tuple[int, int] | None = (84, 84)
|
crop_shape: tuple[int, int] | None = (84, 84)
|
||||||
crop_is_random: bool = True
|
crop_is_random: bool = True
|
||||||
use_pretrained_backbone: bool = False
|
pretrained_backbone_weights: str | None = None
|
||||||
use_group_norm: bool = True
|
use_group_norm: bool = True
|
||||||
spatial_softmax_num_keypoints: int = 32
|
spatial_softmax_num_keypoints: int = 32
|
||||||
# Unet.
|
# Unet.
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class DiffusionPolicy(nn.Module):
|
|||||||
name = "diffusion"
|
name = "diffusion"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
|
self, cfg: DiffusionConfig | None = None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -56,8 +56,11 @@ class DiffusionPolicy(nn.Module):
|
|||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = DiffusionConfig()
|
cfg = DiffusionConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
||||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
||||||
|
self.unnormalize_outputs = Unnormalize(
|
||||||
|
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
||||||
|
)
|
||||||
|
|
||||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||||
self._queues = None
|
self._queues = None
|
||||||
@@ -162,6 +165,7 @@ class DiffusionPolicy(nn.Module):
|
|||||||
self.diffusion.train()
|
self.diffusion.train()
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
batch = self.normalize_targets(batch)
|
||||||
|
|
||||||
loss = self.forward(batch)["loss"]
|
loss = self.forward(batch)["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@@ -374,13 +378,13 @@ class _RgbEncoder(nn.Module):
|
|||||||
|
|
||||||
# Set up backbone.
|
# Set up backbone.
|
||||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||||
pretrained=cfg.use_pretrained_backbone
|
weights=cfg.pretrained_backbone_weights
|
||||||
)
|
)
|
||||||
# Note: This assumes that the layer4 feature map is children()[-3]
|
# Note: This assumes that the layer4 feature map is children()[-3]
|
||||||
# TODO(alexander-soare): Use a safer alternative.
|
# TODO(alexander-soare): Use a safer alternative.
|
||||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||||
if cfg.use_group_norm:
|
if cfg.use_group_norm:
|
||||||
if cfg.use_pretrained_backbone:
|
if cfg.pretrained_backbone_weights:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,27 +1,21 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
def create_stats_buffers(shapes, modes, stats=None):
|
def create_stats_buffers(
|
||||||
|
shapes: dict[str, list[int]],
|
||||||
|
modes: dict[str, str],
|
||||||
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||||
"""
|
"""
|
||||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max statistics.
|
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||||
|
statistics.
|
||||||
|
|
||||||
Parameters:
|
Args: (see Normalize and Unnormalize)
|
||||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
|
|
||||||
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
|
||||||
and width, assuming a channel-first (c, h, w) format.
|
|
||||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
|
|
||||||
- "mean_std": substract the mean and divide by standard deviation.
|
|
||||||
- "min_max": map to [-1, 1] range.
|
|
||||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
|
|
||||||
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
|
|
||||||
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
|
||||||
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
|
||||||
they are already in the policy state_dict.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing `nn.Parameters` set to
|
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
||||||
`requires_grad=False`, suitable to not be updated during backpropagation.
|
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
||||||
"""
|
"""
|
||||||
stats_buffers = {}
|
stats_buffers = {}
|
||||||
|
|
||||||
@@ -75,24 +69,32 @@ def create_stats_buffers(shapes, modes, stats=None):
|
|||||||
|
|
||||||
|
|
||||||
class Normalize(nn.Module):
|
class Normalize(nn.Module):
|
||||||
"""
|
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||||
Normalizes the input data (e.g. "observation.image") for more stable and faster convergence during training.
|
|
||||||
|
|
||||||
Parameters:
|
def __init__(
|
||||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
|
self,
|
||||||
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
shapes: dict[str, list[int]],
|
||||||
and width, assuming a channel-first (c, h, w) format.
|
modes: dict[str, str],
|
||||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
- "mean_std": substract the mean and divide by standard deviation.
|
):
|
||||||
- "min_max": map to [-1, 1] range.
|
"""
|
||||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
|
Args:
|
||||||
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
|
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||||
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||||
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||||
they are already in the policy state_dict.
|
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||||
"""
|
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||||
|
are their normalization modes among:
|
||||||
def __init__(self, shapes, modes, stats=None):
|
- "mean_std": subtract the mean and divide by standard deviation.
|
||||||
|
- "min_max": map to [-1, 1] range.
|
||||||
|
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||||
|
and values are dictionaries of statistic types and their values (e.g.
|
||||||
|
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||||
|
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||||
|
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||||
|
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||||
|
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shapes = shapes
|
self.shapes = shapes
|
||||||
self.modes = modes
|
self.modes = modes
|
||||||
@@ -104,29 +106,33 @@ class Normalize(nn.Module):
|
|||||||
|
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def forward(self, batch):
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
for key, mode in self.modes.items():
|
for key, mode in self.modes.items():
|
||||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||||
|
|
||||||
if mode == "mean_std":
|
if mode == "mean_std":
|
||||||
mean = buffer["mean"]
|
mean = buffer["mean"]
|
||||||
std = buffer["std"]
|
std = buffer["std"]
|
||||||
assert not torch.isinf(
|
assert not torch.isinf(mean).any(), (
|
||||||
mean
|
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`policy.load_state_dict`."
|
||||||
assert not torch.isinf(
|
)
|
||||||
std
|
assert not torch.isinf(std).any(), (
|
||||||
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
|
"`policy.load_state_dict`."
|
||||||
|
)
|
||||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||||
elif mode == "min_max":
|
elif mode == "min_max":
|
||||||
min = buffer["min"]
|
min = buffer["min"]
|
||||||
max = buffer["max"]
|
max = buffer["max"]
|
||||||
assert not torch.isinf(
|
assert not torch.isinf(min).any(), (
|
||||||
min
|
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`policy.load_state_dict`."
|
||||||
assert not torch.isinf(
|
)
|
||||||
max
|
assert not torch.isinf(max).any(), (
|
||||||
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
|
"`policy.load_state_dict`."
|
||||||
|
)
|
||||||
# normalize to [0,1]
|
# normalize to [0,1]
|
||||||
batch[key] = (batch[key] - min) / (max - min)
|
batch[key] = (batch[key] - min) / (max - min)
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
@@ -138,23 +144,34 @@ class Normalize(nn.Module):
|
|||||||
|
|
||||||
class Unnormalize(nn.Module):
|
class Unnormalize(nn.Module):
|
||||||
"""
|
"""
|
||||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their original range used by the environment.
|
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
||||||
|
original range used by the environment.
|
||||||
Parameters:
|
|
||||||
shapes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their shapes (e.g. [10]).
|
|
||||||
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
|
||||||
and width, assuming a channel-first (c, h, w) format.
|
|
||||||
modes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their unnormalization modes among:
|
|
||||||
- "mean_std": multiply by standard deviation and add mean
|
|
||||||
- "min_max": go from [-1, 1] range to original range.
|
|
||||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "action") and values are dictionaries of statistic types and their values
|
|
||||||
(e.g. `{"max": torch.tensor(1)}, "min": torch.tensor(0)}`). If provided, as expected for training the model for the first time,
|
|
||||||
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
|
||||||
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
|
||||||
they are already in the policy state_dict.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, shapes, modes, stats=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
shapes: dict[str, list[int]],
|
||||||
|
modes: dict[str, str],
|
||||||
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||||
|
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||||
|
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||||
|
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||||
|
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||||
|
are their normalization modes among:
|
||||||
|
- "mean_std": subtract the mean and divide by standard deviation.
|
||||||
|
- "min_max": map to [-1, 1] range.
|
||||||
|
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||||
|
and values are dictionaries of statistic types and their values (e.g.
|
||||||
|
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||||
|
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||||
|
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||||
|
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||||
|
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shapes = shapes
|
self.shapes = shapes
|
||||||
self.modes = modes
|
self.modes = modes
|
||||||
@@ -166,29 +183,33 @@ class Unnormalize(nn.Module):
|
|||||||
|
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def forward(self, batch):
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
for key, mode in self.modes.items():
|
for key, mode in self.modes.items():
|
||||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||||
|
|
||||||
if mode == "mean_std":
|
if mode == "mean_std":
|
||||||
mean = buffer["mean"]
|
mean = buffer["mean"]
|
||||||
std = buffer["std"]
|
std = buffer["std"]
|
||||||
assert not torch.isinf(
|
assert not torch.isinf(mean).any(), (
|
||||||
mean
|
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`policy.load_state_dict`."
|
||||||
assert not torch.isinf(
|
)
|
||||||
std
|
assert not torch.isinf(std).any(), (
|
||||||
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
|
"`policy.load_state_dict`."
|
||||||
|
)
|
||||||
batch[key] = batch[key] * std + mean
|
batch[key] = batch[key] * std + mean
|
||||||
elif mode == "min_max":
|
elif mode == "min_max":
|
||||||
min = buffer["min"]
|
min = buffer["min"]
|
||||||
max = buffer["max"]
|
max = buffer["max"]
|
||||||
assert not torch.isinf(
|
assert not torch.isinf(min).any(), (
|
||||||
min
|
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`policy.load_state_dict`."
|
||||||
assert not torch.isinf(
|
)
|
||||||
max
|
assert not torch.isinf(max).any(), (
|
||||||
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||||
|
"`policy.load_state_dict`."
|
||||||
|
)
|
||||||
batch[key] = (batch[key] + 1) / 2
|
batch[key] = (batch[key] + 1) / 2
|
||||||
batch[key] = batch[key] * (max - min) + min
|
batch[key] = batch[key] * (max - min) + min
|
||||||
else:
|
else:
|
||||||
|
|||||||
12
lerobot/common/utils/io_utils.py
Normal file
12
lerobot/common/utils/io_utils.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
|
import imageio
|
||||||
|
|
||||||
|
|
||||||
|
def write_video(video_path, stacked_frames, fps):
|
||||||
|
# Filter out DeprecationWarnings raised from pkg_resources
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
|
||||||
|
)
|
||||||
|
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||||
@@ -92,7 +92,8 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
|
|||||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||||
# Hydra needs a path relative to this file.
|
# Hydra needs a path relative to this file.
|
||||||
hydra.initialize(
|
hydra.initialize(
|
||||||
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent))
|
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)),
|
||||||
|
version_base="1.2",
|
||||||
)
|
)
|
||||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||||
return cfg
|
return cfg
|
||||||
|
|||||||
@@ -36,16 +36,16 @@ policy:
|
|||||||
action: ["${env.action_dim}"]
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes:
|
input_normalization_modes:
|
||||||
observation.images.top: mean_std
|
observation.images.top: mean_std
|
||||||
observation.state: mean_std
|
observation.state: mean_std
|
||||||
unnormalize_output_modes:
|
output_normalization_modes:
|
||||||
action: mean_std
|
action: mean_std
|
||||||
|
|
||||||
# Architecture.
|
# Architecture.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
vision_backbone: resnet18
|
vision_backbone: resnet18
|
||||||
use_pretrained_backbone: true
|
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||||
replace_final_stride_with_dilation: false
|
replace_final_stride_with_dilation: false
|
||||||
# Transformer layers.
|
# Transformer layers.
|
||||||
pre_norm: false
|
pre_norm: false
|
||||||
|
|||||||
@@ -50,10 +50,10 @@ policy:
|
|||||||
action: ["${env.action_dim}"]
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
normalize_input_modes:
|
input_normalization_modes:
|
||||||
observation.image: mean_std
|
observation.image: mean_std
|
||||||
observation.state: min_max
|
observation.state: min_max
|
||||||
unnormalize_output_modes:
|
output_normalization_modes:
|
||||||
action: min_max
|
action: min_max
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
@@ -61,7 +61,7 @@ policy:
|
|||||||
vision_backbone: resnet18
|
vision_backbone: resnet18
|
||||||
crop_shape: [84, 84]
|
crop_shape: [84, 84]
|
||||||
crop_is_random: True
|
crop_is_random: True
|
||||||
use_pretrained_backbone: false
|
pretrained_backbone_weights: null
|
||||||
use_group_norm: True
|
use_group_norm: True
|
||||||
spatial_softmax_num_keypoints: 32
|
spatial_softmax_num_keypoints: 32
|
||||||
# Unet.
|
# Unet.
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
import einops
|
import einops
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import imageio
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, Features, Image, Sequence, Value
|
from datasets import Dataset, Features, Image, Sequence, Value
|
||||||
@@ -51,13 +50,10 @@ from lerobot.common.envs.factory import make_env
|
|||||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.utils.io_utils import write_video
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||||
|
|
||||||
|
|
||||||
def write_video(video_path, stacked_frames, fps):
|
|
||||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
|
||||||
|
|
||||||
|
|
||||||
def eval_policy(
|
def eval_policy(
|
||||||
env: gym.vector.VectorEnv,
|
env: gym.vector.VectorEnv,
|
||||||
policy: torch.nn.Module,
|
policy: torch.nn.Module,
|
||||||
|
|||||||
338
lerobot/scripts/push_dataset_to_hub.py
Normal file
338
lerobot/scripts/push_dataset_to_hub.py
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from datasets import Dataset
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.aloha_processor import (
|
||||||
|
AlohaProcessor,
|
||||||
|
)
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.pusht_processor import PushTProcessor
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.umi_processor import UmiProcessor
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub.xarm_processor import XarmProcessor
|
||||||
|
from lerobot.common.datasets.utils import compute_stats, flatten_dict
|
||||||
|
|
||||||
|
|
||||||
|
def push_lerobot_dataset_to_hub(
|
||||||
|
hf_dataset: Dataset,
|
||||||
|
episode_data_index: dict[str, list[int]],
|
||||||
|
info: dict[str, Any],
|
||||||
|
stats: dict[str, dict[str, torch.Tensor]],
|
||||||
|
root: Path,
|
||||||
|
revision: str,
|
||||||
|
dataset_id: str,
|
||||||
|
community_id: str = "lerobot",
|
||||||
|
dry_run: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Pushes a dataset to the Hugging Face Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hf_dataset (Dataset): The dataset to be pushed.
|
||||||
|
episode_data_index (dict[str, list[int]]): The index of episode data.
|
||||||
|
info (dict[str, Any]): Information about the dataset, eg. fps.
|
||||||
|
stats (dict[str, dict[str, torch.Tensor]]): Statistics of the dataset.
|
||||||
|
root (Path): The root directory of the dataset.
|
||||||
|
revision (str): The revision of the dataset.
|
||||||
|
dataset_id (str): The ID of the dataset.
|
||||||
|
community_id (str, optional): The ID of the community or the user where the
|
||||||
|
dataset will be stored. Defaults to "lerobot".
|
||||||
|
dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False.
|
||||||
|
"""
|
||||||
|
if not dry_run:
|
||||||
|
# push to main to indicate latest version
|
||||||
|
hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True)
|
||||||
|
|
||||||
|
# push to version branch
|
||||||
|
hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True, revision=revision)
|
||||||
|
|
||||||
|
# create and store meta_data
|
||||||
|
meta_data_dir = root / community_id / dataset_id / "meta_data"
|
||||||
|
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# info
|
||||||
|
info_path = meta_data_dir / "info.json"
|
||||||
|
|
||||||
|
with open(str(info_path), "w") as f:
|
||||||
|
json.dump(info, f, indent=4)
|
||||||
|
# stats
|
||||||
|
stats_path = meta_data_dir / "stats.safetensors"
|
||||||
|
save_file(flatten_dict(stats), stats_path)
|
||||||
|
|
||||||
|
# episode_data_index
|
||||||
|
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
||||||
|
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
||||||
|
save_file(episode_data_index, ep_data_idx_path)
|
||||||
|
|
||||||
|
if not dry_run:
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=info_path,
|
||||||
|
path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||||
|
repo_id=f"{community_id}/{dataset_id}",
|
||||||
|
repo_type="dataset",
|
||||||
|
)
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=info_path,
|
||||||
|
path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||||
|
repo_id=f"{community_id}/{dataset_id}",
|
||||||
|
repo_type="dataset",
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
# stats
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=stats_path,
|
||||||
|
path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||||
|
repo_id=f"{community_id}/{dataset_id}",
|
||||||
|
repo_type="dataset",
|
||||||
|
)
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=stats_path,
|
||||||
|
path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||||
|
repo_id=f"{community_id}/{dataset_id}",
|
||||||
|
repo_type="dataset",
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=ep_data_idx_path,
|
||||||
|
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||||
|
repo_id=f"{community_id}/{dataset_id}",
|
||||||
|
repo_type="dataset",
|
||||||
|
)
|
||||||
|
api.upload_file(
|
||||||
|
path_or_fileobj=ep_data_idx_path,
|
||||||
|
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||||
|
repo_id=f"{community_id}/{dataset_id}",
|
||||||
|
repo_type="dataset",
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
# copy in tests folder, the first episode and the meta_data directory
|
||||||
|
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||||
|
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
|
||||||
|
f"tests/data/{community_id}/{dataset_id}/train"
|
||||||
|
)
|
||||||
|
if Path(f"tests/data/{community_id}/{dataset_id}/meta_data").exists():
|
||||||
|
shutil.rmtree(f"tests/data/{community_id}/{dataset_id}/meta_data")
|
||||||
|
shutil.copytree(meta_data_dir, f"tests/data/{community_id}/{dataset_id}/meta_data")
|
||||||
|
|
||||||
|
|
||||||
|
def push_dataset_to_hub(
|
||||||
|
dataset_id: str,
|
||||||
|
root: Path,
|
||||||
|
fps: int | None,
|
||||||
|
dataset_folder: Path | None = None,
|
||||||
|
dry_run: bool = False,
|
||||||
|
revision: str = "v1.1",
|
||||||
|
community_id: str = "lerobot",
|
||||||
|
no_preprocess: bool = False,
|
||||||
|
path_save_to_disk: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Download a raw dataset if needed or access a local raw dataset, detect the raw format (e.g. aloha, pusht, umi) and process it accordingly in a common data format which is then pushed to the Hugging Face Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_id (str): The ID of the dataset.
|
||||||
|
root (Path): The root directory where the dataset will be downloaded.
|
||||||
|
fps (int | None): The desired frames per second for the dataset.
|
||||||
|
dataset_folder (Path | None, optional): The path to the dataset folder. If not provided, the dataset will be downloaded using the dataset ID. Defaults to None.
|
||||||
|
dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False.
|
||||||
|
revision (str, optional): Version of the `push_dataset_to_hub.py` codebase used to preprocess the dataset. Defaults to "v1.1".
|
||||||
|
community_id (str, optional): The ID of the community. Defaults to "lerobot".
|
||||||
|
no_preprocess (bool, optional): If True, does not preprocesses the dataset. Defaults to False.
|
||||||
|
path_save_to_disk (str | None, optional): The path to save the dataset to disk. Works when `dry_run` is True, which allows to only save on disk without uploading. By default, the dataset is not saved on disk.
|
||||||
|
**kwargs: Additional keyword arguments for the preprocessor init method.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
if dataset_folder is None:
|
||||||
|
dataset_folder = download_raw(root=root, dataset_id=dataset_id)
|
||||||
|
|
||||||
|
if not no_preprocess:
|
||||||
|
processor = guess_dataset_type(dataset_folder=dataset_folder, fps=fps, **kwargs)
|
||||||
|
data_dict, episode_data_index = processor.preprocess()
|
||||||
|
hf_dataset = processor.to_hf_dataset(data_dict)
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"fps": processor.fps,
|
||||||
|
}
|
||||||
|
stats: dict[str, dict[str, torch.Tensor]] = compute_stats(hf_dataset)
|
||||||
|
|
||||||
|
push_lerobot_dataset_to_hub(
|
||||||
|
hf_dataset=hf_dataset,
|
||||||
|
episode_data_index=episode_data_index,
|
||||||
|
info=info,
|
||||||
|
stats=stats,
|
||||||
|
root=root,
|
||||||
|
revision=revision,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
community_id=community_id,
|
||||||
|
dry_run=dry_run,
|
||||||
|
)
|
||||||
|
if path_save_to_disk:
|
||||||
|
hf_dataset.with_format("torch").save_to_disk(dataset_path=str(path_save_to_disk))
|
||||||
|
|
||||||
|
processor.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetProcessor(Protocol):
|
||||||
|
"""A class for processing datasets.
|
||||||
|
|
||||||
|
This class provides methods for validating, preprocessing, and converting datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_path (str): The path to the folder containing the dataset.
|
||||||
|
fps (int | None): The frames per second of the dataset. If None, the default value is used.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, folder_path: str, fps: int | None, *args, **kwargs) -> None: ...
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
"""Check if the dataset is valid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the dataset is valid, False otherwise.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def preprocess(self) -> tuple[dict, dict]:
|
||||||
|
"""Preprocess the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[dict, dict]: A tuple containing two dictionaries representing the preprocessed data.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def to_hf_dataset(self, data_dict: dict) -> Dataset:
|
||||||
|
"""Convert the preprocessed data to a Hugging Face dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_dict (dict): The preprocessed data.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset: The converted Hugging Face dataset.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps(self) -> int:
|
||||||
|
"""Get the frames per second of the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The frames per second.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Clean up any resources used by the dataset processor."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def guess_dataset_type(dataset_folder: Path, **processor_kwargs) -> DatasetProcessor:
|
||||||
|
if (processor := AlohaProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
||||||
|
return processor
|
||||||
|
if (processor := XarmProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
||||||
|
return processor
|
||||||
|
if (processor := PushTProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
||||||
|
return processor
|
||||||
|
if (processor := UmiProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
||||||
|
return processor
|
||||||
|
# TODO: Propose a registration mechanism for new dataset types
|
||||||
|
raise ValueError(f"Could not guess dataset type for folder {dataset_folder}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
Main function to process command line arguments and push dataset to Hugging Face Hub.
|
||||||
|
|
||||||
|
Parses command line arguments to get dataset details and conditions under which the dataset
|
||||||
|
is processed and pushed. It manages dataset preparation and uploading based on the user-defined parameters.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Push a dataset to the Hugging Face Hub with optional parameters for customization.",
|
||||||
|
epilog="""
|
||||||
|
Example usage:
|
||||||
|
python -m lerobot.scripts.push_dataset_to_hub --dataset-folder /path/to/dataset --dataset-id example_dataset --root /path/to/root --dry-run --revision v2.0 --community-id example_community --fps 30 --path-save-to-disk /path/to/save --no-preprocess
|
||||||
|
|
||||||
|
This processes and optionally pushes 'example_dataset' located in '/path/to/dataset' to Hugging Face Hub,
|
||||||
|
with various parameters to control the processing and uploading behavior.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-folder",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="The filesystem path to the dataset folder. If not provided, the dataset must be identified and managed by other means.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-id",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Unique identifier for the dataset to be processed and uploaded.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--root", type=Path, required=True, help="Root directory where the dataset operations are managed."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dry-run",
|
||||||
|
action="store_true",
|
||||||
|
help="Simulate the push process without uploading any data, for testing purposes.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--community-id",
|
||||||
|
type=str,
|
||||||
|
default="lerobot",
|
||||||
|
help="Community or user ID under which the dataset will be hosted on the Hub.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fps",
|
||||||
|
type=int,
|
||||||
|
help="Target frame rate for video or image sequence datasets. Optional and applicable only if the dataset includes temporal media.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--revision",
|
||||||
|
type=str,
|
||||||
|
default="v1.0",
|
||||||
|
help="Dataset version identifier to manage different iterations of the dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-preprocess",
|
||||||
|
action="store_true",
|
||||||
|
help="Does not preprocess the dataset, set this flag if you only want dowload the dataset raw.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--path-save-to-disk",
|
||||||
|
type=Path,
|
||||||
|
help="Optional path where the processed dataset can be saved locally.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
push_dataset_to_hub(
|
||||||
|
dataset_folder=args.dataset_folder,
|
||||||
|
dataset_id=args.dataset_id,
|
||||||
|
root=args.root,
|
||||||
|
fps=args.fps,
|
||||||
|
dry_run=args.dry_run,
|
||||||
|
community_id=args.community_id,
|
||||||
|
revision=args.revision,
|
||||||
|
no_preprocess=args.no_preprocess,
|
||||||
|
path_save_to_disk=args.path_save_to_disk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -22,7 +22,7 @@ from lerobot.common.utils.utils import (
|
|||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||||
def train_cli(cfg: dict):
|
def train_cli(cfg: dict):
|
||||||
train(
|
train(
|
||||||
cfg,
|
cfg,
|
||||||
@@ -258,7 +258,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
policy,
|
policy,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
max_episodes_rendered=4,
|
max_episodes_rendered=4,
|
||||||
transform=offline_dataset.transform,
|
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ MAX_NUM_STEPS = 1000
|
|||||||
FIRST_FRAME = 0
|
FIRST_FRAME = 0
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||||
def visualize_dataset_cli(cfg: dict):
|
def visualize_dataset_cli(cfg: dict):
|
||||||
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
|
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
|
||||||
|
|
||||||
|
|||||||
1397
poetry.lock
generated
1397
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,7 @@ authors = [
|
|||||||
"Alexander Soare <alexander.soare159@gmail.com>",
|
"Alexander Soare <alexander.soare159@gmail.com>",
|
||||||
"Quentin Gallouédec <quentin.gallouedec@ec-lyon.fr>",
|
"Quentin Gallouédec <quentin.gallouedec@ec-lyon.fr>",
|
||||||
"Simon Alibert <alibert.sim@gmail.com>",
|
"Simon Alibert <alibert.sim@gmail.com>",
|
||||||
|
"Adil Zouitine <adilzouitinegm@gmail.com>",
|
||||||
"Thomas Wolf <thomaswolfcontact@gmail.com>",
|
"Thomas Wolf <thomaswolfcontact@gmail.com>",
|
||||||
]
|
]
|
||||||
repository = "https://github.com/huggingface/lerobot"
|
repository = "https://github.com/huggingface/lerobot"
|
||||||
@@ -33,14 +34,14 @@ wandb = "^0.16.3"
|
|||||||
imageio = {extras = ["ffmpeg"], version = "^2.34.0"}
|
imageio = {extras = ["ffmpeg"], version = "^2.34.0"}
|
||||||
gdown = "^5.1.0"
|
gdown = "^5.1.0"
|
||||||
hydra-core = "^1.3.2"
|
hydra-core = "^1.3.2"
|
||||||
einops = "^0.7.0"
|
einops = "^0.8.0"
|
||||||
pymunk = "^6.6.0"
|
pymunk = "^6.6.0"
|
||||||
zarr = "^2.17.0"
|
zarr = "^2.17.0"
|
||||||
numba = "^0.59.0"
|
numba = "^0.59.0"
|
||||||
torch = "^2.2.1"
|
torch = "^2.2.1"
|
||||||
opencv-python = "^4.9.0.80"
|
opencv-python = "^4.9.0.80"
|
||||||
diffusers = "^0.26.3"
|
diffusers = "^0.27.2"
|
||||||
torchvision = "^0.17.1"
|
torchvision = "^0.18.0"
|
||||||
h5py = "^3.10.0"
|
h5py = "^3.10.0"
|
||||||
huggingface-hub = "^0.21.4"
|
huggingface-hub = "^0.21.4"
|
||||||
robomimic = "0.2.0"
|
robomimic = "0.2.0"
|
||||||
@@ -54,6 +55,8 @@ debugpy = {version = "^1.8.1", optional = true}
|
|||||||
pytest = {version = "^8.1.0", optional = true}
|
pytest = {version = "^8.1.0", optional = true}
|
||||||
pytest-cov = {version = "^5.0.0", optional = true}
|
pytest-cov = {version = "^5.0.0", optional = true}
|
||||||
datasets = "^2.19.0"
|
datasets = "^2.19.0"
|
||||||
|
imagecodecs = { version = "^2024.1.1", optional = true }
|
||||||
|
torchaudio = "^2.3.0"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
@@ -62,12 +65,13 @@ xarm = ["gym-xarm"]
|
|||||||
aloha = ["gym-aloha"]
|
aloha = ["gym-aloha"]
|
||||||
dev = ["pre-commit", "debugpy"]
|
dev = ["pre-commit", "debugpy"]
|
||||||
test = ["pytest", "pytest-cov"]
|
test = ["pytest", "pytest-cov"]
|
||||||
|
umi = ["imagecodecs"]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 110
|
line-length = 110
|
||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
exclude = [
|
exclude = [
|
||||||
|
"tests/data",
|
||||||
".bzr",
|
".bzr",
|
||||||
".direnv",
|
".direnv",
|
||||||
".eggs",
|
".eggs",
|
||||||
|
|||||||
5
tests/conftest.py
Normal file
5
tests/conftest.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .utils import DEVICE
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_finish():
|
||||||
|
print(f"\nTesting with {DEVICE=}")
|
||||||
Binary file not shown.
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"fps": 10
|
||||||
|
}
|
||||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,67 @@
|
|||||||
|
{
|
||||||
|
"citation": "",
|
||||||
|
"description": "",
|
||||||
|
"features": {
|
||||||
|
"observation.state": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 7,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"episode_index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"frame_index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"timestamp": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"episode_data_index_from": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"episode_data_index_to": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"end_pose": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 6,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"start_pos": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 6,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"gripper_width": {
|
||||||
|
"feature": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"length": 1,
|
||||||
|
"_type": "Sequence"
|
||||||
|
},
|
||||||
|
"index": {
|
||||||
|
"dtype": "int64",
|
||||||
|
"_type": "Value"
|
||||||
|
},
|
||||||
|
"observation.image": {
|
||||||
|
"_type": "Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"homepage": "",
|
||||||
|
"license": ""
|
||||||
|
}
|
||||||
13
tests/data/lerobot/umi_cup_in_the_wild/train/state.json
Normal file
13
tests/data/lerobot/umi_cup_in_the_wild/train/state.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"_data_files": [
|
||||||
|
{
|
||||||
|
"filename": "data-00000-of-00001.arrow"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"_fingerprint": "fd95ee932cb1fce2",
|
||||||
|
"_format_columns": null,
|
||||||
|
"_format_kwargs": {},
|
||||||
|
"_format_type": "torch",
|
||||||
|
"_output_all_columns": false,
|
||||||
|
"_split": null
|
||||||
|
}
|
||||||
@@ -16,6 +16,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from lerobot import available_datasets
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
@@ -26,8 +27,7 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
|||||||
shutil.rmtree(data_dir)
|
shutil.rmtree(data_dir)
|
||||||
|
|
||||||
data_dir.mkdir(parents=True, exist_ok=True)
|
data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
dataset = LeRobotDataset(repo_id=repo_id, root=data_dir)
|
||||||
dataset = LeRobotDataset(repo_id)
|
|
||||||
|
|
||||||
# save 2 first frames of first episode
|
# save 2 first frames of first episode
|
||||||
i = dataset.episode_data_index["from"][0].item()
|
i = dataset.episode_data_index["from"][0].item()
|
||||||
@@ -64,4 +64,11 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors")
|
available_datasets = [
|
||||||
|
"lerobot/pusht",
|
||||||
|
"lerobot/xarm_push_medium",
|
||||||
|
"lerobot/aloha_sim_insertion_human",
|
||||||
|
"lerobot/umi_cup_in_the_wild",
|
||||||
|
]
|
||||||
|
for dataset in available_datasets:
|
||||||
|
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
|
||||||
|
|||||||
@@ -241,57 +241,65 @@ def test_flatten_unflatten_dict():
|
|||||||
def test_backward_compatibility():
|
def test_backward_compatibility():
|
||||||
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||||
|
|
||||||
repo_id = "lerobot/pusht"
|
all_repo_id = [
|
||||||
|
"lerobot/pusht",
|
||||||
|
# TODO (azouitine): Add artifacts for the following datasets
|
||||||
|
# "lerobot/aloha_sim_insertion_human",
|
||||||
|
# "lerobot/xarm_push_medium",
|
||||||
|
# "lerobot/umi_cup_in_the_wild",
|
||||||
|
]
|
||||||
|
for repo_id in all_repo_id:
|
||||||
|
dataset = LeRobotDataset(
|
||||||
|
repo_id,
|
||||||
|
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||||
|
)
|
||||||
|
|
||||||
dataset = LeRobotDataset(
|
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||||
repo_id,
|
|
||||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
def load_and_compare(i):
|
||||||
|
new_frame = dataset[i] # noqa: B023
|
||||||
|
old_frame = load_file(data_dir / f"frame_{i}.safetensors") # noqa: B023
|
||||||
|
|
||||||
def load_and_compare(i):
|
new_keys = set(new_frame.keys())
|
||||||
new_frame = dataset[i]
|
old_keys = set(old_frame.keys())
|
||||||
old_frame = load_file(data_dir / f"frame_{i}.safetensors")
|
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||||
|
|
||||||
new_keys = set(new_frame.keys())
|
for key in new_frame:
|
||||||
old_keys = set(old_frame.keys())
|
assert (
|
||||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
new_frame[key] == old_frame[key]
|
||||||
|
).all(), f"{key=} for index={i} does not contain the same value"
|
||||||
|
|
||||||
for key in new_frame:
|
# test2 first frames of first episode
|
||||||
assert (
|
i = dataset.episode_data_index["from"][0].item()
|
||||||
new_frame[key] == old_frame[key]
|
load_and_compare(i)
|
||||||
).all(), f"{key=} for index={i} does not contain the same value"
|
load_and_compare(i + 1)
|
||||||
|
|
||||||
# test2 first frames of first episode
|
# test 2 frames at the middle of first episode
|
||||||
i = dataset.episode_data_index["from"][0].item()
|
i = int(
|
||||||
load_and_compare(i)
|
(dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2
|
||||||
load_and_compare(i + 1)
|
)
|
||||||
|
load_and_compare(i)
|
||||||
|
load_and_compare(i + 1)
|
||||||
|
|
||||||
# test 2 frames at the middle of first episode
|
# test 2 last frames of first episode
|
||||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
i = dataset.episode_data_index["to"][0].item()
|
||||||
load_and_compare(i)
|
load_and_compare(i - 2)
|
||||||
load_and_compare(i + 1)
|
load_and_compare(i - 1)
|
||||||
|
|
||||||
# test 2 last frames of first episode
|
# TODO(rcadene): Enable testing on second and last episode
|
||||||
i = dataset.episode_data_index["to"][0].item()
|
# We currently cant because our test dataset only contains the first episode
|
||||||
load_and_compare(i - 2)
|
|
||||||
load_and_compare(i - 1)
|
|
||||||
|
|
||||||
# TODO(rcadene): Enable testing on second and last episode
|
# # test 2 first frames of second episode
|
||||||
# We currently cant because our test dataset only contains the first episode
|
# i = dataset.episode_data_index["from"][1].item()
|
||||||
|
# load_and_compare(i)
|
||||||
|
# load_and_compare(i+1)
|
||||||
|
|
||||||
# # test 2 first frames of second episode
|
# #test 2 last frames of second episode
|
||||||
# i = dataset.episode_data_index["from"][1].item()
|
# i = dataset.episode_data_index["to"][1].item()
|
||||||
# load_and_compare(i)
|
# load_and_compare(i-2)
|
||||||
# load_and_compare(i+1)
|
# load_and_compare(i-1)
|
||||||
|
|
||||||
# #test 2 last frames of second episode
|
# # test 2 last frames of last episode
|
||||||
# i = dataset.episode_data_index["to"][1].item()
|
# i = dataset.episode_data_index["to"][-1].item()
|
||||||
# load_and_compare(i-2)
|
# load_and_compare(i-2)
|
||||||
# load_and_compare(i-1)
|
# load_and_compare(i-1)
|
||||||
|
|
||||||
# # test 2 last frames of last episode
|
|
||||||
# i = dataset.episode_data_index["to"][-1].item()
|
|
||||||
# load_and_compare(i-2)
|
|
||||||
# load_and_compare(i-1)
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# TODO(aliberts): Mute logging for these tests
|
# TODO(aliberts): Mute logging for these tests
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
@@ -11,28 +12,22 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
|
|||||||
|
|
||||||
|
|
||||||
def _run_script(path):
|
def _run_script(path):
|
||||||
subprocess.run(["python", path], check=True)
|
subprocess.run([sys.executable, path], check=True)
|
||||||
|
|
||||||
|
|
||||||
def test_example_1():
|
def test_example_1():
|
||||||
path = "examples/1_load_hugging_face_dataset.py"
|
path = "examples/1_load_lerobot_dataset.py"
|
||||||
_run_script(path)
|
_run_script(path)
|
||||||
assert Path("outputs/examples/1_load_hugging_face_dataset/episode_5.mp4").exists()
|
assert Path("outputs/examples/1_load_lerobot_dataset/episode_5.mp4").exists()
|
||||||
|
|
||||||
|
|
||||||
def test_example_2():
|
def test_examples_3_and_2():
|
||||||
path = "examples/2_load_lerobot_dataset.py"
|
|
||||||
_run_script(path)
|
|
||||||
assert Path("outputs/examples/2_load_lerobot_dataset/episode_5.mp4").exists()
|
|
||||||
|
|
||||||
|
|
||||||
def test_examples_4_and_3():
|
|
||||||
"""
|
"""
|
||||||
Train a model with example 3, check the outputs.
|
Train a model with example 3, check the outputs.
|
||||||
Evaluate the trained model with example 2, check the outputs.
|
Evaluate the trained model with example 2, check the outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
path = "examples/4_train_policy.py"
|
path = "examples/3_train_policy.py"
|
||||||
|
|
||||||
with open(path) as file:
|
with open(path) as file:
|
||||||
file_contents = file.read()
|
file_contents = file.read()
|
||||||
@@ -54,7 +49,7 @@ def test_examples_4_and_3():
|
|||||||
for file_name in ["model.pt", "config.yaml"]:
|
for file_name in ["model.pt", "config.yaml"]:
|
||||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||||
|
|
||||||
path = "examples/3_evaluate_pretrained_policy.py"
|
path = "examples/2_evaluate_pretrained_policy.py"
|
||||||
|
|
||||||
with open(path) as file:
|
with open(path) as file:
|
||||||
file_contents = file.read()
|
file_contents = file.read()
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ from lerobot.common.datasets.factory import make_dataset
|
|||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||||
|
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||||
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
@@ -113,6 +115,15 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||||||
new_policy.load_state_dict(policy.state_dict())
|
new_policy.load_state_dict(policy.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ActionChunkingTransformerPolicy])
|
||||||
|
def test_policy_defaults(policy_cls):
|
||||||
|
kwargs = {}
|
||||||
|
# TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP.
|
||||||
|
if policy_cls is DiffusionPolicy:
|
||||||
|
kwargs = {"lr_scheduler_num_training_steps": 1}
|
||||||
|
policy_cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"insert_temporal_dim",
|
"insert_temporal_dim",
|
||||||
[
|
[
|
||||||
|
|||||||
Reference in New Issue
Block a user