Compare commits
67 Commits
my-fix-bas
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41132be602 | ||
|
|
8746276d41 | ||
|
|
f07887e8d1 | ||
|
|
8d360927af | ||
|
|
e07cb52baa | ||
|
|
e88af0e588 | ||
|
|
1ecaeabad0 | ||
|
|
0309a9fcbc | ||
|
|
588bf96559 | ||
|
|
e11d2e4197 | ||
|
|
253c649507 | ||
|
|
71715c3914 | ||
|
|
7c005c2aa1 | ||
|
|
d518b036d0 | ||
|
|
367d9bda7d | ||
|
|
601b5fdbfe | ||
|
|
20b74ae1eb | ||
|
|
b9b880bd8b | ||
|
|
5bd9cb1e72 | ||
|
|
2866d0770f | ||
|
|
4375a05a9f | ||
|
|
4acf99f622 | ||
|
|
5a6ea09248 | ||
|
|
9c0836c8d0 | ||
|
|
b0cca75e5e | ||
|
|
54b5c805bf | ||
|
|
eab5543750 | ||
|
|
6b6a990f4c | ||
|
|
c2a05a1fde | ||
|
|
6c4d122198 | ||
|
|
34c5d4ce07 | ||
|
|
c1b28f0b58 | ||
|
|
53ecec5fb2 | ||
|
|
65738f0a80 | ||
|
|
5d184a7811 | ||
|
|
1a5c1ef9c7 | ||
|
|
7866c1f7d1 | ||
|
|
3666ac9346 | ||
|
|
3daab2acbb | ||
|
|
c36d2253d0 | ||
|
|
e2e6f6e666 | ||
|
|
ff0029f84b | ||
|
|
39ad2d16d4 | ||
|
|
689c5efc72 | ||
|
|
eda0b996cd | ||
|
|
15e7a9d541 | ||
|
|
52fb4143b5 | ||
|
|
93c80b2cb1 | ||
|
|
5fbbaa1bc0 | ||
|
|
71d1f5e2c9 | ||
|
|
b520941cd9 | ||
|
|
64ed5258e6 | ||
|
|
392a8c32a7 | ||
|
|
969ef745a2 | ||
|
|
6fe42a72db | ||
|
|
2487228ea7 | ||
|
|
76436ca1de | ||
|
|
fbf2f2222a | ||
|
|
02bc4e03e0 | ||
|
|
624eaf1175 | ||
|
|
aed3eb4a94 | ||
|
|
8426c64f42 | ||
|
|
7c2bbee613 | ||
|
|
9d6886dd08 | ||
|
|
d67ca342e9 | ||
|
|
57c9c21c39 | ||
|
|
38c14571cc |
24
.github/workflows/build-docker-images.yml
vendored
@@ -40,24 +40,24 @@ jobs:
|
||||
git lfs install
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push CPU
|
||||
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/lerobot-cpu/Dockerfile
|
||||
@@ -78,24 +78,24 @@ jobs:
|
||||
git lfs install
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/lerobot-gpu/Dockerfile
|
||||
@@ -110,23 +110,23 @@ jobs:
|
||||
group: aws-general-8-plus
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU dev
|
||||
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/lerobot-gpu-dev/Dockerfile
|
||||
|
||||
23
.github/workflows/build_documentation.yml
vendored
@@ -1,23 +0,0 @@
|
||||
name: Build documentation
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
paths:
|
||||
- "docs/**"
|
||||
branches:
|
||||
- main
|
||||
- doc-builder*
|
||||
- v*-release
|
||||
|
||||
|
||||
jobs:
|
||||
build: # zizmor: ignore[excessive-permissions] We follow the same pattern as in Transformers
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: lerobot
|
||||
additional_args: --not_python_module
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
19
.github/workflows/build_pr_documentation.yml
vendored
@@ -1,19 +0,0 @@
|
||||
name: Build PR Documentation
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "docs/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build: # zizmor: ignore[excessive-permissions] We follow the same pattern as in Transformers
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
package: lerobot
|
||||
additional_args: --not_python_module
|
||||
4
.github/workflows/nightly-tests.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
container:
|
||||
image: huggingface/lerobot-cpu:latest # zizmor: ignore[unpinned-images]
|
||||
image: huggingface/lerobot-cpu:latest
|
||||
options: --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
TEST_TYPE: "single_gpu"
|
||||
container:
|
||||
image: huggingface/lerobot-gpu:latest # zizmor: ignore[unpinned-images]
|
||||
image: huggingface/lerobot-gpu:latest
|
||||
options: --gpus all --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
|
||||
8
.github/workflows/quality.yml
vendored
@@ -33,12 +33,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@7f4fc3e22c37d6ff65e88745f38bd3157c663f7c # v4.9.1
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
@@ -64,9 +64,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@db35ee91e80fbb447f33b0e5fbddb24d2a1a884f # v1.29.10
|
||||
uses: crate-ci/typos@v1.29.10
|
||||
|
||||
8
.github/workflows/test-docker-build.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -64,17 +64,17 @@ jobs:
|
||||
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build Docker image
|
||||
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
file: ${{ matrix.docker-file }}
|
||||
context: .
|
||||
|
||||
12
.github/workflows/test.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
@@ -85,7 +85,7 @@ jobs:
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
@@ -117,7 +117,7 @@ jobs:
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
@@ -129,7 +129,7 @@ jobs:
|
||||
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
|
||||
4
.github/workflows/trufflehog.yml
vendored
@@ -24,12 +24,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@90694bf9af66e7536abc5824e7a87246dbf933cb # v3.88.35
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
with:
|
||||
extra_args: --only-verified
|
||||
|
||||
16
.github/workflows/upload_pr_documentation.yml
vendored
@@ -1,16 +0,0 @@
|
||||
name: Upload PR Documentation
|
||||
|
||||
on: # zizmor: ignore[dangerous-triggers] We follow the same pattern as in Transformers
|
||||
workflow_run:
|
||||
workflows: [ "Build PR Documentation" ]
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
build: # zizmor: ignore[excessive-permissions] We follow the same pattern as in Transformers
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
|
||||
with:
|
||||
package_name: lerobot
|
||||
secrets:
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
|
||||
@@ -37,18 +37,18 @@ repos:
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/adhtruong/mirrors-typos
|
||||
rev: v1.32.0
|
||||
rev: v1.31.1
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.20.0
|
||||
rev: v3.19.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.11
|
||||
rev: v0.11.5
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
@@ -57,12 +57,12 @@ repos:
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.26.0
|
||||
rev: v8.24.3
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.8.0
|
||||
rev: v1.5.2
|
||||
hooks:
|
||||
- id: zizmor
|
||||
|
||||
|
||||
40
README.md
@@ -23,35 +23,21 @@
|
||||
</div>
|
||||
|
||||
<h2 align="center">
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/12_use_so101.md">
|
||||
Build Your Own SO-101 Robot!</a></p>
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">
|
||||
Build Your Own SO-100 Robot!</a></p>
|
||||
</h2>
|
||||
|
||||
<div align="center">
|
||||
<div style="display: flex; gap: 1rem; justify-content: center; align-items: center;" >
|
||||
<img
|
||||
src="media/so101/so101.webp?raw=true"
|
||||
alt="SO-101 follower arm"
|
||||
title="SO-101 follower arm"
|
||||
style="width: 40%;"
|
||||
/>
|
||||
<img
|
||||
src="media/so101/so101-leader.webp?raw=true"
|
||||
alt="SO-101 leader arm"
|
||||
title="SO-101 leader arm"
|
||||
style="width: 40%;"
|
||||
/>
|
||||
</div>
|
||||
<img src="media/so100/leader_follower.webp?raw=true" alt="SO-100 leader and follower arms" title="SO-100 leader and follower arms" width="50%">
|
||||
|
||||
|
||||
<p><strong>Meet the updated SO100, the SO-101 – Just €114 per arm!</strong></p>
|
||||
<p><strong>Meet the SO-100 – Just $110 per arm!</strong></p>
|
||||
<p>Train it in minutes with a few simple moves on your laptop.</p>
|
||||
<p>Then sit back and watch your creation act autonomously! 🤯</p>
|
||||
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/12_use_so101.md">
|
||||
See the full SO-101 tutorial here.</a></p>
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">
|
||||
Get the full SO-100 tutorial here.</a></p>
|
||||
|
||||
<p>Want to take it to the next level? Make your SO-101 mobile by building LeKiwi!</p>
|
||||
<p>Want to take it to the next level? Make your SO-100 mobile by building LeKiwi!</p>
|
||||
<p>Check out the <a href="https://github.com/huggingface/lerobot/blob/main/examples/11_use_lekiwi.md">LeKiwi tutorial</a> and bring your robot to life on wheels.</p>
|
||||
|
||||
<img src="media/lekiwi/kiwi.webp?raw=true" alt="LeKiwi mobile robot" title="LeKiwi mobile robot" width="50%">
|
||||
@@ -65,6 +51,7 @@
|
||||
|
||||
---
|
||||
|
||||
|
||||
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
|
||||
|
||||
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
|
||||
@@ -211,6 +198,7 @@ Under the hood, the `LeRobotDataset` format makes use of several ways to seriali
|
||||
Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects:
|
||||
|
||||
```
|
||||
TODO: IMPROVE
|
||||
dataset attributes:
|
||||
├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example:
|
||||
│ ├ observation.images.cam_high (VideoFrame):
|
||||
@@ -221,9 +209,9 @@ dataset attributes:
|
||||
│ ├ episode_index (int64): index of the episode for this sample
|
||||
│ ├ frame_index (int64): index of the frame for this sample in the episode ; starts at 0 for each episode
|
||||
│ ├ timestamp (float32): timestamp in the episode
|
||||
│ ├ next.done (bool): indicates the end of an episode ; True for the last frame in each episode
|
||||
│ ├ next.done (bool): indicates the end of en episode ; True for the last frame in each episode
|
||||
│ └ index (int64): general index in the whole dataset
|
||||
├ episode_data_index: contains 2 tensors with the start and end indices of each episode
|
||||
├ meta: contains 2 tensors with the start and end indices of each episode
|
||||
│ ├ from (1D int64 tensor): first frame index for each episode — shape (num episodes,) starts with 0
|
||||
│ └ to: (1D int64 tensor): last frame index for each episode — shape (num episodes,)
|
||||
├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
|
||||
@@ -270,7 +258,7 @@ See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
|
||||
Check out [example 3](./examples/3_train_policy.py) that illustrate how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
|
||||
|
||||
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding `--wandb.enable=true`.
|
||||
|
||||
@@ -321,7 +309,7 @@ Once you have trained a policy you may upload it to the Hugging Face hub using a
|
||||
You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain:
|
||||
- `config.json`: A serialized version of the policy configuration (following the policy's dataclass config).
|
||||
- `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format.
|
||||
- `train_config.json`: A consolidated configuration containing all parameters used for training. The policy configuration should match `config.json` exactly. This is useful for anyone who wants to evaluate your policy or for reproducibility.
|
||||
- `train_config.json`: A consolidated configuration containing all parameter userd for training. The policy configuration should match `config.json` exactly. Thisis useful for anyone who wants to evaluate your policy or for reproducibility.
|
||||
|
||||
To upload these to the hub, run the following:
|
||||
```bash
|
||||
@@ -360,7 +348,7 @@ with profile(
|
||||
If you want, you can cite this work with:
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
|
||||
@@ -108,7 +108,8 @@ def save_decoded_frames(
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
|
||||
return
|
||||
|
||||
@@ -265,7 +266,8 @@ def benchmark_encoding_decoding(
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
|
||||
num_pixels = width * height
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
@@ -416,7 +418,7 @@ if __name__ == "__main__":
|
||||
"--vcodec",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["libx264", "hevc", "libsvtav1"],
|
||||
default=["libx264", "libx265", "libsvtav1"],
|
||||
help="Video codecs to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -446,7 +448,7 @@ if __name__ == "__main__":
|
||||
# nargs="*",
|
||||
# default=[0, 1],
|
||||
# help="Use the fastdecode tuning option. 0 disables it. "
|
||||
# "For libx264 and libx265/hevc, only 1 is possible. "
|
||||
# "For libx264 and libx265, only 1 is possible. "
|
||||
# "For libsvtav1, 1, 2 or 3 are possible values with a higher number meaning a faster decoding optimization",
|
||||
# )
|
||||
parser.add_argument(
|
||||
|
||||
137
docs/README.md
@@ -1,137 +0,0 @@
|
||||
<!---
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Generating the documentation
|
||||
|
||||
To generate the documentation, you first have to build it. Several packages are necessary to build the doc,
|
||||
you can install them with the following command, at the root of the code repository:
|
||||
|
||||
```bash
|
||||
pip install -e ".[docs]"
|
||||
```
|
||||
|
||||
You will also need `nodejs`. Please refer to their [installation page](https://nodejs.org/en/download)
|
||||
|
||||
---
|
||||
**NOTE**
|
||||
|
||||
You only need to generate the documentation to inspect it locally (if you're planning changes and want to
|
||||
check how they look before committing for instance). You don't have to `git commit` the built documentation.
|
||||
|
||||
---
|
||||
|
||||
## Building the documentation
|
||||
|
||||
Once you have setup the `doc-builder` and additional packages, you can generate the documentation by
|
||||
typing the following command:
|
||||
|
||||
```bash
|
||||
doc-builder build lerobot docs/source/ --build_dir ~/tmp/test-build
|
||||
```
|
||||
|
||||
You can adapt the `--build_dir` to set any temporary folder that you prefer. This command will create it and generate
|
||||
the MDX files that will be rendered as the documentation on the main website. You can inspect them in your favorite
|
||||
Markdown editor.
|
||||
|
||||
## Previewing the documentation
|
||||
|
||||
To preview the docs, first install the `watchdog` module with:
|
||||
|
||||
```bash
|
||||
pip install watchdog
|
||||
```
|
||||
|
||||
Then run the following command:
|
||||
|
||||
```bash
|
||||
doc-builder preview lerobot docs/source/
|
||||
```
|
||||
|
||||
The docs will be viewable at [http://localhost:3000](http://localhost:3000). You can also preview the docs once you have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives.
|
||||
|
||||
---
|
||||
**NOTE**
|
||||
|
||||
The `preview` command only works with existing doc files. When you add a completely new file, you need to update `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again).
|
||||
|
||||
---
|
||||
|
||||
## Adding a new element to the navigation bar
|
||||
|
||||
Accepted files are Markdown (.md).
|
||||
|
||||
Create a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting
|
||||
the filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/lerobot/blob/main/docs/source/_toctree.yml) file.
|
||||
|
||||
## Renaming section headers and moving sections
|
||||
|
||||
It helps to keep the old links working when renaming the section header and/or moving sections from one document to another. This is because the old links are likely to be used in Issues, Forums, and Social media and it'd make for a much more superior user experience if users reading those months later could still easily navigate to the originally intended information.
|
||||
|
||||
Therefore, we simply keep a little map of moved sections at the end of the document where the original section was. The key is to preserve the original anchor.
|
||||
|
||||
So if you renamed a section from: "Section A" to "Section B", then you can add at the end of the file:
|
||||
|
||||
```
|
||||
Sections that were moved:
|
||||
|
||||
[ <a href="#section-b">Section A</a><a id="section-a"></a> ]
|
||||
```
|
||||
and of course, if you moved it to another file, then:
|
||||
|
||||
```
|
||||
Sections that were moved:
|
||||
|
||||
[ <a href="../new-file#section-b">Section A</a><a id="section-a"></a> ]
|
||||
```
|
||||
|
||||
Use the relative style to link to the new file so that the versioned docs continue to work.
|
||||
|
||||
For an example of a rich moved sections set please see the very end of [the transformers Trainer doc](https://github.com/huggingface/transformers/blob/main/docs/source/en/main_classes/trainer.md).
|
||||
|
||||
### Adding a new tutorial
|
||||
|
||||
Adding a new tutorial or section is done in two steps:
|
||||
|
||||
- Add a new file under `./source`. This file can either be ReStructuredText (.rst) or Markdown (.md).
|
||||
- Link that file in `./source/_toctree.yml` on the correct toc-tree.
|
||||
|
||||
Make sure to put your new file under the proper section. If you have a doubt, feel free to ask in a Github Issue or PR.
|
||||
|
||||
### Writing source documentation
|
||||
|
||||
Values that should be put in `code` should either be surrounded by backticks: \`like so\`. Note that argument names
|
||||
and objects like True, None or any strings should usually be put in `code`.
|
||||
|
||||
#### Writing a multi-line code block
|
||||
|
||||
Multi-line code blocks can be useful for displaying examples. They are done between two lines of three backticks as usual in Markdown:
|
||||
|
||||
|
||||
````
|
||||
```
|
||||
# first line of code
|
||||
# second line
|
||||
# etc
|
||||
```
|
||||
````
|
||||
|
||||
#### Adding an image
|
||||
|
||||
Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like
|
||||
the ones hosted on [`hf-internal-testing`](https://huggingface.co/hf-internal-testing) in which to place these files and reference
|
||||
them by URL. We recommend putting them in the following dataset: [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images).
|
||||
If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images
|
||||
to this dataset.
|
||||
@@ -1,12 +0,0 @@
|
||||
- sections:
|
||||
- local: index
|
||||
title: LeRobot
|
||||
- local: installation
|
||||
title: Installation
|
||||
title: Get started
|
||||
- sections:
|
||||
- local: assemble_so101
|
||||
title: Assemble SO-101
|
||||
- local: getting_started_real_world_robot
|
||||
title: Getting Started with Real-World Robots
|
||||
title: "Tutorials"
|
||||
@@ -1,348 +0,0 @@
|
||||
# Assemble SO-101
|
||||
|
||||
In the steps below we explain how to assemble our flagship robot, the SO-101.
|
||||
|
||||
## Source the parts
|
||||
|
||||
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts,
|
||||
and advice if it's your first time printing or if you don't own a 3D printer.
|
||||
|
||||
Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
To install LeRobot follow our [Installation Guide](./installation)
|
||||
|
||||
## Configure motors
|
||||
|
||||
To configure the motors designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm. It's convenient to label them and write on each motor if it's for the follower `F` or for the leader `L` and it's ID from 1 to 6.
|
||||
|
||||
You now should plug the 5V or 12V power supply to the motor bus. 5V for the STS3215 7.4V motors and 12V for the STS3215 12V motors. Note that the leader arm always uses the 7.4V motors, so watch out that you plug in the right power supply if you have 12V and 7.4V motors, otherwise you might burn your motors! Now, connect the motor bus to your computer via USB. Note that the USB doesn't provide any power, and both the power supply and USB have to be plugged in.
|
||||
|
||||
### Find the USB ports associated to each arm
|
||||
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
##### Example outputs of script
|
||||
|
||||
<hfoptions id="example">
|
||||
<hfoption id="Mac">
|
||||
|
||||
Example output leader arm's port: `/dev/tty.usbmodem575E0031751`
|
||||
|
||||
```bash
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output follower arm port: `/dev/tty.usbmodem575E0032081`
|
||||
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Linux">
|
||||
|
||||
On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
Example output leader arm port: `/dev/ttyACM0`
|
||||
|
||||
```bash
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/ttyACM0', '/dev/ttyACM1']
|
||||
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/ttyACM0
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output follower arm port: `/dev/ttyACM1`
|
||||
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/ttyACM0', '/dev/ttyACM1']
|
||||
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/ttyACM1
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
#### Update config file
|
||||
|
||||
Now that you have your ports, update the **port** default values of [`SO101RobotConfig`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/robot_devices/robots/configs.py).
|
||||
You will find a class called `so101` where you can update the `port` values with your actual motor ports:
|
||||
```diff
|
||||
@RobotConfig.register_subclass("so101")
|
||||
@dataclass
|
||||
class So101RobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/so101"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
- port="/dev/tty.usbmodem58760431091",
|
||||
+ port="{ADD YOUR LEADER PORT}",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
- port="/dev/tty.usbmodem585A0076891",
|
||||
+ port="{ADD YOUR FOLLOWER PORT}",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
Here is a video of the process:
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-find-motorbus.mp4" type="video/mp4" />
|
||||
</video>
|
||||
</div>
|
||||
|
||||
## Step-by-Step Assembly Instructions
|
||||
|
||||
The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader however uses three differently geared motors to make sure it can both sustain its own weight and it can be moved without requiring much force. Which motor is needed for which joint is shown in table below.
|
||||
|
||||
| Leader-Arm Axis | Motor | Gear Ratio |
|
||||
|-----------------|:-------:|:----------:|
|
||||
| Base / Shoulder Yaw | 1 | 1 / 191 |
|
||||
| Shoulder Pitch | 2 | 1 / 345 |
|
||||
| Elbow | 3 | 1 / 191 |
|
||||
| Wrist Roll | 4 | 1 / 147 |
|
||||
| Wrist Pitch | 5 | 1 / 147 |
|
||||
| Gripper | 6 | 1 / 147 |
|
||||
|
||||
### Set motor IDs
|
||||
|
||||
Plug your motor in one of the two ports of the motor bus and run this script to set its ID to 1. Replace the text after --port to the corresponding control board port.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 2
|
||||
```
|
||||
|
||||
Redo this process for all your motors until ID 6. Do the same for the 6 motors of the leader arm, but make sure to change the power supply if you use motors with different voltage and make sure you give the right ID to the right motor according to the table above.
|
||||
|
||||
Here is a video of the process:
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-configure-motor.mp4" type="video/mp4" />
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Clean Parts
|
||||
Remove all support material from the 3D-printed parts, the easiest way to do this is using a small screwdriver to get underneath the support material.
|
||||
|
||||
### Joint 1
|
||||
|
||||
- Place the first motor into the base.
|
||||
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from bottom.
|
||||
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
|
||||
- Install both motor horns, securing the top horn with a M3x6mm screw.
|
||||
- Attach the shoulder part.
|
||||
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
|
||||
- Add the shoulder motor holder.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint1_v2.mp4" type="video/mp4" />
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 2
|
||||
|
||||
- Slide the second motor in from the top.
|
||||
- Fasten the second motor with 4 M2x6mm screws.
|
||||
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
|
||||
- Attach the upper arm with 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint2_v2.mp4" type="video/mp4" />
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 3
|
||||
|
||||
- Insert motor 3 and fasten using 4 M2x6mm screws
|
||||
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
|
||||
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint3_v2.mp4" type="video/mp4" />
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 4
|
||||
|
||||
- Slide over motor holder 4.
|
||||
- Slide in motor 4.
|
||||
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint4_v2.mp4" type="video/mp4" />
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 5
|
||||
|
||||
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
|
||||
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
|
||||
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint5_v2.mp4" type="video/mp4" />
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Gripper / Handle
|
||||
|
||||
<hfoptions id="assembly">
|
||||
<hfoption id="Follower">
|
||||
|
||||
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
|
||||
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
|
||||
- Attach the motor horns and again use a M3x6mm horn screw.
|
||||
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Gripper_v2.mp4" type="video/mp4" />
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Leader">
|
||||
|
||||
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
|
||||
- Attach the handle to motor 5 using 1 M2x6mm screw.
|
||||
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
|
||||
- Attach the follower trigger with 4 M3x6mm screws.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Leader_v2.mp4" type="video/mp4" />
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
##### Wiring
|
||||
|
||||
- Attach the motor controller on the back.
|
||||
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themselves and stay in place.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Wiring_v2.mp4" type="video/mp4" />
|
||||
</video>
|
||||
</div>
|
||||
|
||||
## Calibrate
|
||||
|
||||
Next, you'll need to calibrate your SO-101 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
|
||||
The calibration process is very important because it allows a neural network trained on one SO-101 robot to work on another.
|
||||
|
||||
#### Manual calibration of follower arm
|
||||
|
||||
You will need to move the follower arm to these positions sequentially, note that the rotated position is on the right side of the robot and you have to open the gripper fully.
|
||||
|
||||
| 1. Middle position | 2. Zero position | 3. Rotated position | 4. Rest position |
|
||||
| ------------ |------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/follower_middle.webp?raw=true" alt="SO-101 leader arm middle position" title="SO-101 leader arm middle position" style="width:100%;"> | <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/follower_zero.webp?raw=true" alt="SO-101 leader arm zero position" title="SO-101 leader arm zero position" style="width:100%;"> | <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/follower_rotated.webp?raw=true" alt="SO-101 leader arm rotated position" title="SO-101 leader arm rotated position" style="width:100%;"> | <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/follower_rest.webp?raw=true" alt="SO-101 leader arm rest position" title="SO-101 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_follower"]'
|
||||
```
|
||||
|
||||
#### Manual calibration of leader arm
|
||||
You will also need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Middle position | 2. Zero position | 3. Rotated position | 4. Rest position |
|
||||
| ------------ |------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/leader_middle.webp?raw=true" alt="SO-101 leader arm middle position" title="SO-101 leader arm middle position" style="width:100%;"> | <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/leader_zero.webp?raw=true" alt="SO-101 leader arm zero position" title="SO-101 leader arm zero position" style="width:100%;"> | <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/leader_rotated.webp?raw=true" alt="SO-101 leader arm rotated position" title="SO-101 leader arm rotated position" style="width:100%;"> | <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/leader_rest.webp?raw=true" alt="SO-101 leader arm rest position" title="SO-101 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_leader"]'
|
||||
```
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
|
||||
@@ -1,370 +0,0 @@
|
||||
# Getting Started with Real-World Robots
|
||||
|
||||
This tutorial will explain you how to train a neural network to autonomously control a real robot.
|
||||
|
||||
**You'll learn:**
|
||||
1. How to record and visualize your dataset.
|
||||
2. How to train a policy using your data and prepare it for evaluation.
|
||||
3. How to evaluate your policy and visualize the results.
|
||||
|
||||
By following these steps, you'll be able to replicate tasks like picking up a Lego block and placing it in a bin with a high success rate, as demonstrated in [this video](https://x.com/RemiCadene/status/1814680760592572934).
|
||||
|
||||
This tutorial is specifically made for the affordable [SO-101](https://github.com/TheRobotStudio/SO-ARM100) robot, but it contains additional information to be easily adapted to various types of robots like [Aloha bimanual robot](https://aloha-2.github.io) by changing some configurations. The SO-101 consists of a leader arm and a follower arm, each with 6 motors. It can work with one or several cameras to record the scene, which serve as visual sensors for the robot.
|
||||
|
||||
During the data collection phase, you will control the follower arm by moving the leader arm. This process is known as "teleoperation." This technique is used to collect robot trajectories. Afterward, you'll train a neural network to imitate these trajectories and deploy the network to enable your robot to operate autonomously.
|
||||
|
||||
If you encounter any issues at any step of the tutorial, feel free to seek help on [Discord](https://discord.com/invite/s3KuuzsPFb) or don't hesitate to iterate with us on the tutorial by creating issues or pull requests.
|
||||
|
||||
## Setup and Calibrate
|
||||
|
||||
If you haven't yet setup and calibrate the SO-101 follow these steps:
|
||||
1. [Find ports and update config file](./assemble_so101#find-the-usb-ports-associated-to-each-arm)
|
||||
2. [Calibrate](./assemble_so101#calibrate)
|
||||
|
||||
## Teleoperate
|
||||
|
||||
Run this simple script to teleoperate your robot (it won't connect and display the cameras):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
The teleoperate command will automatically:
|
||||
1. Identify any missing calibrations and initiate the calibration procedure.
|
||||
2. Connect the robot and start teleoperation.
|
||||
|
||||
## Setup Cameras
|
||||
|
||||
To connect a camera you have three options:
|
||||
1. OpenCVCamera which allows us to use any camera: usb, realsense, laptop webcam
|
||||
2. iPhone camera with MacOS
|
||||
3. Phone camera on Linux
|
||||
|
||||
### Use OpenCVCamera
|
||||
|
||||
The [`OpenCVCamera`](../lerobot/common/robot_devices/cameras/opencv.py) class allows you to efficiently record frames from most cameras using the [`opencv2`](https://docs.opencv.org) library. For more details on compatibility, see [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
|
||||
To instantiate an [`OpenCVCamera`](../lerobot/common/robot_devices/cameras/opencv.py), you need a camera index (e.g. `OpenCVCamera(camera_index=0)`). When you only have one camera like a webcam of a laptop, the camera index is usually `0` but it might differ, and the camera index might change if you reboot your computer or re-plug your camera. This behavior depends on your operating system.
|
||||
|
||||
To find the camera indices, run the following utility script, which will save a few frames from each detected camera:
|
||||
```bash
|
||||
python lerobot/common/robot_devices/cameras/opencv.py \
|
||||
--images-dir outputs/images_from_opencv_cameras
|
||||
```
|
||||
|
||||
The output will look something like this if you have two cameras connected:
|
||||
```
|
||||
Mac or Windows detected. Finding available camera indices through scanning all indices from 0 to 60
|
||||
[...]
|
||||
Camera found at index 0
|
||||
Camera found at index 1
|
||||
[...]
|
||||
Connecting cameras
|
||||
OpenCVCamera(0, fps=30.0, width=1920.0, height=1080.0, color_mode=rgb)
|
||||
OpenCVCamera(1, fps=24.0, width=1920.0, height=1080.0, color_mode=rgb)
|
||||
Saving images to outputs/images_from_opencv_cameras
|
||||
Frame: 0000 Latency (ms): 39.52
|
||||
[...]
|
||||
Frame: 0046 Latency (ms): 40.07
|
||||
Images have been saved to outputs/images_from_opencv_cameras
|
||||
```
|
||||
|
||||
Check the saved images in `outputs/images_from_opencv_cameras` to identify which camera index corresponds to which physical camera (e.g. `0` for `camera_00` or `1` for `camera_01`):
|
||||
```
|
||||
camera_00_frame_000000.png
|
||||
[...]
|
||||
camera_00_frame_000047.png
|
||||
camera_01_frame_000000.png
|
||||
[...]
|
||||
camera_01_frame_000047.png
|
||||
```
|
||||
|
||||
Note: Some cameras may take a few seconds to warm up, and the first frame might be black or green.
|
||||
|
||||
Now that you have the camera indexes, you should specify the camera's in the config.
|
||||
|
||||
### Use your phone
|
||||
<hfoptions id="use phone">
|
||||
<hfoption id="Mac">
|
||||
|
||||
To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
|
||||
- Ensure your Mac is running macOS 13 or later, and your iPhone is on iOS 16 or later.
|
||||
- Sign in both devices with the same Apple ID.
|
||||
- Connect your devices with a USB cable or turn on Wi-Fi and Bluetooth for a wireless connection.
|
||||
|
||||
For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac).
|
||||
|
||||
Your iPhone should be detected automatically when running the camera setup script in the next section.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Linux">
|
||||
|
||||
If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera
|
||||
|
||||
1. *Install `v4l2loopback-dkms` and `v4l-utils`*. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using:
|
||||
```python
|
||||
sudo apt install v4l2loopback-dkms v4l-utils
|
||||
```
|
||||
2. *Install [DroidCam](https://droidcam.app) on your phone*. This app is available for both iOS and Android.
|
||||
3. *Install [OBS Studio](https://obsproject.com)*. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org):
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio
|
||||
```
|
||||
4. *Install the DroidCam OBS plugin*. This plugin integrates DroidCam with OBS Studio. Install it with:
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio.Plugin.DroidCam
|
||||
```
|
||||
5. *Start OBS Studio*. Launch with:
|
||||
```python
|
||||
flatpak run com.obsproject.Studio
|
||||
```
|
||||
6. *Add your phone as a source*. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`.
|
||||
7. *Adjust resolution settings*. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in.
|
||||
8. *Start virtual camera*. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide).
|
||||
9. *Verify the virtual camera setup*. Use `v4l2-ctl` to list the devices:
|
||||
```python
|
||||
v4l2-ctl --list-devices
|
||||
```
|
||||
You should see an entry like:
|
||||
```
|
||||
VirtualCam (platform:v4l2loopback-000):
|
||||
/dev/video1
|
||||
```
|
||||
10. *Check the camera resolution*. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`.
|
||||
```python
|
||||
v4l2-ctl -d /dev/video1 --get-fmt-video
|
||||
```
|
||||
You should see an entry like:
|
||||
```
|
||||
>>> Format Video Capture:
|
||||
>>> Width/Height : 640/480
|
||||
>>> Pixel Format : 'YUYV' (YUYV 4:2:2)
|
||||
```
|
||||
|
||||
Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed.
|
||||
|
||||
If everything is set up correctly, you can proceed with the rest of the tutorial.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Teleoperate with cameras
|
||||
|
||||
We can now teleoperate again while at the same time visualizing the cameras and joint positions with `rerun`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--control.type=teleoperate
|
||||
--control.display_data=true
|
||||
```
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with SO-101.
|
||||
|
||||
We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens).
|
||||
|
||||
Add your token to the cli by running this command:
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Now you can record a dataset, to record 2 episodes and upload your dataset to the hub execute this command:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/so101_test \
|
||||
--control.tags='["so101","tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
You will see a lot of lines appearing like this one:
|
||||
```
|
||||
INFO 2024-08-10 15:02:58 ol_robot.py:219 dt:33.34 (30.0hz) dtRlead: 5.06 (197.5hz) dtWfoll: 0.25 (3963.7hz) dtRfoll: 6.22 (160.7hz) dtRlaptop: 32.57 (30.7hz) dtRphone: 33.84 (29.5hz)
|
||||
```
|
||||
|
||||
| Field | Meaning |
|
||||
|:---|:---|
|
||||
| `2024-08-10 15:02:58` | Timestamp when `print` was called. |
|
||||
| `ol_robot.py:219` | Source file and line number of the `print` call (`lerobot/scripts/control_robot.py` at line `219`). |
|
||||
| `dt: 33.34 (30.0 Hz)` | Delta time (ms) between teleop steps (target: 30.0 Hz, `--fps 30`). Yellow if step is too slow. |
|
||||
| `dtRlead: 5.06 (197.5 Hz)` | Delta time (ms) for reading present position from the **leader arm**. |
|
||||
| `dtWfoll: 0.25 (3963.7 Hz)` | Delta time (ms) for writing goal position to the **follower arm** (asynchronous). |
|
||||
| `dtRfoll: 6.22 (160.7 Hz)` | Delta time (ms) for reading present position from the **follower arm**. |
|
||||
| `dtRlaptop: 32.57 (30.7 Hz)` | Delta time (ms) for capturing an image from the **laptop camera** (async thread). |
|
||||
| `dtRphone: 33.84 (29.5 Hz)` | Delta time (ms) for capturing an image from the **phone camera** (async thread). |
|
||||
|
||||
|
||||
#### Dataset upload
|
||||
Locally your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/so101_test`). At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running:
|
||||
```bash
|
||||
echo https://huggingface.co/datasets/${HF_USER}/so101_test
|
||||
```
|
||||
Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||
|
||||
You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot).
|
||||
|
||||
#### Record function
|
||||
|
||||
The `record` function provides a suite of tools for capturing and managing data during robot operation:
|
||||
|
||||
##### 1. Frame Capture and Video Encoding
|
||||
- Frames from cameras are saved to disk during recording.
|
||||
- At the end of each episode, frames are encoded into video files.
|
||||
|
||||
##### 2. Data Storage
|
||||
- Data is stored using the `LeRobotDataset` format.
|
||||
- By default, the dataset is pushed to your Hugging Face page.
|
||||
- To disable uploading, use `--control.push_to_hub=false`.
|
||||
|
||||
##### 3. Checkpointing and Resuming
|
||||
- Checkpoints are automatically created during recording.
|
||||
- If an issue occurs, you can resume by re-running the same command with `--control.resume=true`.
|
||||
- To start recording from scratch, **manually delete** the dataset directory.
|
||||
|
||||
##### 4. Recording Parameters
|
||||
Set the flow of data recording using command-line arguments:
|
||||
- `--control.warmup_time_s=10`
|
||||
Number of seconds before starting data collection (default: **10 seconds**).
|
||||
Allows devices to warm up and synchronize.
|
||||
- `--control.episode_time_s=60`
|
||||
Duration of each data recording episode (default: **60 seconds**).
|
||||
- `--control.reset_time_s=60`
|
||||
Duration for resetting the environment after each episode (default: **60 seconds**).
|
||||
- `--control.num_episodes=50`
|
||||
Total number of episodes to record (default: **50**).
|
||||
|
||||
##### 5. Keyboard Controls During Recording
|
||||
Control the data recording flow using keyboard shortcuts:
|
||||
- Press **Right Arrow (`→`)**: Early stop the current episode or reset time and move to the next.
|
||||
- Press **Left Arrow (`←`)**: Cancel the current episode and re-record it.
|
||||
- Press **Escape (`ESC`)**: Immediately stop the session, encode videos, and upload the dataset.
|
||||
|
||||
#### Tips for gathering data
|
||||
|
||||
Once you're comfortable with data recording, you can create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings. Also make sure the object you are manipulating is visible on the camera's. A good rule of thumb is you should be able to do the task yourself by only looking at the camera images.
|
||||
|
||||
In the following sections, you’ll train your neural network. After achieving reliable grasping performance, you can start introducing more variations during data collection, such as additional grasp locations, different grasping techniques, and altering camera positions.
|
||||
|
||||
Avoid adding too much variation too quickly, as it may hinder your results.
|
||||
|
||||
|
||||
#### Troubleshooting:
|
||||
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/so101_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can visualize it locally with (via a window in the browser `http://127.0.0.1:9090` with the visualization tool):
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/so101_test \
|
||||
--local-files-only 1
|
||||
```
|
||||
|
||||
This will launch a local web server that looks like this:
|
||||
<div style="text-align:center;">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/visualize_dataset_html.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="100%"></img>
|
||||
</div>
|
||||
|
||||
## Replay an episode
|
||||
|
||||
A useful feature is the `replay` function, which allows to replay on your robot any episode that you've recorded or episodes from any dataset out there. This function helps you test the repeatability of your robot's actions and assess transferability across robots of the same model.
|
||||
|
||||
You can replay the first episode on your robot with:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/so101_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com).
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/so101_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_so101_test \
|
||||
--job_name=act_so101_test \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain the command:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so101_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so101_test/checkpoints`.
|
||||
|
||||
To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so101_test` policy:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/act_so101_test/checkpoints/last/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
```
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
```bash
|
||||
huggingface-cli upload ${HF_USER}/act_so101_test \
|
||||
outputs/train/act_so101_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
You can also upload intermediate checkpoints with:
|
||||
```bash
|
||||
CKPT=010000
|
||||
huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model
|
||||
```
|
||||
|
||||
## Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/eval_act_so101_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_so101_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`).
|
||||
@@ -1,19 +0,0 @@
|
||||
<div class="flex justify-center">
|
||||
<a target="_blank" href="https://huggingface.co/lerobot">
|
||||
<img alt="HuggingFace Expert Acceleration Program" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-logo-thumbnail.png" style="width: 100%"></img>
|
||||
</a>
|
||||
</div>
|
||||
|
||||
# LeRobot
|
||||
|
||||
**State-of-the-art machine learning for real-world robotics**
|
||||
|
||||
🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models.
|
||||
|
||||
🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning.
|
||||
|
||||
🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulated environments so that everyone can get started.
|
||||
|
||||
🤗 LeRobot hosts pretrained models and datasets on the LeRobot HuggingFace page.
|
||||
|
||||
Join the LeRobot community on [Discord](https://discord.gg/s3KuuzsPFb)
|
||||
@@ -1,84 +0,0 @@
|
||||
# Installation
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
Download our source code:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
```
|
||||
|
||||
Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install)
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
```
|
||||
|
||||
Now restart the shell by running:
|
||||
<hfoptions id="shell_restart">
|
||||
<hfoption id="Windows">
|
||||
|
||||
```bash
|
||||
source ~/.bashrc
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="Mac">
|
||||
|
||||
```bash
|
||||
source ~/.bash_profile
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="zshell">
|
||||
|
||||
```bash
|
||||
source ~/.zshrc
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
|
||||
```bash
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can:
|
||||
> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
|
||||
> ```bash
|
||||
> conda install ffmpeg=7.1.1 -c conda-forge
|
||||
> ```
|
||||
> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
cd lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
To install these for linux run:
|
||||
```bash
|
||||
sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config
|
||||
```
|
||||
For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
|
||||
## Sim
|
||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
||||
- [xarm](https://github.com/huggingface/gym-xarm)
|
||||
- [pusht](https://github.com/huggingface/gym-pusht)
|
||||
|
||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||
```bash
|
||||
pip install -e ".[aloha, pusht]"
|
||||
```
|
||||
|
||||
## W&B
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||
```bash
|
||||
wandb login
|
||||
```
|
||||
@@ -128,7 +128,7 @@ sudo chmod 666 /dev/ttyACM1
|
||||
#### d. Update config file
|
||||
|
||||
IMPORTANTLY: Now that you have your ports, update the **port** default values of [`SO100RobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like:
|
||||
```diff
|
||||
```python
|
||||
@RobotConfig.register_subclass("so100")
|
||||
@dataclass
|
||||
class So100RobotConfig(ManipulatorRobotConfig):
|
||||
@@ -141,8 +141,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
- port="/dev/tty.usbmodem58760431091",
|
||||
+ port="{ADD YOUR LEADER PORT}",
|
||||
port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
@@ -159,8 +158,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
- port="/dev/tty.usbmodem585A0076891",
|
||||
+ port="{ADD YOUR FOLLOWER PORT}",
|
||||
port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
@@ -447,16 +445,18 @@ For the leader configuration, perform **Steps 1–23**. Make sure that you remov
|
||||
|
||||
## E. Calibrate
|
||||
|
||||
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
|
||||
The calibration process is very important because it allows a neural network trained on one SO-100 robot to work on another.
|
||||
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
|
||||
|
||||
#### Manual calibration of follower arm
|
||||
#### a. Manual calibration of follower arm
|
||||
|
||||
You will need to move the follower arm to these positions sequentially, note that the rotated position is on the right side of the robot and you have to open the gripper fully.
|
||||
> [!IMPORTANT]
|
||||
> Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
| 1. Middle position | 2. Zero position | 3. Rotated position | 4. Rest position |
|
||||
| ------------ |------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| <img src="../media/so101/follower_middle.webp?raw=true" alt="SO-101 leader arm middle position" title="SO-101 leader arm middle position" style="width:100%;"> | <img src="../media/so101/follower_zero.webp?raw=true" alt="SO-101 leader arm zero position" title="SO-101 leader arm zero position" style="width:100%;"> | <img src="../media/so101/follower_rotated.webp?raw=true" alt="SO-101 leader arm rotated position" title="SO-101 leader arm rotated position" style="width:100%;"> | <img src="../media/so101/follower_rest.webp?raw=true" alt="SO-101 leader arm rest position" title="SO-101 leader arm rest position" style="width:100%;"> |
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| <img src="../media/so100/follower_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/so100/follower_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/so100/follower_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
```bash
|
||||
@@ -467,12 +467,12 @@ python lerobot/scripts/control_robot.py \
|
||||
--control.arms='["main_follower"]'
|
||||
```
|
||||
|
||||
#### Manual calibration of leader arm
|
||||
You will also need to move the leader arm to these positions sequentially:
|
||||
#### b. Manual calibration of leader arm
|
||||
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Middle position | 2. Zero position | 3. Rotated position | 4. Rest position |
|
||||
| ------------ |------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| <img src="../media/so101/leader_middle.webp?raw=true" alt="SO-100 leader arm middle position" title="SO-100 leader arm middle position" style="width:100%;"> | <img src="../media/so101/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so101/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so101/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
```bash
|
||||
@@ -580,7 +580,7 @@ python lerobot/scripts/train.py \
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
|
||||
@@ -134,7 +134,7 @@ First we will assemble the two SO100 arms. One to attach to the mobile base and
|
||||
|
||||
## SO100 Arms
|
||||
### Configure motors
|
||||
The instructions for configuring the motors can be found [Here](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#c-configure-the-motors) in step C of the SO100 tutorial. Besides the ID's for the arm motors we also need to set the motor ID's for the mobile base. These need to be in a specific order to work. Below an image of the motor ID's and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ID's for the wheels are 7, 8 and 9.
|
||||
The instructions for configuring the motors can be found [Here](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#c-configure-the-motors) in step C of the SO100 tutorial. Besides the ID's for the arm motors we also need to set the motor ID's for the mobile base. These needs to be in a specific order to work. Below an image of the motor ID's and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ID's for the wheels are 7, 8 and 9.
|
||||
|
||||
<img src="../media/lekiwi/motor_ids.webp?raw=true" alt="Motor ID's for mobile robot" title="Motor ID's for mobile robot" width="60%">
|
||||
|
||||
@@ -567,7 +567,7 @@ python lerobot/scripts/train.py \
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ cd ~/lerobot && pip install -e ".[feetech]"
|
||||
|
||||
## Configure the motors
|
||||
|
||||
Follow step 1 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the use of our scripts below.
|
||||
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the use of our scripts below.
|
||||
|
||||
**Find USB ports associated to your arms**
|
||||
To find the correct ports for each arm, run the utility script twice:
|
||||
@@ -141,7 +141,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Note: These motors are currently limited. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
@@ -164,7 +164,7 @@ Try to avoid rotating the motor while doing so to keep position 2048 set during
|
||||
|
||||
## Assemble the arms
|
||||
|
||||
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). The first arm should take a bit more than 1 hour to assemble, but once you get used to it, you can do it under 1 hour for the second arm.
|
||||
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
|
||||
## Calibrate
|
||||
|
||||
@@ -301,7 +301,7 @@ python lerobot/scripts/train.py \
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
|
||||
@@ -1,711 +0,0 @@
|
||||
# Assemble and use SO-101
|
||||
|
||||
In the steps below we explain how to assemble and use our flagship robot, the SO-101 with LeRobot 🤗.
|
||||
|
||||
## Source the parts
|
||||
|
||||
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts,
|
||||
and advice if it's your first time printing or if you don't own a 3D printer.
|
||||
|
||||
Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
> [!TIP]
|
||||
> We use the Command Prompt (cmd) quite a lot. If you are not comfortable using the cmd or want to brush up using the command line you can have a look here: [Command line crash course](https://developer.mozilla.org/en-US/docs/Learn_web_development/Getting_started/Environment_setup/Command_line)
|
||||
|
||||
Download our source code:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
```
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install):
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
```
|
||||
Now restart the shell by running:
|
||||
|
||||
##### Windows:
|
||||
```bash
|
||||
`source ~/.bashrc`
|
||||
```
|
||||
|
||||
##### Mac:
|
||||
```bash
|
||||
`source ~/.bash_profile`
|
||||
```
|
||||
|
||||
##### zshell:
|
||||
```bash
|
||||
`source ~/.zshrc`
|
||||
```
|
||||
|
||||
Then activate your conda environment, you have to do this each time you open a shell to use lerobot:
|
||||
```bash
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
When using `miniconda`, install `ffmpeg` in your environment:
|
||||
```bash
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can:
|
||||
> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
|
||||
> ```bash
|
||||
> conda install ffmpeg=7.1.1 -c conda-forge
|
||||
> ```
|
||||
> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
cd lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run: `sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
|
||||
|
||||
## Configure motors
|
||||
|
||||
To configure the motors designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm. It's convenient to label them and write on each motor if it's for the follower `F` or for the leader `L` and it's ID from 1 to 6.
|
||||
|
||||
You now should plug the 5V or 12V power supply to the motor bus. 5V for the STS3215 7.4V motors and 12V for the STS3215 12V motors. Note that the leader arm always uses the 7.4V motors, so watch out that you plug in the right power supply if you have 12V and 7.4V motors, otherwise you might burn your motors! Now, connect the motor bus to your computer via USB. Note that the USB doesn't provide any power, and both the power supply and USB have to be plugged in.
|
||||
|
||||
### Find the USB ports associated to each arm
|
||||
|
||||
To find the port for each bus servo adapter, run this script:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
#### Example outputs of script
|
||||
|
||||
##### Mac:
|
||||
Example output leader arm's port: `/dev/tty.usbmodem575E0031751`
|
||||
|
||||
```bash
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output follower arm port: `/dev/tty.usbmodem575E0032081`
|
||||
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
##### Linux:
|
||||
On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
Example output leader arm port: `/dev/ttyACM0`
|
||||
|
||||
```bash
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/ttyACM0', '/dev/ttyACM1']
|
||||
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/ttyACM0
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output follower arm port: `/dev/ttyACM1`
|
||||
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/ttyACM0', '/dev/ttyACM1']
|
||||
Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this MotorsBus is /dev/ttyACM1
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
#### Update config file
|
||||
|
||||
Now that you have your ports, update the **port** default values of [`SO101RobotConfig`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/robot_devices/robots/configs.py).
|
||||
You will find a class called `so101` where you can update the `port` values with your actual motor ports:
|
||||
```diff
|
||||
@RobotConfig.register_subclass("so101")
|
||||
@dataclass
|
||||
class So101RobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/so101"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
- port="/dev/tty.usbmodem58760431091",
|
||||
+ port="{ADD YOUR LEADER PORT}",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
- port="/dev/tty.usbmodem585A0076891",
|
||||
+ port="{ADD YOUR FOLLOWER PORT}",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
Here is a video of the process:
|
||||
|
||||
<video controls width="640" src="https://github.com/user-attachments/assets/fc45d756-31bb-4a61-b973-a87d633d08a7" type="video/mp4"></video>
|
||||
|
||||
### Set motor IDs
|
||||
|
||||
Now we need to set the motor ID for each motor. Plug your motor in only one of the two ports of the motor bus and run this script to set its ID to 1. Replace the text after --port to the corresponding control board port.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 2
|
||||
```
|
||||
|
||||
Redo this process for all your motors until ID 6. Do the same for the 6 motors of the leader arm, but make sure to change the power supply if you use motors with different voltage.
|
||||
|
||||
Here is a video of the process:
|
||||
|
||||
<video controls width="640" src="https://github.com/user-attachments/assets/b31c115f-e706-4dcd-b7f1-4535da62416d" type="video/mp4"></video>
|
||||
|
||||
## Step-by-Step Assembly Instructions
|
||||
|
||||
The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader however uses three differently geared motors to make sure it can both sustain its own weight and it can be moved without requiring much force. Which motor is needed for which joint is shown in table below.
|
||||
|
||||
| Leader-Arm Axis | Motor | Gear Ratio |
|
||||
|-----------------|:-------:|:----------:|
|
||||
| Base / Shoulder Yaw | 1 | 1 / 191 |
|
||||
| Shoulder Pitch | 2 | 1 / 345 |
|
||||
| Elbow | 3 | 1 / 191 |
|
||||
| Wrist Roll | 4 | 1 / 147 |
|
||||
| Wrist Pitch | 5 | 1 / 147 |
|
||||
| Gripper | 6 | 1 / 147 |
|
||||
|
||||
|
||||
### Clean Parts
|
||||
Remove all support material from the 3D-printed parts.
|
||||
|
||||
### Joint 1
|
||||
|
||||
- Place the first motor into the base.
|
||||
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from bottom.
|
||||
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
|
||||
- Install both motor horns, securing the top horn with a M3x6mm screw.
|
||||
- Attach the shoulder part.
|
||||
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
|
||||
- Add the shoulder motor holder.
|
||||
|
||||
<video controls width="640" src="https://github.com/user-attachments/assets/b0ee9dee-a2d0-445b-8489-02ebecb3d639" type="video/mp4"></video>
|
||||
|
||||
### Joint 2
|
||||
|
||||
- Slide the second motor in from the top.
|
||||
- Fasten the second motor with 4 M2x6mm screws.
|
||||
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
|
||||
- Attach the upper arm with 4 M3x6mm screws on each side.
|
||||
|
||||
<video controls width="640" src="https://github.com/user-attachments/assets/32453dc2-5006-4140-9f56-f0d78eae5155" type="video/mp4"></video>
|
||||
|
||||
### Joint 3
|
||||
|
||||
- Insert motor 3 and fasten using 4 M2x6mm screws
|
||||
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
|
||||
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
|
||||
|
||||
<video controls width="640" src="https://github.com/user-attachments/assets/7384b9a7-a946-440c-b292-91391bcc4d6b" type="video/mp4"></video>
|
||||
|
||||
### Joint 4
|
||||
|
||||
- Slide over motor holder 4.
|
||||
- Slide in motor 4.
|
||||
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
|
||||
|
||||
<video controls width="640" src="https://github.com/user-attachments/assets/dca78ad0-7c36-4bdf-8162-c9ac42a1506f" type="video/mp4"></video>
|
||||
|
||||
### Joint 5
|
||||
|
||||
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
|
||||
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
|
||||
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
|
||||
|
||||
<video controls width="640" src="https://github.com/user-attachments/assets/55f5d245-976d-49ff-8b4a-59843c441b12" type="video/mp4"></video>
|
||||
|
||||
### Gripper / Handle
|
||||
|
||||
#### Follower:
|
||||
|
||||
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
|
||||
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
|
||||
- Attach the motor horns and again use a M3x6mm horn screw.
|
||||
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
|
||||
|
||||
<video controls width="640" src="https://github.com/user-attachments/assets/6f766aa9-cfae-4388-89e7-0247f198c086" type="video/mp4"></video>
|
||||
|
||||
#### Leader:
|
||||
|
||||
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
|
||||
- Attach the handle to motor 5 using 1 M2x6mm screw.
|
||||
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
|
||||
- Attach the follower trigger with 4 M3x6mm screws.
|
||||
|
||||
<video controls width="640" src="https://github.com/user-attachments/assets/1308c93d-2ef1-4560-8e93-a3812568a202" type="video/mp4"></video>
|
||||
|
||||
##### Wiring
|
||||
|
||||
- Attach the motor controller on the back.
|
||||
- Then insert all wires, use the wire guides everywhere to make sure the wires don't unplug themselves and stay in place.
|
||||
|
||||
<video controls width="640" src="https://github.com/user-attachments/assets/4c2cacfd-9276-4ee4-8bf2-ba2492667b78" type="video/mp4"></video>
|
||||
|
||||
## Calibrate
|
||||
|
||||
Next, you'll need to calibrate your SO-101 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
|
||||
The calibration process is very important because it allows a neural network trained on one SO-101 robot to work on another.
|
||||
|
||||
#### Manual calibration of follower arm
|
||||
|
||||
You will need to move the follower arm to these positions sequentially, note that the rotated position is on the right side of the robot and you have to open the gripper fully.
|
||||
|
||||
| 1. Middle position | 2. Zero position | 3. Rotated position | 4. Rest position |
|
||||
| ------------ |------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| <img src="../media/so101/follower_middle.webp?raw=true" alt="SO-101 leader arm middle position" title="SO-101 leader arm middle position" style="width:100%;"> | <img src="../media/so101/follower_zero.webp?raw=true" alt="SO-101 leader arm zero position" title="SO-101 leader arm zero position" style="width:100%;"> | <img src="../media/so101/follower_rotated.webp?raw=true" alt="SO-101 leader arm rotated position" title="SO-101 leader arm rotated position" style="width:100%;"> | <img src="../media/so101/follower_rest.webp?raw=true" alt="SO-101 leader arm rest position" title="SO-101 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_follower"]'
|
||||
```
|
||||
|
||||
#### Manual calibration of leader arm
|
||||
You will also need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Middle position | 2. Zero position | 3. Rotated position | 4. Rest position |
|
||||
| ------------ |------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| <img src="../media/so101/leader_middle.webp?raw=true" alt="SO-101 leader arm middle position" title="SO-101 leader arm middle position" style="width:100%;"> | <img src="../media/so101/leader_zero.webp?raw=true" alt="SO-101 leader arm zero position" title="SO-101 leader arm zero position" style="width:100%;"> | <img src="../media/so101/leader_rotated.webp?raw=true" alt="SO-101 leader arm rotated position" title="SO-101 leader arm rotated position" style="width:100%;"> | <img src="../media/so101/leader_rest.webp?raw=true" alt="SO-101 leader arm rest position" title="SO-101 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=calibrate \
|
||||
--control.arms='["main_leader"]'
|
||||
```
|
||||
## Control your robot
|
||||
|
||||
Congrats 🎉, your robot is all set to learn a task on its own. Next we will explain to you how to train a neural network to autonomously control a real robot.
|
||||
|
||||
**You'll learn to:**
|
||||
1. How to record and visualize your dataset.
|
||||
2. How to train a policy using your data and prepare it for evaluation.
|
||||
3. How to evaluate your policy and visualize the results.
|
||||
|
||||
By following these steps, you'll be able to replicate tasks like picking up a Lego block and placing it in a bin with a high success rate, as demonstrated in [this video](https://x.com/RemiCadene/status/1814680760592572934).
|
||||
|
||||
This tutorial is specifically made for the affordable [SO-101](https://github.com/TheRobotStudio/SO-ARM100) robot, but it contains additional information to be easily adapted to various types of robots like [Aloha bimanual robot](https://aloha-2.github.io) by changing some configurations. The SO-101 consists of a leader arm and a follower arm, each with 6 motors. It can work with one or several cameras to record the scene, which serve as visual sensors for the robot.
|
||||
|
||||
During the data collection phase, you will control the follower arm by moving the leader arm. This process is known as "teleoperation." This technique is used to collect robot trajectories. Afterward, you'll train a neural network to imitate these trajectories and deploy the network to enable your robot to operate autonomously.
|
||||
|
||||
If you encounter any issues at any step of the tutorial, feel free to seek help on [Discord](https://discord.com/invite/s3KuuzsPFb) or don't hesitate to iterate with us on the tutorial by creating issues or pull requests.
|
||||
|
||||
## Teleoperate
|
||||
|
||||
Run this simple script to teleoperate your robot (it won't connect and display the cameras):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--robot.cameras='{}' \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
The teleoperate command will automatically:
|
||||
1. Identify any missing calibrations and initiate the calibration procedure.
|
||||
2. Connect the robot and start teleoperation.
|
||||
|
||||
## Setup Cameras
|
||||
|
||||
To connect a camera you have three options:
|
||||
1. OpenCVCamera which allows us to use any camera: usb, realsense, laptop webcam
|
||||
2. iPhone camera with MacOS
|
||||
3. Phone camera on Linux
|
||||
|
||||
### Use OpenCVCamera
|
||||
|
||||
The [`OpenCVCamera`](../lerobot/common/robot_devices/cameras/opencv.py) class allows you to efficiently record frames from most cameras using the [`opencv2`](https://docs.opencv.org) library. For more details on compatibility, see [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
|
||||
To instantiate an [`OpenCVCamera`](../lerobot/common/robot_devices/cameras/opencv.py), you need a camera index (e.g. `OpenCVCamera(camera_index=0)`). When you only have one camera like a webcam of a laptop, the camera index is usually `0` but it might differ, and the camera index might change if you reboot your computer or re-plug your camera. This behavior depends on your operating system.
|
||||
|
||||
To find the camera indices, run the following utility script, which will save a few frames from each detected camera:
|
||||
```bash
|
||||
python lerobot/common/robot_devices/cameras/opencv.py \
|
||||
--images-dir outputs/images_from_opencv_cameras
|
||||
```
|
||||
|
||||
The output will look something like this if you have two cameras connected:
|
||||
```
|
||||
Mac or Windows detected. Finding available camera indices through scanning all indices from 0 to 60
|
||||
[...]
|
||||
Camera found at index 0
|
||||
Camera found at index 1
|
||||
[...]
|
||||
Connecting cameras
|
||||
OpenCVCamera(0, fps=30.0, width=1920.0, height=1080.0, color_mode=rgb)
|
||||
OpenCVCamera(1, fps=24.0, width=1920.0, height=1080.0, color_mode=rgb)
|
||||
Saving images to outputs/images_from_opencv_cameras
|
||||
Frame: 0000 Latency (ms): 39.52
|
||||
[...]
|
||||
Frame: 0046 Latency (ms): 40.07
|
||||
Images have been saved to outputs/images_from_opencv_cameras
|
||||
```
|
||||
|
||||
Check the saved images in `outputs/images_from_opencv_cameras` to identify which camera index corresponds to which physical camera (e.g. `0` for `camera_00` or `1` for `camera_01`):
|
||||
```
|
||||
camera_00_frame_000000.png
|
||||
[...]
|
||||
camera_00_frame_000047.png
|
||||
camera_01_frame_000000.png
|
||||
[...]
|
||||
camera_01_frame_000047.png
|
||||
```
|
||||
|
||||
Note: Some cameras may take a few seconds to warm up, and the first frame might be black or green.
|
||||
|
||||
Now that you have the camera indexes, you should change them in the config. You can also change the fps, width or height of the camera.
|
||||
|
||||
The camera config is defined per robot, can be found here [`RobotConfig`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/robot_devices/robots/configs.py) and looks like this:
|
||||
```python
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"wrist": OpenCVCameraConfig(
|
||||
camera_index=0, <-- UPDATE HERE
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"base": OpenCVCameraConfig(
|
||||
camera_index=1, <-- UPDATE HERE
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Use your phone
|
||||
#### Mac:
|
||||
|
||||
To use your iPhone as a camera on macOS, enable the Continuity Camera feature:
|
||||
- Ensure your Mac is running macOS 13 or later, and your iPhone is on iOS 16 or later.
|
||||
- Sign in both devices with the same Apple ID.
|
||||
- Connect your devices with a USB cable or turn on Wi-Fi and Bluetooth for a wireless connection.
|
||||
|
||||
For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac).
|
||||
|
||||
Your iPhone should be detected automatically when running the camera setup script in the next section.
|
||||
|
||||
#### Linux:
|
||||
|
||||
If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera
|
||||
|
||||
1. *Install `v4l2loopback-dkms` and `v4l-utils`*. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using:
|
||||
```python
|
||||
sudo apt install v4l2loopback-dkms v4l-utils
|
||||
```
|
||||
2. *Install [DroidCam](https://droidcam.app) on your phone*. This app is available for both iOS and Android.
|
||||
3. *Install [OBS Studio](https://obsproject.com)*. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org):
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio
|
||||
```
|
||||
4. *Install the DroidCam OBS plugin*. This plugin integrates DroidCam with OBS Studio. Install it with:
|
||||
```python
|
||||
flatpak install flathub com.obsproject.Studio.Plugin.DroidCam
|
||||
```
|
||||
5. *Start OBS Studio*. Launch with:
|
||||
```python
|
||||
flatpak run com.obsproject.Studio
|
||||
```
|
||||
6. *Add your phone as a source*. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`.
|
||||
7. *Adjust resolution settings*. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in.
|
||||
8. *Start virtual camera*. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide).
|
||||
9. *Verify the virtual camera setup*. Use `v4l2-ctl` to list the devices:
|
||||
```python
|
||||
v4l2-ctl --list-devices
|
||||
```
|
||||
You should see an entry like:
|
||||
```
|
||||
VirtualCam (platform:v4l2loopback-000):
|
||||
/dev/video1
|
||||
```
|
||||
10. *Check the camera resolution*. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`.
|
||||
```python
|
||||
v4l2-ctl -d /dev/video1 --get-fmt-video
|
||||
```
|
||||
You should see an entry like:
|
||||
```
|
||||
>>> Format Video Capture:
|
||||
>>> Width/Height : 640/480
|
||||
>>> Pixel Format : 'YUYV' (YUYV 4:2:2)
|
||||
```
|
||||
|
||||
Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed.
|
||||
|
||||
If everything is set up correctly, you can proceed with the rest of the tutorial.
|
||||
|
||||
### Add wrist camera
|
||||
If you have an additional camera you can add a wrist camera to the SO101. There are already many premade wrist camera holders that you can find in the SO101 repo: [Wrist camera's](https://github.com/TheRobotStudio/SO-ARM100#wrist-cameras)
|
||||
|
||||
## Teleoperate with cameras
|
||||
|
||||
We can now teleoperate again while at the same time visualizing the cameras and joint positions with `rerun`.
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--control.type=teleoperate \
|
||||
--control.display_data=true
|
||||
```
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with SO-101.
|
||||
|
||||
We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens).
|
||||
|
||||
Add your token to the cli by running this command:
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Now you can record a dataset, to record 2 episodes and upload your dataset to the hub execute this command:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/so101_test \
|
||||
--control.tags='["so101","tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=2 \
|
||||
--control.display_data=true \
|
||||
--control.push_to_hub=true
|
||||
```
|
||||
|
||||
You will see a lot of lines appearing like this one:
|
||||
```
|
||||
INFO 2024-08-10 15:02:58 ol_robot.py:219 dt:33.34 (30.0hz) dtRlead: 5.06 (197.5hz) dtWfoll: 0.25 (3963.7hz) dtRfoll: 6.22 (160.7hz) dtRlaptop: 32.57 (30.7hz) dtRphone: 33.84 (29.5hz)
|
||||
```
|
||||
It contains:
|
||||
- `2024-08-10 15:02:58` which is the date and time of the call to the print function,
|
||||
- `ol_robot.py:219` which is the end of the file name and the line number where the print function is called (`lerobot/scripts/control_robot.py` line `219`).
|
||||
- `dt:33.34 (30.0hz)` which is the "delta time" or the number of milliseconds spent between the previous call to `robot.teleop_step(record_data=True)` and the current one, associated with the frequency (33.34 ms equals 30.0 Hz) ; note that we use `--fps 30` so we expect 30.0 Hz ; when a step takes more time, the line appears in yellow.
|
||||
- `dtRlead: 5.06 (197.5hz)` which is the delta time of reading the present position of the leader arm.
|
||||
- `dtWfoll: 0.25 (3963.7hz)` which is the delta time of writing the goal position on the follower arm ; writing is asynchronous so it takes less time than reading.
|
||||
- `dtRfoll: 6.22 (160.7hz)` which is the delta time of reading the present position on the follower arm.
|
||||
- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchronously.
|
||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
|
||||
|
||||
#### Dataset upload
|
||||
Locally your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/so101_test`). At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running:
|
||||
```bash
|
||||
echo https://huggingface.co/datasets/${HF_USER}/so101_test
|
||||
```
|
||||
Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||
|
||||
You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot).
|
||||
|
||||
#### Record function
|
||||
|
||||
The `record` function provides a suite of tools for capturing and managing data during robot operation:
|
||||
1. Set the flow of data recording using command line arguments:
|
||||
- `--control.warmup_time_s=10` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
|
||||
- `--control.episode_time_s=60` defines the number of seconds for data recording for each episode (60 seconds by default).
|
||||
- `--control.reset_time_s=60` defines the number of seconds for resetting the environment after each episode (60 seconds by default).
|
||||
- `--control.num_episodes=50` defines the number of episodes to record (50 by default).
|
||||
2. Control the flow during data recording using keyboard keys:
|
||||
- Press right arrow `->` at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording.
|
||||
- Press left arrow `<-` at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it.
|
||||
- Press escape `ESC` at any time during episode recording to end the session early and go straight to video encoding and dataset uploading.
|
||||
3. Checkpoints are done set during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You will need to manually delete the dataset directory if you want to start recording from scratch.
|
||||
|
||||
#### Tips for gathering data
|
||||
|
||||
Once you're comfortable with data recording, you can create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings. Also make sure the object you are manipulating is visible on the camera's. A good rule of thumb is you should be able to do the task yourself by only looking at the camera images.
|
||||
|
||||
In the following sections, you’ll train your neural network. After achieving reliable grasping performance, you can start introducing more variations during data collection, such as additional grasp locations, different grasping techniques, and altering camera positions.
|
||||
|
||||
Avoid adding too much variation too quickly, as it may hinder your results.
|
||||
|
||||
#### Troubleshooting:
|
||||
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/so101_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can visualize it locally with (via a window in the browser `http://127.0.0.1:9090` with the visualization tool):
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/so101_test \
|
||||
--local-files-only 1
|
||||
```
|
||||
|
||||
This will launch a local web server that looks like this:
|
||||
|
||||
<div style="text-align:center;">
|
||||
<img src="../media/tutorial/visualize_dataset_html.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="100%"></img>
|
||||
</div>
|
||||
|
||||
## Replay an episode
|
||||
|
||||
A useful feature is the `replay` function, which allows to replay on your robot any episode that you've recorded or episodes from any dataset out there. This function helps you test the repeatability of your robot's actions and assess transferability across robots of the same model.
|
||||
|
||||
You can replay the first episode on your robot with:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/so101_test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com).
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/so101_test \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_so101_test \
|
||||
--job_name=act_so101_test \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain the command:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so101_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so101_test/checkpoints`.
|
||||
|
||||
To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so101_test` policy:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/act_so101_test/checkpoints/last/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
```
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
```bash
|
||||
huggingface-cli upload ${HF_USER}/act_so101_test \
|
||||
outputs/train/act_so101_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
You can also upload intermediate checkpoints with:
|
||||
```bash
|
||||
CKPT=010000
|
||||
huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model
|
||||
```
|
||||
|
||||
## Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=so101 \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="Grasp a lego block and put it in the bin." \
|
||||
--control.repo_id=${HF_USER}/eval_act_so101_test \
|
||||
--control.tags='["tutorial"]' \
|
||||
--control.warmup_time_s=5 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=30 \
|
||||
--control.num_episodes=10 \
|
||||
--control.push_to_hub=true \
|
||||
--control.policy.path=outputs/train/act_so101_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`).
|
||||
@@ -92,11 +92,11 @@ print(dataset.hf_dataset)
|
||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset.
|
||||
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
|
||||
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
|
||||
# frame indices associated to the first episode:
|
||||
episode_index = 0
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# Then we grab all the image frames from the first camera:
|
||||
camera_key = dataset.meta.camera_keys[0]
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
||||
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
|
||||
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
|
||||
|
||||
It requires the installation of the 'gym_pusht' simulation environment. Install it by running:
|
||||
@@ -119,7 +119,7 @@ while not done:
|
||||
rewards.append(reward)
|
||||
frames.append(env.render())
|
||||
|
||||
# The rollout is considered done when the success state is reached (i.e. terminated is True),
|
||||
# The rollout is considered done when the success state is reach (i.e. terminated is True),
|
||||
# or the maximum number of iterations is reached (i.e. truncated is True)
|
||||
done = terminated | truncated | done
|
||||
step += 1
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This script demonstrates how to train Diffusion Policy on the PushT environment.
|
||||
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
|
||||
|
||||
Once you have trained a model with this script, you can try to evaluate it on
|
||||
examples/2_evaluate_pretrained_policy.py
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run.
|
||||
> **Note:** The following assumes you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu.
|
||||
> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu.
|
||||
|
||||
|
||||
## The training script
|
||||
@@ -23,7 +23,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
||||
|
||||
When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated to this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
|
||||
When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
|
||||
|
||||
Let's have a look at a simplified example. Amongst other attributes, the training config has the following attributes:
|
||||
```python
|
||||
@@ -43,7 +43,7 @@ class DatasetConfig:
|
||||
```
|
||||
|
||||
This creates a hierarchical relationship where, for example assuming we have a `cfg` instance of `TrainPipelineConfig`, we can access the `repo_id` value with `cfg.dataset.repo_id`.
|
||||
From the command line, we can specify this value by using a very similar syntax `--dataset.repo_id=repo/id`.
|
||||
From the command line, we can specify this value with using a very similar syntax `--dataset.repo_id=repo/id`.
|
||||
|
||||
By default, every field takes its default value specified in the dataclass. If a field doesn't have a default value, it needs to be specified either from the command line or from a config file – which path is also given in the command line (more in this below). In the example above, the `dataset` field doesn't have a default value which means it must be specified.
|
||||
|
||||
@@ -135,7 +135,7 @@ will start a training run with the same configuration used for training [lerobot
|
||||
|
||||
## Resume training
|
||||
|
||||
Being able to resume a training run is important in case it crashed or aborted for any reason. We'll demonstrate how to do that here.
|
||||
Being able to resume a training run is important in case it crashed or aborted for any reason. We'll demonstrate how to that here.
|
||||
|
||||
Let's reuse the command from the previous run and add a few more options:
|
||||
```bash
|
||||
|
||||
@@ -377,7 +377,7 @@ robot = ManipulatorRobot(robot_config)
|
||||
|
||||
The `KochRobotConfig` is used to set the associated settings and calibration process. For instance, we activate the torque of the gripper of the leader Koch v1.1 arm and position it at a 40 degree angle to use it as a trigger.
|
||||
|
||||
For the [Aloha bimanual robot](https://aloha-2.github.io), we would use `AlohaRobotConfig` to set different settings such as a secondary ID for shadow joints (shoulder, elbow). Specific to Aloha, LeRobot comes with default calibration files stored in `.cache/calibration/aloha_default`. Assuming the motors have been properly assembled, no manual calibration step is expected for Aloha.
|
||||
For the [Aloha bimanual robot](https://aloha-2.github.io), we would use `AlohaRobotConfig` to set different settings such as a secondary ID for shadow joints (shoulder, elbow). Specific to Aloha, LeRobot comes with default calibration files stored in in `.cache/calibration/aloha_default`. Assuming the motors have been properly assembled, no manual calibration step is expected for Aloha.
|
||||
|
||||
**Calibrate and Connect the ManipulatorRobot**
|
||||
|
||||
@@ -399,7 +399,7 @@ And here are the corresponding positions for the leader arm:
|
||||
|
||||
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
|
||||
|
||||
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask you to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to measure if the values changed negatively or positively.
|
||||
During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to measure if the values changed negatively or positively.
|
||||
|
||||
Finally, the rest position ensures that the follower and leader arms are roughly aligned after calibration, preventing sudden movements that could damage the motors when starting teleoperation.
|
||||
|
||||
@@ -622,7 +622,7 @@ camera_01_frame_000047.png
|
||||
|
||||
Note: Some cameras may take a few seconds to warm up, and the first frame might be black or green.
|
||||
|
||||
Finally, run this code to instantiate and connect your camera:
|
||||
Finally, run this code to instantiate and connectyour camera:
|
||||
```python
|
||||
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
@@ -99,7 +99,7 @@ This is equivalent to running `stretch_robot_home.py`
|
||||
> **Note:** If you run any of the LeRobot scripts below and Stretch is not properly homed, it will automatically home/calibrate first.
|
||||
|
||||
**Teleoperate**
|
||||
Before trying teleoperation, you need to activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
|
||||
Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation).
|
||||
|
||||
Now try out teleoperation (see above documentation to learn about the gamepad controls):
|
||||
|
||||
|
||||
@@ -142,7 +142,7 @@ python lerobot/scripts/train.py \
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
|
||||
|
||||
# Get the index of the first observation in the first episode
|
||||
first_idx = dataset.episode_data_index["from"][0].item()
|
||||
first_idx = dataset.meta.episodes["dataset_from_index"][0]
|
||||
|
||||
# Get the frame corresponding to the first camera
|
||||
frame = dataset[first_idx][dataset.meta.camera_keys[0]]
|
||||
|
||||
@@ -66,7 +66,7 @@ def main():
|
||||
print(f"Number of episodes in full dataset: {total_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
|
||||
# - Load train and val datasets
|
||||
# - Load train an val datasets
|
||||
train_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps
|
||||
)
|
||||
|
||||
144
examples/port_datasets/droid_rlds/README.md
Normal file
@@ -0,0 +1,144 @@
|
||||
# Port DROID 1.0.1 dataset to LeRobotDataset
|
||||
|
||||
## Download
|
||||
|
||||
TODO
|
||||
|
||||
It will take 2 TB in your local disk.
|
||||
|
||||
## Port on a single computer
|
||||
|
||||
First, install tensorflow dataset utilities to read from raw files:
|
||||
```bash
|
||||
pip install tensorflow
|
||||
pip install tensorflow_datasets
|
||||
```
|
||||
|
||||
Then run this script to start porting the dataset:
|
||||
```bash
|
||||
python examples/port_datasets/droid_rlds/port_droid.py \
|
||||
--raw-dir /your/data/droid/1.0.1 \
|
||||
--repo-id your_id/droid_1.0.1 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
It will take 400GB in your local disk.
|
||||
|
||||
As usual, your LeRobotDataset will be stored in your huggingface/lerobot cache folder.
|
||||
|
||||
WARNING: it will take 7 days for porting the dataset locally and 3 days to upload, so we will need to parallelize over multiple nodes on a slurm cluster.
|
||||
|
||||
NOTE: For development, run this script to start porting a shard:
|
||||
```bash
|
||||
python examples/port_datasets/droid_rlds/port.py \
|
||||
--raw-dir /your/data/droid/1.0.1 \
|
||||
--repo-id your_id/droid_1.0.1 \
|
||||
--num-shards 2048 \
|
||||
--shard-index 0
|
||||
```
|
||||
|
||||
## Port over SLURM
|
||||
|
||||
Install slurm utilities from Hugging Face:
|
||||
```bash
|
||||
pip install datatrove
|
||||
```
|
||||
|
||||
|
||||
### 1. Port one shard per job
|
||||
|
||||
Run this script to start porting shards of the dataset:
|
||||
```bash
|
||||
python examples/port_datasets/droid_rlds/slurm_port_shards.py \
|
||||
--raw-dir /your/data/droid/1.0.1 \
|
||||
--repo-id your_id/droid_1.0.1 \
|
||||
--logs-dir /your/logs \
|
||||
--job-name port_droid \
|
||||
--partition your_partition \
|
||||
--workers 2048 \
|
||||
--cpus-per-task 8 \
|
||||
--mem-per-cpu 1950M
|
||||
```
|
||||
|
||||
**Note on how to set your command line arguments**
|
||||
|
||||
Regarding `--partition`, find yours by running:
|
||||
```bash
|
||||
info --format="%R"`
|
||||
```
|
||||
and select the CPU partition if you have one. No GPU needed.
|
||||
|
||||
Regarding `--workers`, it is the number of slurm jobs you will launch in parallel. 2048 is the maximum number, since there is 2048 shards in Droid. This big number will certainly max-out your cluster.
|
||||
|
||||
Regarding `--cpus-per-task` and `--mem-per-cpu`, by default it will use ~16GB of RAM (8*1950M) which is recommended to load the raw frames and 8 CPUs which can be useful to parallelize the encoding of the frames.
|
||||
|
||||
Find the number of CPUs and Memory of the nodes of your partition by running:
|
||||
```bash
|
||||
sinfo -N -p your_partition -h -o "%N cpus=%c mem=%m"
|
||||
```
|
||||
|
||||
**Useful commands to check progress and debug**
|
||||
|
||||
Check if your jobs are running:
|
||||
```bash
|
||||
squeue -u $USER`
|
||||
```
|
||||
|
||||
You should see a list with job indices like `15125385_155` where `15125385` is the index of the run and `155` is the worker index. The output/print of this worker is written in real time in `/your/logs/job_name/slurm_jobs/15125385_155.out`. For instance, you can inspect the content of this file by running `less /your/logs/job_name/slurm_jobs/15125385_155.out`.
|
||||
|
||||
Check the progression of your jobs by running:
|
||||
```bash
|
||||
jobs_status /your/logs
|
||||
```
|
||||
|
||||
If it's not 100% and no more slurm job is running, it means that some of them failed. Inspect the logs by running:
|
||||
```bash
|
||||
failed_logs /your/logs/job_name
|
||||
```
|
||||
|
||||
If there is an issue in the code, you can fix it in debug mode with `--slurm 0` which allows to set breakpoint:
|
||||
```bash
|
||||
python examples/port_datasets/droid_rlds/slurm_port_shards.py --slurm 0 ...
|
||||
```
|
||||
|
||||
And you can relaunch the same command, which will skip the completed jobs:
|
||||
```bash
|
||||
python examples/port_datasets/droid_rlds/slurm_port_shards.py --slurm 1 ...
|
||||
```
|
||||
|
||||
Once all jobs are completed, you will have one dataset per shard (e.g. `droid_1.0.1_world_2048_rank_1594`) saved on disk in your `/lerobot/home/dir/your_id` directory. You can find your `/lerobot/home/dir` by running:
|
||||
```bash
|
||||
python -c "from lerobot.common.constants import HF_LEROBOT_HOME;print(HF_LEROBOT_HOME)"
|
||||
```
|
||||
|
||||
|
||||
### 2. Aggregate all shards
|
||||
|
||||
Run this script to start aggregation:
|
||||
```bash
|
||||
python examples/port_datasets/droid_rlds/slurm_aggregate_shards.py \
|
||||
--repo-id your_id/droid_1.0.1 \
|
||||
--logs-dir /your/logs \
|
||||
--job-name aggr_droid \
|
||||
--partition your_partition \
|
||||
--workers 2048 \
|
||||
--cpus-per-task 8 \
|
||||
--mem-per-cpu 1950M
|
||||
```
|
||||
|
||||
Once all jobs are completed, you will have one dataset your `/lerobot/home/dir/your_id/droid_1.0.1` directory.
|
||||
|
||||
|
||||
### 3. Upload dataset
|
||||
|
||||
Run this script to start uploading:
|
||||
```bash
|
||||
python examples/port_datasets/droid_rlds/slurm_upload.py \
|
||||
--repo-id your_id/droid_1.0.1 \
|
||||
--logs-dir /your/logs \
|
||||
--job-name upload_droid \
|
||||
--partition your_partition \
|
||||
--workers 50 \
|
||||
--cpus-per-task 4 \
|
||||
--mem-per-cpu 1950M
|
||||
```
|
||||
430
examples/port_datasets/droid_rlds/port_droid.py
Normal file
@@ -0,0 +1,430 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
|
||||
|
||||
DROID_SHARDS = 2048
|
||||
DROID_FPS = 15
|
||||
DROID_ROBOT_TYPE = "Franka"
|
||||
|
||||
# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
|
||||
DROID_FEATURES = {
|
||||
# true on first step of the episode
|
||||
"is_first": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
# true on last step of the episode
|
||||
"is_last": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
# true on last step of the episode if it is a terminal step, True for demos
|
||||
"is_terminal": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
# language_instruction is also stored as "task" to follow LeRobot standard
|
||||
"language_instruction": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"language_instruction_2": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"language_instruction_3": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"observation.state.gripper_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": {
|
||||
"axes": ["gripper"],
|
||||
},
|
||||
},
|
||||
"observation.state.cartesian_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||
},
|
||||
},
|
||||
"observation.state.joint_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {
|
||||
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
|
||||
},
|
||||
},
|
||||
# Add this new feature to follow LeRobot standard of using joint position + gripper
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (8,),
|
||||
"names": {
|
||||
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
|
||||
},
|
||||
},
|
||||
# Initially called wrist_image_left
|
||||
"observation.images.wrist_left": {
|
||||
"dtype": "video",
|
||||
"shape": (180, 320, 3),
|
||||
"names": [
|
||||
"height",
|
||||
"width",
|
||||
"channels",
|
||||
],
|
||||
},
|
||||
# Initially called exterior_image_1_left
|
||||
"observation.images.exterior_1_left": {
|
||||
"dtype": "video",
|
||||
"shape": (180, 320, 3),
|
||||
"names": [
|
||||
"height",
|
||||
"width",
|
||||
"channels",
|
||||
],
|
||||
},
|
||||
# Initially called exterior_image_2_left
|
||||
"observation.images.exterior_2_left": {
|
||||
"dtype": "video",
|
||||
"shape": (180, 320, 3),
|
||||
"names": [
|
||||
"height",
|
||||
"width",
|
||||
"channels",
|
||||
],
|
||||
},
|
||||
"action.gripper_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": {
|
||||
"axes": ["gripper"],
|
||||
},
|
||||
},
|
||||
"action.gripper_velocity": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": {
|
||||
"axes": ["gripper"],
|
||||
},
|
||||
},
|
||||
"action.cartesian_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||
},
|
||||
},
|
||||
"action.cartesian_velocity": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||
},
|
||||
},
|
||||
"action.joint_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {
|
||||
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
|
||||
},
|
||||
},
|
||||
"action.joint_velocity": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {
|
||||
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
|
||||
},
|
||||
},
|
||||
# This feature was called "action" in RLDS dataset and consists of [6x joint velocities, 1x gripper position]
|
||||
"action.original": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
|
||||
},
|
||||
},
|
||||
# Add this new feature to follow LeRobot standard of using joint position + gripper
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (8,),
|
||||
"names": {
|
||||
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
|
||||
},
|
||||
},
|
||||
"discount": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"reward": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
# Meta data that are the same for all frames in the episode
|
||||
"task_category": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"building": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"collector_id": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"date": {
|
||||
"dtype": "string",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"camera_extrinsics.wrist_left": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||
},
|
||||
},
|
||||
"camera_extrinsics.exterior_1_left": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||
},
|
||||
},
|
||||
"camera_extrinsics.exterior_2_left": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {
|
||||
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||
},
|
||||
},
|
||||
"is_episode_successful": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def is_episode_successful(tf_episode_metadata):
|
||||
# Adapted from: https://github.com/droid-dataset/droid_policy_learning/blob/dd1020eb20d981f90b5ff07dc80d80d5c0cb108b/robomimic/utils/rlds_utils.py#L8
|
||||
return "/success/" in tf_episode_metadata["file_path"].numpy().decode()
|
||||
|
||||
|
||||
def generate_lerobot_frames(tf_episode):
|
||||
m = tf_episode["episode_metadata"]
|
||||
frame_meta = {
|
||||
"task_category": m["building"].numpy().decode(),
|
||||
"building": m["building"].numpy().decode(),
|
||||
"collector_id": m["collector_id"].numpy().decode(),
|
||||
"date": m["date"].numpy().decode(),
|
||||
"camera_extrinsics.wrist_left": m["extrinsics_wrist_cam"].numpy(),
|
||||
"camera_extrinsics.exterior_1_left": m["extrinsics_exterior_cam_1"].numpy(),
|
||||
"camera_extrinsics.exterior_2_left": m["extrinsics_exterior_cam_2"].numpy(),
|
||||
"is_episode_successful": np.array([is_episode_successful(m)]),
|
||||
}
|
||||
for f in tf_episode["steps"]:
|
||||
# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
|
||||
frame = {
|
||||
"is_first": np.array([f["is_first"].numpy()]),
|
||||
"is_last": np.array([f["is_last"].numpy()]),
|
||||
"is_terminal": np.array([f["is_terminal"].numpy()]),
|
||||
"language_instruction": f["language_instruction"].numpy().decode(),
|
||||
"language_instruction_2": f["language_instruction_2"].numpy().decode(),
|
||||
"language_instruction_3": f["language_instruction_3"].numpy().decode(),
|
||||
"observation.state.gripper_position": f["observation"]["gripper_position"].numpy(),
|
||||
"observation.state.cartesian_position": f["observation"]["cartesian_position"].numpy(),
|
||||
"observation.state.joint_position": f["observation"]["joint_position"].numpy(),
|
||||
"observation.images.wrist_left": f["observation"]["wrist_image_left"].numpy(),
|
||||
"observation.images.exterior_1_left": f["observation"]["exterior_image_1_left"].numpy(),
|
||||
"observation.images.exterior_2_left": f["observation"]["exterior_image_2_left"].numpy(),
|
||||
"action.gripper_position": f["action_dict"]["gripper_position"].numpy(),
|
||||
"action.gripper_velocity": f["action_dict"]["gripper_velocity"].numpy(),
|
||||
"action.cartesian_position": f["action_dict"]["cartesian_position"].numpy(),
|
||||
"action.cartesian_velocity": f["action_dict"]["cartesian_velocity"].numpy(),
|
||||
"action.joint_position": f["action_dict"]["joint_position"].numpy(),
|
||||
"action.joint_velocity": f["action_dict"]["joint_velocity"].numpy(),
|
||||
"discount": np.array([f["discount"].numpy()]),
|
||||
"reward": np.array([f["reward"].numpy()]),
|
||||
"action.original": f["action"].numpy(),
|
||||
}
|
||||
|
||||
# language_instruction is also stored as "task" to follow LeRobot standard
|
||||
frame["task"] = frame["language_instruction"]
|
||||
|
||||
# Add this new feature to follow LeRobot standard of using joint position + gripper
|
||||
frame["observation.state"] = np.concatenate(
|
||||
[frame["observation.state.joint_position"], frame["observation.state.gripper_position"]]
|
||||
)
|
||||
frame["action"] = np.concatenate([frame["action.joint_position"], frame["action.gripper_position"]])
|
||||
|
||||
# Meta data that are the same for all frames in the episode
|
||||
frame.update(frame_meta)
|
||||
|
||||
# Cast fp64 to fp32
|
||||
for key in frame:
|
||||
if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64:
|
||||
frame[key] = frame[key].astype(np.float32)
|
||||
|
||||
yield frame
|
||||
|
||||
|
||||
def port_droid(
|
||||
raw_dir: Path,
|
||||
repo_id: str,
|
||||
push_to_hub: bool = False,
|
||||
num_shards: int | None = None,
|
||||
shard_index: int | None = None,
|
||||
):
|
||||
dataset_name = raw_dir.parent.name
|
||||
version = raw_dir.name
|
||||
data_dir = raw_dir.parent.parent
|
||||
|
||||
builder = tfds.builder(f"{dataset_name}/{version}", data_dir=data_dir, version="")
|
||||
|
||||
if num_shards is not None:
|
||||
tfds_num_shards = builder.info.splits["train"].num_shards
|
||||
if tfds_num_shards != DROID_SHARDS:
|
||||
raise ValueError(
|
||||
f"Number of shards of Droid dataset is expected to be {DROID_SHARDS} but is {tfds_num_shards}."
|
||||
)
|
||||
if num_shards != tfds_num_shards:
|
||||
raise ValueError(
|
||||
f"We only shard over the fixed number of shards provided by tensorflow dataset ({tfds_num_shards}), but {num_shards} shards provided instead."
|
||||
)
|
||||
if shard_index >= tfds_num_shards:
|
||||
raise ValueError(
|
||||
f"Shard index is greater than the num of shards ({shard_index} >= {num_shards})."
|
||||
)
|
||||
|
||||
raw_dataset = builder.as_dataset(split=f"train[{shard_index}shard]")
|
||||
else:
|
||||
raw_dataset = builder.as_dataset(split="train")
|
||||
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
robot_type=DROID_ROBOT_TYPE,
|
||||
fps=DROID_FPS,
|
||||
features=DROID_FEATURES,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
num_episodes = raw_dataset.cardinality().numpy().item()
|
||||
logging.info(f"Number of episodes {num_episodes}")
|
||||
|
||||
for episode_index, episode in enumerate(raw_dataset):
|
||||
elapsed_time = time.time() - start_time
|
||||
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
|
||||
|
||||
logging.info(
|
||||
f"{episode_index} / {num_episodes} episodes processed (after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)"
|
||||
)
|
||||
|
||||
for frame in generate_lerobot_frames(episode):
|
||||
lerobot_dataset.add_frame(frame)
|
||||
|
||||
lerobot_dataset.save_episode()
|
||||
logging.info("Save_episode")
|
||||
|
||||
if push_to_hub:
|
||||
lerobot_dataset.push_to_hub(
|
||||
# Add openx tag, since it belongs to the openx collection of datasets
|
||||
tags=["openx"],
|
||||
private=False,
|
||||
)
|
||||
|
||||
|
||||
def validate_dataset(repo_id):
|
||||
"""Sanity check that ensure meta data can be loaded and all files are present."""
|
||||
meta = LeRobotDatasetMetadata(repo_id)
|
||||
|
||||
if meta.total_episodes == 0:
|
||||
raise ValueError("Number of episodes is 0.")
|
||||
|
||||
for ep_idx in range(meta.total_episodes):
|
||||
data_path = meta.root / meta.get_data_file_path(ep_idx)
|
||||
|
||||
if not data_path.exists():
|
||||
raise ValueError(f"Parquet file is missing in: {data_path}")
|
||||
|
||||
for vid_key in meta.video_keys:
|
||||
vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key)
|
||||
if not vid_path.exists():
|
||||
raise ValueError(f"Video file is missing in: {vid_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Upload to hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-shards",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of shards. Can be either None to load the full dataset, or 2048 to load one of the 2048 tensorflow dataset files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard-index",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Index of the shard. Can be either None to load the full dataset, or in [0,2047] to load one of the 2048 tensorflow dataset files.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
port_droid(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
293
examples/port_datasets/droid_rlds/slurm_aggregate_shards.py
Normal file
@@ -0,0 +1,293 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
from lerobot.common.datasets.aggregate import validate_all_metadata
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
legacy_write_episode_stats,
|
||||
legacy_write_task,
|
||||
write_episode,
|
||||
write_info,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
class AggregateDatasets(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
aggregated_repo_id: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.aggr_repo_id = aggregated_repo_id
|
||||
|
||||
self.create_aggr_dataset()
|
||||
|
||||
def create_aggr_dataset(self):
|
||||
init_logging()
|
||||
|
||||
logging.info("Start aggregate_datasets")
|
||||
|
||||
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in self.repo_ids]
|
||||
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
|
||||
# Create resulting dataset folder
|
||||
aggr_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=self.aggr_repo_id,
|
||||
fps=fps,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
)
|
||||
|
||||
logging.info("Find all tasks")
|
||||
# find all tasks, deduplicate them, create new task indices for each dataset
|
||||
# indexed by dataset index
|
||||
datasets_task_index_to_aggr_task_index = {}
|
||||
aggr_task_index = 0
|
||||
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Find all tasks")):
|
||||
task_index_to_aggr_task_index = {}
|
||||
|
||||
for task_index, task in meta.tasks.items():
|
||||
if task not in aggr_meta.task_to_task_index:
|
||||
# add the task to aggr tasks mappings
|
||||
aggr_meta.tasks[aggr_task_index] = task
|
||||
aggr_meta.task_to_task_index[task] = aggr_task_index
|
||||
aggr_task_index += 1
|
||||
|
||||
# add task_index anyway
|
||||
task_index_to_aggr_task_index[task_index] = aggr_meta.task_to_task_index[task]
|
||||
|
||||
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
|
||||
|
||||
logging.info("Prepare copy data and videos")
|
||||
datasets_ep_idx_to_aggr_ep_idx = {}
|
||||
datasets_aggr_episode_index_shift = {}
|
||||
aggr_episode_index_shift = 0
|
||||
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Prepare copy data and videos")):
|
||||
ep_idx_to_aggr_ep_idx = {}
|
||||
|
||||
for episode_index in range(meta.total_episodes):
|
||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||
ep_idx_to_aggr_ep_idx[episode_index] = aggr_episode_index
|
||||
|
||||
datasets_ep_idx_to_aggr_ep_idx[dataset_index] = ep_idx_to_aggr_ep_idx
|
||||
datasets_aggr_episode_index_shift[dataset_index] = aggr_episode_index_shift
|
||||
|
||||
# populate episodes
|
||||
for episode_index, episode_dict in meta.episodes.items():
|
||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||
episode_dict["episode_index"] = aggr_episode_index
|
||||
aggr_meta.episodes[aggr_episode_index] = episode_dict
|
||||
|
||||
# populate episodes_stats
|
||||
for episode_index, episode_stats in meta.episodes_stats.items():
|
||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
|
||||
|
||||
# populate info
|
||||
aggr_meta.info["total_episodes"] += meta.total_episodes
|
||||
aggr_meta.info["total_frames"] += meta.total_frames
|
||||
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
|
||||
|
||||
aggr_episode_index_shift += meta.total_episodes
|
||||
|
||||
logging.info("Write meta data")
|
||||
aggr_meta.info["total_tasks"] = len(aggr_meta.tasks)
|
||||
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
|
||||
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
|
||||
|
||||
# create a new episodes jsonl with updated episode_index using write_episode
|
||||
for episode_dict in tqdm.tqdm(aggr_meta.episodes.values(), desc="Write episodes"):
|
||||
write_episode(episode_dict, aggr_meta.root)
|
||||
|
||||
# create a new episode_stats jsonl with updated episode_index using write_episode_stats
|
||||
for episode_index, episode_stats in tqdm.tqdm(
|
||||
aggr_meta.episodes_stats.items(), desc="Write episodes stats"
|
||||
):
|
||||
legacy_write_episode_stats(episode_index, episode_stats, aggr_meta.root)
|
||||
|
||||
# create a new task jsonl with updated episode_index using write_task
|
||||
for task_index, task in tqdm.tqdm(aggr_meta.tasks.items(), desc="Write tasks"):
|
||||
legacy_write_task(task_index, task, aggr_meta.root)
|
||||
|
||||
write_info(aggr_meta.info, aggr_meta.root)
|
||||
|
||||
self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index
|
||||
self.datasets_ep_idx_to_aggr_ep_idx = datasets_ep_idx_to_aggr_ep_idx
|
||||
self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift
|
||||
|
||||
logging.info("Meta data done writing!")
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from lerobot.common.datasets.aggregate import get_update_episode_and_task_func
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
|
||||
aggr_meta = LeRobotDatasetMetadata(self.aggr_repo_id)
|
||||
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in self.repo_ids]
|
||||
|
||||
if world_size != len(all_metadata):
|
||||
raise ValueError()
|
||||
|
||||
dataset_index = rank
|
||||
meta = all_metadata[dataset_index]
|
||||
aggr_episode_index_shift = self.datasets_aggr_episode_index_shift[dataset_index]
|
||||
|
||||
logging.info("Copy data")
|
||||
for episode_index in range(meta.total_episodes):
|
||||
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
|
||||
data_path = meta.root / meta.get_data_file_path(episode_index)
|
||||
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
|
||||
|
||||
# update episode_index and task_index
|
||||
df = pd.read_parquet(data_path)
|
||||
update_row_func = get_update_episode_and_task_func(
|
||||
aggr_episode_index_shift, self.datasets_task_index_to_aggr_task_index[dataset_index]
|
||||
)
|
||||
df = df.apply(update_row_func, axis=1)
|
||||
|
||||
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(aggr_data_path)
|
||||
|
||||
logging.info("Copy videos")
|
||||
for episode_index in range(meta.total_episodes):
|
||||
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||
for vid_key in meta.video_keys:
|
||||
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
|
||||
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
|
||||
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(video_path, aggr_video_path)
|
||||
|
||||
# copy_command = f"cp {video_path} {aggr_video_path} &"
|
||||
# subprocess.Popen(copy_command, shell=True)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
def make_aggregate_executor(
|
||||
repo_ids, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
AggregateDatasets(repo_ids, repo_id),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": DROID_SHARDS,
|
||||
"workers": workers,
|
||||
"time": "08:00:00",
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": DROID_SHARDS,
|
||||
"workers": 1,
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
help="Path to logs directory for `datatrove`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="aggr_droid",
|
||||
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Number of slurm workers. It should be less than the maximum number of shards.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of cpus that each slurm worker will use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="1950M",
|
||||
help="Memory per cpu that each worker will use.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
|
||||
repo_ids = [f"{args.repo_id}_world_{DROID_SHARDS}_rank_{rank}" for rank in range(DROID_SHARDS)]
|
||||
aggregate_executor = make_aggregate_executor(repo_ids, **kwargs)
|
||||
aggregate_executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
147
examples/port_datasets/droid_rlds/slurm_port_shards.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
|
||||
|
||||
class PortDroidShards(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
raw_dir: Path | str,
|
||||
repo_id: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.raw_dir = Path(raw_dir)
|
||||
self.repo_id = repo_id
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
|
||||
from examples.port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
disable_progress_bars()
|
||||
|
||||
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
|
||||
|
||||
try:
|
||||
validate_dataset(shard_repo_id)
|
||||
return
|
||||
except:
|
||||
pass
|
||||
|
||||
port_droid(
|
||||
self.raw_dir,
|
||||
shard_repo_id,
|
||||
push_to_hub=False,
|
||||
num_shards=world_size,
|
||||
shard_index=rank,
|
||||
)
|
||||
|
||||
validate_dataset(shard_repo_id)
|
||||
|
||||
|
||||
def make_port_executor(
|
||||
raw_dir, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
PortDroidShards(raw_dir, repo_id),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": DROID_SHARDS,
|
||||
"workers": workers,
|
||||
"time": "08:00:00",
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
help="Path to logs directory for `datatrove`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="port_droid",
|
||||
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Number of slurm workers. It should be less than the maximum number of shards.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of cpus that each slurm worker will use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="1950M",
|
||||
help="Memory per cpu that each worker will use.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
port_executor = make_port_executor(**kwargs)
|
||||
port_executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
263
examples/port_datasets/droid_rlds/slurm_upload.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
|
||||
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
class UploadDataset(PipelineStep):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
branch: str | None = None,
|
||||
revision: str | None = None,
|
||||
tags: list | None = None,
|
||||
license: str | None = "apache-2.0",
|
||||
private: bool = False,
|
||||
distant_repo_id: str | None = None,
|
||||
**card_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.distant_repo_id = self.repo_id if distant_repo_id is None else distant_repo_id
|
||||
self.branch = branch
|
||||
self.tags = tags
|
||||
self.license = license
|
||||
self.private = private
|
||||
self.card_kwargs = card_kwargs
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
|
||||
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") != "1":
|
||||
logging.warning(
|
||||
'HF_HUB_ENABLE_HF_TRANSFER is not set to "1". Install hf_transfer and set the env '
|
||||
"variable for faster uploads:\npip install hf-transfer\nexport HF_HUB_ENABLE_HF_TRANSFER=1"
|
||||
)
|
||||
|
||||
self.create_repo()
|
||||
|
||||
def create_repo(self):
|
||||
logging.info(f"Loading meta data from {self.repo_id}...")
|
||||
meta = LeRobotDatasetMetadata(self.repo_id)
|
||||
|
||||
logging.info(f"Creating repo {self.distant_repo_id}...")
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(
|
||||
repo_id=self.distant_repo_id,
|
||||
private=self.private,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
if self.branch:
|
||||
hub_api.create_branch(
|
||||
repo_id=self.distant_repo_id,
|
||||
branch=self.branch,
|
||||
revision=self.revision,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
if not hub_api.file_exists(
|
||||
self.distant_repo_id, REPOCARD_NAME, repo_type="dataset", revision=self.branch
|
||||
):
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=self.tags, dataset_info=meta.info, license=self.license, **self.card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.distant_repo_id, repo_type="dataset", revision=self.branch)
|
||||
|
||||
def list_files_recursively(directory):
|
||||
base_path = Path(directory)
|
||||
return [str(file.relative_to(base_path)) for file in base_path.rglob("*") if file.is_file()]
|
||||
|
||||
logging.info(f"Listing all local files from {self.repo_id}...")
|
||||
self.file_paths = list_files_recursively(meta.root)
|
||||
self.file_paths = sorted(self.file_paths)
|
||||
|
||||
def create_chunks(self, lst, n):
|
||||
from itertools import islice
|
||||
|
||||
it = iter(lst)
|
||||
return [list(islice(it, size)) for size in [len(lst) // n + (i < len(lst) % n) for i in range(n)]]
|
||||
|
||||
def create_commits(self, additions):
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
from huggingface_hub import create_commit
|
||||
from huggingface_hub.utils import HfHubHTTPError
|
||||
|
||||
FILES_BETWEEN_COMMITS = 10 # noqa: N806
|
||||
BASE_DELAY = 0.1 # noqa: N806
|
||||
MAX_RETRIES = 12 # noqa: N806
|
||||
|
||||
# Split the files into smaller chunks for faster commit
|
||||
# and avoiding "A commit has happened since" error
|
||||
num_chunks = math.ceil(len(additions) / FILES_BETWEEN_COMMITS)
|
||||
chunks = self.create_chunks(additions, num_chunks)
|
||||
|
||||
for chunk in chunks:
|
||||
retries = 0
|
||||
while True:
|
||||
try:
|
||||
create_commit(
|
||||
self.distant_repo_id,
|
||||
repo_type="dataset",
|
||||
operations=chunk,
|
||||
commit_message=f"DataTrove upload ({len(chunk)} files)",
|
||||
revision=self.branch,
|
||||
)
|
||||
# TODO: every 100 chunks super_squach_commits()
|
||||
logging.info("create_commit completed!")
|
||||
break
|
||||
except HfHubHTTPError as e:
|
||||
if "A commit has happened since" in e.server_message:
|
||||
if retries >= MAX_RETRIES:
|
||||
logging.error(f"Failed to create commit after {MAX_RETRIES=}. Giving up.")
|
||||
raise e
|
||||
logging.info("Commit creation race condition issue. Waiting...")
|
||||
time.sleep(BASE_DELAY * 2**retries + random.uniform(0, 2))
|
||||
retries += 1
|
||||
else:
|
||||
raise e
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import logging
|
||||
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
from huggingface_hub import CommitOperationAdd, preupload_lfs_files
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
disable_progress_bars()
|
||||
|
||||
chunks = self.create_chunks(self.file_paths, world_size)
|
||||
file_paths = chunks[rank]
|
||||
|
||||
if len(file_paths) == 0:
|
||||
raise ValueError(file_paths)
|
||||
|
||||
logging.info("Pre-uploading LFS files...")
|
||||
for i, path in enumerate(file_paths):
|
||||
logging.info(f"{i}: {path}")
|
||||
|
||||
meta = LeRobotDatasetMetadata(self.repo_id)
|
||||
additions = [
|
||||
CommitOperationAdd(path_in_repo=path, path_or_fileobj=meta.root / path) for path in file_paths
|
||||
]
|
||||
preupload_lfs_files(
|
||||
repo_id=self.distant_repo_id, repo_type="dataset", additions=additions, revision=self.branch
|
||||
)
|
||||
|
||||
logging.info("Creating commits...")
|
||||
self.create_commits(additions)
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
def make_upload_executor(
|
||||
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
UploadDataset(repo_id),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": DROID_SHARDS,
|
||||
"workers": workers,
|
||||
"time": "08:00:00",
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": DROID_SHARDS,
|
||||
"workers": 1,
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
help="Path to logs directory for `datatrove`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="upload_droid",
|
||||
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of slurm workers. It should be less than the maximum number of shards.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of cpus that each slurm worker will use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="1950M",
|
||||
help="Memory per cpu that each worker will use.",
|
||||
)
|
||||
|
||||
init_logging()
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||
upload_executor = make_upload_executor(**kwargs)
|
||||
upload_executor.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -168,7 +168,12 @@ available_datasets = sorted(
|
||||
)
|
||||
|
||||
# lists all available policies from `lerobot/common/policies`
|
||||
available_policies = ["act", "diffusion", "tdmpc", "vqbet", "smolvla"]
|
||||
available_policies = [
|
||||
"act",
|
||||
"diffusion",
|
||||
"tdmpc",
|
||||
"vqbet",
|
||||
]
|
||||
|
||||
# lists all available robots from `lerobot/common/robot_devices/robots`
|
||||
available_robots = [
|
||||
@@ -176,7 +181,6 @@ available_robots = [
|
||||
"koch_bimanual",
|
||||
"aloha",
|
||||
"so100",
|
||||
"so101",
|
||||
"moss",
|
||||
]
|
||||
|
||||
|
||||
416
lerobot/common/datasets/aggregate.py
Normal file
@@ -0,0 +1,416 @@
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
concat_video_files,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_video_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
# validate same fps, robot_type, features
|
||||
|
||||
fps = all_metadata[0].fps
|
||||
robot_type = all_metadata[0].robot_type
|
||||
features = all_metadata[0].features
|
||||
|
||||
for meta in tqdm.tqdm(all_metadata, desc="Validate all meta data"):
|
||||
if fps != meta.fps:
|
||||
raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.")
|
||||
if robot_type != meta.robot_type:
|
||||
raise ValueError(
|
||||
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
|
||||
)
|
||||
if features != meta.features:
|
||||
raise ValueError(
|
||||
f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||
)
|
||||
|
||||
return fps, robot_type, features
|
||||
|
||||
|
||||
def update_data_df(df, src_meta, dst_meta):
|
||||
def _update(row):
|
||||
row["episode_index"] = row["episode_index"] + dst_meta["total_episodes"]
|
||||
row["index"] = row["index"] + dst_meta["total_frames"]
|
||||
task = src_meta.tasks.iloc[row["task_index"]].name
|
||||
row["task_index"] = dst_meta.tasks.loc[task].task_index.item()
|
||||
return row
|
||||
|
||||
return df.apply(_update, axis=1)
|
||||
|
||||
|
||||
def update_meta_data(
|
||||
df,
|
||||
dst_meta,
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
):
|
||||
def _update(row):
|
||||
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk_index"]
|
||||
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file_index"]
|
||||
row["data/chunk_index"] = row["data/chunk_index"] + data_idx["chunk_index"]
|
||||
row["data/file_index"] = row["data/file_index"] + data_idx["file_index"]
|
||||
for key, video_idx in videos_idx.items():
|
||||
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + video_idx["chunk_index"]
|
||||
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + video_idx["file_index"]
|
||||
row[f"videos/{key}/from_timestamp"] = (
|
||||
row[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
row[f"videos/{key}/to_timestamp"] = (
|
||||
row[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
row["dataset_from_index"] = row["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
row["dataset_to_index"] = row["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
|
||||
return row
|
||||
|
||||
return df.apply(_update, axis=1)
|
||||
|
||||
|
||||
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None):
|
||||
logging.info("Start aggregate_datasets")
|
||||
|
||||
# Load metadata
|
||||
all_metadata = (
|
||||
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
||||
if roots is None
|
||||
else [
|
||||
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||
]
|
||||
)
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
image_keys = [key for key in features if features[key]["dtype"] == "image"]
|
||||
|
||||
# Initialize output dataset metadata
|
||||
dst_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=aggr_repo_id,
|
||||
fps=fps,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
root=aggr_root,
|
||||
)
|
||||
|
||||
# Aggregate task info
|
||||
logging.info("Find all tasks")
|
||||
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
|
||||
dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
|
||||
|
||||
# Track counters and indices
|
||||
meta_idx = {"chunk": 0, "file": 0}
|
||||
data_idx = {"chunk": 0, "file": 0}
|
||||
videos_idx = {
|
||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
||||
}
|
||||
|
||||
dst_meta.episodes = {}
|
||||
|
||||
# Process each dataset
|
||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx)
|
||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx)
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, video_keys, image_keys)
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
|
||||
finalize_aggregation(dst_meta, all_metadata)
|
||||
logging.info("Aggregation complete.")
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Helper Functions
|
||||
# -------------------------------
|
||||
|
||||
|
||||
def aggregate_videos(src_meta, dst_meta, videos_idx):
|
||||
"""
|
||||
Aggregates video chunks from a dataset into the aggregated dataset folder.
|
||||
"""
|
||||
for key, video_idx in videos_idx.items():
|
||||
# Get unique (chunk, file) combinations
|
||||
unique_chunk_file_pairs = {
|
||||
(chunk, file)
|
||||
for chunk, file in zip(
|
||||
src_meta.episodes[f"videos/{key}/chunk_index"],
|
||||
src_meta.episodes[f"videos/{key}/file_index"],
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
|
||||
# Current target chunk/file index
|
||||
chunk_idx = video_idx["chunk_idx"]
|
||||
file_idx = video_idx["file_idx"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
file_index=src_file_idx,
|
||||
)
|
||||
|
||||
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
if not dst_path.exists():
|
||||
# First write to this destination file
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
continue
|
||||
|
||||
# Check file sizes before appending
|
||||
src_size = get_video_size_in_mb(src_path)
|
||||
dst_size = get_video_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= DEFAULT_VIDEO_FILE_SIZE_IN_MB:
|
||||
# Rotate to a new chunk/file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
else:
|
||||
# Append to existing video file
|
||||
concat_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_meta.root,
|
||||
key,
|
||||
chunk_idx,
|
||||
file_idx,
|
||||
)
|
||||
|
||||
if src_size + dst_size >= DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||
# Size limit is reached, prepare new parquet file
|
||||
aggr_data_chunk_idx, aggr_data_file_idx = update_chunk_file_indices(
|
||||
aggr_data_chunk_idx, aggr_data_file_idx, DEFAULT_CHUNK_SIZE
|
||||
)
|
||||
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx
|
||||
)
|
||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(aggr_path)
|
||||
else:
|
||||
# Update the existing parquet file with new rows
|
||||
aggr_df = pd.read_parquet(aggr_path)
|
||||
df = pd.concat([aggr_df, df], ignore_index=True)
|
||||
to_parquet_with_hf_images(df, aggr_path, dst_meta.image_keys)
|
||||
|
||||
return videos_idx
|
||||
|
||||
|
||||
def aggregate_data(src_meta, dst_meta, data_idx):
|
||||
unique_chunk_file_ids = {
|
||||
(c, f)
|
||||
for c, f in zip(
|
||||
src_meta.episodes["data/chunk_index"], src_meta.episodes["data/file_index"], strict=False
|
||||
)
|
||||
}
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||
)
|
||||
df = pd.read_parquet(src_path)
|
||||
df = update_data_df(df, src_meta, dst_meta)
|
||||
|
||||
data_idx = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
data_idx,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_PATH,
|
||||
contains_images=len(dst_meta.image_keys) > 0
|
||||
)
|
||||
|
||||
return data_idx
|
||||
|
||||
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
chunk_file_ids = {
|
||||
(c, f)
|
||||
for c, f in zip(
|
||||
src_meta.episodes["meta/episodes/chunk_index"],
|
||||
src_meta.episodes["meta/episodes/file_index"],
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
|
||||
for chunk_idx, file_idx in chunk_file_ids:
|
||||
src_path = src_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
df = pd.read_parquet(src_path)
|
||||
df = update_meta_data(
|
||||
df,
|
||||
dst_meta,
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
)
|
||||
|
||||
# for k in video_keys:
|
||||
# video_idx[k]["latest_duration"] += video_idx[k]["episode_duration"]
|
||||
|
||||
append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
meta_idx,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
)
|
||||
|
||||
return meta_idx
|
||||
|
||||
|
||||
def append_or_create_parquet_file(
|
||||
df: pd.DataFrame,
|
||||
src_path: Path,
|
||||
idx: dict[str, int],
|
||||
max_mb: float,
|
||||
chunk_size: int,
|
||||
default_path: str,
|
||||
contains_images: bool = False,
|
||||
):
|
||||
"""
|
||||
Safely appends or creates a Parquet file at dst_path based on size constraints.
|
||||
|
||||
Parameters:
|
||||
df (pd.DataFrame): Data to write.
|
||||
src_path (Path): Path to source file (used to get size).
|
||||
idx (dict): Dictionary containing 'chunk' and 'file' indices.
|
||||
max_mb (float): Maximum allowed file size in MB.
|
||||
chunk_size (int): Maximum number of files per chunk.
|
||||
default_path (str): Format string for generating a new file path.
|
||||
|
||||
Returns:
|
||||
dict: Updated index dictionary.
|
||||
"""
|
||||
# Initial destination path
|
||||
dst_path = aggr_root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=idx["chunk"], file_index=idx["file"]
|
||||
)
|
||||
|
||||
# If destination file doesn't exist, just write the new one
|
||||
if not dst_path.exists():
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path)
|
||||
return idx
|
||||
|
||||
# Otherwise, check if we exceed the size limit
|
||||
src_size = get_parquet_file_size_in_mb(src_path)
|
||||
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= max_mb:
|
||||
# File is too large, move to a new one
|
||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||
new_path = dst_path.parent / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
final_df = df
|
||||
else:
|
||||
# Append to existing file
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
final_df = pd.concat([existing_df, df], ignore_index=True)
|
||||
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(final_df, new_path)
|
||||
else:
|
||||
final_df.to_parquet(new_path)
|
||||
|
||||
return idx
|
||||
|
||||
|
||||
def finalize_aggregation(aggr_meta, all_metadata):
|
||||
logging.info("write tasks")
|
||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||
|
||||
logging.info("write info")
|
||||
aggr_meta.info.update(
|
||||
{
|
||||
"total_tasks": len(aggr_meta.tasks),
|
||||
"total_episodes": sum(m.total_episodes for m in all_metadata),
|
||||
"total_frames": sum(m.total_frames for m in all_metadata),
|
||||
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
|
||||
}
|
||||
)
|
||||
write_info(aggr_meta.info, aggr_meta.root)
|
||||
|
||||
logging.info("write stats")
|
||||
aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata])
|
||||
write_stats(aggr_meta.stats, aggr_meta.root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
num_shards = 2048
|
||||
repo_id = "cadene/droid_1.0.1_v30"
|
||||
aggr_repo_id = f"{repo_id}_compact_6"
|
||||
tags = ["openx"]
|
||||
|
||||
# num_shards = 210
|
||||
# repo_id = "cadene/agibot_alpha_v30"
|
||||
# aggr_repo_id = f"{repo_id}"
|
||||
# tags = None
|
||||
|
||||
# aggr_root = Path(f"/tmp/{aggr_repo_id}")
|
||||
aggr_root = HF_LEROBOT_HOME / aggr_repo_id
|
||||
if aggr_root.exists():
|
||||
shutil.rmtree(aggr_root)
|
||||
|
||||
repo_ids = []
|
||||
roots = []
|
||||
for rank in range(num_shards):
|
||||
shard_repo_id = f"{repo_id}_world_{num_shards}_rank_{rank}"
|
||||
shard_root = HF_LEROBOT_HOME / shard_repo_id
|
||||
try:
|
||||
meta = LeRobotDatasetMetadata(shard_repo_id, root=shard_root)
|
||||
if len(meta.video_keys) == 0:
|
||||
continue
|
||||
repo_ids.append(shard_repo_id)
|
||||
roots.append(shard_root)
|
||||
except:
|
||||
pass
|
||||
|
||||
if rank == 1:
|
||||
break
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids,
|
||||
aggr_repo_id,
|
||||
roots=roots,
|
||||
aggr_root=aggr_root,
|
||||
)
|
||||
|
||||
aggr_dataset = LeRobotDataset(repo_id=aggr_repo_id, root=aggr_root)
|
||||
# for i in tqdm.tqdm(range(len(aggr_dataset))):
|
||||
# aggr_dataset[i]
|
||||
# pass
|
||||
aggr_dataset.push_to_hub(tags=tags, upload_large_folder=True)
|
||||
@@ -47,6 +47,18 @@ If you encounter a problem, contact LeRobot maintainers on [Discord](https://dis
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
V30_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is in {version} format.
|
||||
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
||||
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
||||
```
|
||||
python lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py --repo-id={repo_id}
|
||||
```
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
"""
|
||||
|
||||
FUTURE_MESSAGE = """
|
||||
The dataset you requested ({repo_id}) is only available in {version} format.
|
||||
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
|
||||
@@ -58,7 +70,14 @@ class CompatibilityError(Exception): ...
|
||||
|
||||
class BackwardCompatibilityError(CompatibilityError):
|
||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
if version.major == 3:
|
||||
message = V30_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
elif version.major == 2:
|
||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)."
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ def resolve_delta_timestamps(
|
||||
"observation.state": [-0.04, -0.02, 0]
|
||||
"observation.action": [-0.02, 0, 0.02]
|
||||
}
|
||||
returns `None` if the resulting dict is empty.
|
||||
returns `None` if the the resulting dict is empty.
|
||||
"""
|
||||
delta_timestamps = {}
|
||||
for key in ds_meta.features:
|
||||
|
||||
@@ -106,7 +106,7 @@ def worker_process(queue: queue.Queue, num_threads: int):
|
||||
class AsyncImageWriter:
|
||||
"""
|
||||
This class abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchronously, which is critical to control a robot and record data
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
at a high frame rate.
|
||||
|
||||
When `num_processes=0`, it creates a threads pool of size `num_threads`.
|
||||
|
||||
@@ -16,16 +16,18 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import pandas as pd
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
@@ -34,36 +36,41 @@ from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
INFO_PATH,
|
||||
TASKS_PATH,
|
||||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
concat_video_files,
|
||||
create_empty_dataset_info,
|
||||
create_lerobot_dataset_card,
|
||||
embed_images,
|
||||
flatten_dict,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
get_hf_dataset_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_parquet_num_frames,
|
||||
get_safe_version,
|
||||
get_video_duration_in_s,
|
||||
get_video_size_in_mb,
|
||||
hf_transform_to_torch,
|
||||
is_valid_version,
|
||||
load_episodes,
|
||||
load_episodes_stats,
|
||||
load_info,
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
write_episode,
|
||||
write_episode_stats,
|
||||
write_info,
|
||||
write_json,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
@@ -74,7 +81,7 @@ from lerobot.common.datasets.video_utils import (
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
CODEBASE_VERSION = "v2.1"
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -98,20 +105,18 @@ class LeRobotDatasetMetadata:
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
# TODO(rcadene): instead of downloading all episodes metadata files,
|
||||
# download only the ones associated to the requested episodes. This would
|
||||
# require adding `episodes: list[int]` as argument.
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
if self._version < packaging.version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
else:
|
||||
self.episodes_stats = load_episodes_stats(self.root)
|
||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
@@ -133,18 +138,19 @@ class LeRobotDatasetMetadata:
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep["data/chunk_index"]
|
||||
file_idx = ep["data/file_index"]
|
||||
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
|
||||
file_idx = ep[f"videos/{vid_key}/file_index"]
|
||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
return ep_index // self.chunks_size
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
@@ -211,39 +217,108 @@ class LeRobotDatasetMetadata:
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def total_chunks(self) -> int:
|
||||
"""Total number of chunks (groups of episodes)."""
|
||||
return self.info["total_chunks"]
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of files per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of episodes per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
def data_files_size_in_mb(self) -> int:
|
||||
"""Max size of data file in mega bytes."""
|
||||
return self.info["data_files_size_in_mb"]
|
||||
|
||||
@property
|
||||
def video_files_size_in_mb(self) -> int:
|
||||
"""Max size of video file in mega bytes."""
|
||||
return self.info["video_files_size_in_mb"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise return None.
|
||||
"""
|
||||
return self.task_to_task_index.get(task, None)
|
||||
if task in self.tasks.index:
|
||||
return int(self.tasks.loc[task].task_index)
|
||||
else:
|
||||
return None
|
||||
|
||||
def add_task(self, task: str):
|
||||
def save_episode_tasks(self, tasks: list[str]):
|
||||
if len(set(tasks)) != len(tasks):
|
||||
raise ValueError(f"Tasks are not unique: {tasks}")
|
||||
|
||||
if self.tasks is None:
|
||||
new_tasks = tasks
|
||||
task_indices = range(len(tasks))
|
||||
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
|
||||
else:
|
||||
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
||||
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
|
||||
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
|
||||
self.tasks.loc[task] = task_idx
|
||||
|
||||
if len(new_tasks) > 0:
|
||||
# Update on disk
|
||||
write_tasks(self.tasks, self.root)
|
||||
|
||||
def _save_episode_metadata(self, episode_dict: dict) -> None:
|
||||
"""Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata.
|
||||
|
||||
This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset,
|
||||
and saves it as a parquet file. It handles both the creation of new parquet files and the
|
||||
updating of existing ones based on size constraints. After saving the metadata, it reloads
|
||||
the Hugging Face dataset to ensure it is up-to-date.
|
||||
|
||||
Notes: We both need to update parquet files and HF dataset:
|
||||
- `pandas` loads parquet file in RAM
|
||||
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
|
||||
or loads directly from pyarrow cache.
|
||||
"""
|
||||
Given a task in natural language, add it to the dictionary of tasks.
|
||||
"""
|
||||
if task in self.task_to_task_index:
|
||||
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
||||
# Convert buffer into HF Dataset
|
||||
episode_dict = {key: [value] for key, value in episode_dict.items()}
|
||||
ep_dataset = Dataset.from_dict(episode_dict)
|
||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||
df = pd.DataFrame(ep_dataset)
|
||||
num_frames = episode_dict["length"][0]
|
||||
|
||||
task_index = self.info["total_tasks"]
|
||||
self.task_to_task_index[task] = task_index
|
||||
self.tasks[task_index] = task
|
||||
self.info["total_tasks"] += 1
|
||||
if self.episodes is None:
|
||||
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
df["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
df["meta/episodes/file_index"] = [file_idx]
|
||||
df["dataset_from_index"] = [0]
|
||||
df["dataset_to_index"] = [num_frames]
|
||||
else:
|
||||
# Retrieve information from the latest parquet file
|
||||
latest_ep = self.episodes[-1]
|
||||
chunk_idx = latest_ep["meta/episodes/chunk_index"]
|
||||
file_idx = latest_ep["meta/episodes/file_index"]
|
||||
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
||||
latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.data_files_size_in_mb:
|
||||
# Size limit is reached, prepare new parquet file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||
|
||||
# Update the existing pandas dataframe with new row
|
||||
df["meta/episodes/chunk_index"] = [chunk_idx]
|
||||
df["meta/episodes/file_index"] = [file_idx]
|
||||
df["dataset_from_index"] = [latest_ep["dataset_to_index"]]
|
||||
df["dataset_to_index"] = [latest_ep["dataset_to_index"] + num_frames]
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb < self.data_files_size_in_mb:
|
||||
# Size limit wasnt reached, concatenate latest dataframe with new one
|
||||
latest_df = pd.read_parquet(latest_path)
|
||||
df = pd.concat([latest_df, df], ignore_index=True)
|
||||
|
||||
# Write the resulting dataframe from RAM to disk
|
||||
path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(path, index=False)
|
||||
|
||||
# Update the Hugging Face dataset by reloading it.
|
||||
# This process should be fast because only the latest Parquet file has been modified.
|
||||
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
|
||||
self.episodes = load_episodes(self.root)
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
@@ -251,32 +326,28 @@ class LeRobotDatasetMetadata:
|
||||
episode_length: int,
|
||||
episode_tasks: list[str],
|
||||
episode_stats: dict[str, dict],
|
||||
episode_metadata: dict,
|
||||
) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
self.info["total_chunks"] += 1
|
||||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": episode_tasks,
|
||||
"length": episode_length,
|
||||
}
|
||||
self.episodes[episode_index] = episode_dict
|
||||
write_episode(episode_dict, self.root)
|
||||
episode_dict.update(episode_metadata)
|
||||
episode_dict.update(flatten_dict({"stats": episode_stats}))
|
||||
self._save_episode_metadata(episode_dict)
|
||||
|
||||
self.episodes_stats[episode_index] = episode_stats
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
# Update info
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
self.info["total_tasks"] = len(self.tasks)
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
write_info(self.info, self.root)
|
||||
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||
write_stats(self.stats, self.root)
|
||||
|
||||
def update_video_info(self) -> None:
|
||||
"""
|
||||
@@ -341,8 +412,9 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||
obj.tasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
@@ -487,29 +559,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
|
||||
# Load actual data
|
||||
try:
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download_episodes(download_videos)
|
||||
self.download(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# Check timestamps
|
||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
@@ -585,7 +645,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
def download_episodes(self, download_videos: bool = True) -> None:
|
||||
def download(self, download_videos: bool = True) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
@@ -593,11 +653,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
files = None
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
files = None
|
||||
if self.episodes is not None:
|
||||
files = self.get_episodes_file_paths()
|
||||
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
@@ -610,19 +669,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += video_files
|
||||
|
||||
# episodes are stored in the same files, so we return unique paths only
|
||||
fpaths = list(set(fpaths))
|
||||
return fpaths
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if self.episodes is None:
|
||||
path = str(self.root / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
else:
|
||||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset = load_nested_dataset(self.root / "data")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
@@ -630,8 +683,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
features = get_hf_features_from_features(self.features)
|
||||
ft_dict = {col: [] for col in features}
|
||||
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
||||
|
||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
@@ -663,15 +714,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
ep_end = self.episode_data_index["to"][ep_idx]
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
query_indices = {
|
||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||
[(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
@@ -685,7 +737,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
query_timestamps = {}
|
||||
for key in self.meta.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
@@ -694,7 +746,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
key: torch.stack(self.hf_dataset[q_idx][key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
@@ -705,10 +757,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
the main process and a subprocess fails to access it.
|
||||
"""
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
|
||||
# Thus we load the start timestamp of the episode on this mp4 and
|
||||
# shift the query timestamp accordingly.
|
||||
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
||||
frames = decode_video_frames(video_path, shifted_query_ts, self.tolerance_s, self.video_backend)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
return item
|
||||
@@ -746,8 +805,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
|
||||
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
@@ -777,6 +835,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
if self.image_writer is None:
|
||||
if isinstance(image, torch.Tensor):
|
||||
@@ -855,11 +916,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
# Add new tasks to the tasks dictionary
|
||||
for task in episode_tasks:
|
||||
task_index = self.meta.get_task_index(task)
|
||||
if task_index is None:
|
||||
self.meta.add_task(task)
|
||||
# Update tasks and task indices with new tasks if any
|
||||
self.meta.save_episode_tasks(episode_tasks)
|
||||
|
||||
# Given tasks in natural language, find their corresponding task indices
|
||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||
@@ -871,51 +929,154 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_paths = self.encode_episode_videos(episode_index)
|
||||
for key in self.meta.video_keys:
|
||||
episode_buffer[key] = video_paths[key]
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
# `meta.save_episode` be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||
# `meta.save_episode` need to be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
|
||||
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||
check_timestamps_sync(
|
||||
episode_buffer["timestamp"],
|
||||
episode_buffer["episode_index"],
|
||||
ep_data_index_np,
|
||||
self.fps,
|
||||
self.tolerance_s,
|
||||
)
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
# TODO(rcadene): remove? there is only one episode in the episode buffer, no need for ep_data_index
|
||||
# ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||
# ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||
# check_timestamps_sync(
|
||||
# episode_buffer["timestamp"],
|
||||
# episode_buffer["episode_index"],
|
||||
# ep_data_index_np,
|
||||
# self.fps,
|
||||
# self.tolerance_s,
|
||||
# )
|
||||
|
||||
# TODO(rcadene): images are also deleted in clear_episode_buffer
|
||||
# delete images
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
if not episode_data:
|
||||
# Reset episode buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
def _save_episode_data(self, episode_buffer: dict) -> dict:
|
||||
"""Save episode data to a parquet file and update the Hugging Face dataset of frames data.
|
||||
|
||||
This function processes episodes data from a buffer, converts it into a Hugging Face dataset,
|
||||
and saves it as a parquet file. It handles both the creation of new parquet files and the
|
||||
updating of existing ones based on size constraints. After saving the data, it reloads
|
||||
the Hugging Face dataset to ensure it is up-to-date.
|
||||
|
||||
Notes: We both need to update parquet files and HF dataset:
|
||||
- `pandas` loads parquet file in RAM
|
||||
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
|
||||
or loads directly from pyarrow cache.
|
||||
"""
|
||||
# Convert buffer into HF Dataset
|
||||
ep_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train")
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
|
||||
self.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
ep_dataset.to_parquet(ep_data_path)
|
||||
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||
ep_num_frames = len(ep_dataset)
|
||||
df = pd.DataFrame(ep_dataset)
|
||||
|
||||
if self.meta.episodes is None:
|
||||
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
latest_num_frames = 0
|
||||
else:
|
||||
# Retrieve information from the latest parquet file
|
||||
latest_ep = self.meta.episodes[-1]
|
||||
chunk_idx = latest_ep["data/chunk_index"]
|
||||
file_idx = latest_ep["data/file_index"]
|
||||
|
||||
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||
latest_num_frames = get_parquet_num_frames(latest_path)
|
||||
|
||||
# Determine if a new parquet file is needed
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.data_files_size_in_mb:
|
||||
# Size limit is reached, prepare new parquet file
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
latest_num_frames = 0
|
||||
else:
|
||||
# Update the existing parquet file with new rows
|
||||
latest_df = pd.read_parquet(latest_path)
|
||||
df = pd.concat([latest_df, df], ignore_index=True)
|
||||
|
||||
# Write the resulting dataframe from RAM to disk
|
||||
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if len(self.meta.image_keys) > 0:
|
||||
to_parquet_with_hf_images(df, path)
|
||||
else:
|
||||
df.to_parquet(path)
|
||||
|
||||
# Update the Hugging Face dataset by reloading it.
|
||||
# This process should be fast because only the latest Parquet file has been modified.
|
||||
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
metadata = {
|
||||
"data/chunk_index": chunk_idx,
|
||||
"data/file_index": file_idx,
|
||||
"dataset_from_index": latest_num_frames,
|
||||
"dataset_to_index": latest_num_frames + ep_num_frames,
|
||||
}
|
||||
return metadata
|
||||
|
||||
def _save_episode_video(self, video_key: str, episode_index: int):
|
||||
# Encode episode frames into a temporary video
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
if self.meta.episodes is None:
|
||||
# Initialize indices for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
latest_duration_in_s = 0
|
||||
new_path = self.root / self.meta.video_path.format(
|
||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
else:
|
||||
# Retrieve information from the latest video file
|
||||
latest_ep = self.meta.episodes[-1]
|
||||
chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"]
|
||||
file_idx = latest_ep[f"videos/{video_key}/file_index"]
|
||||
|
||||
latest_path = self.root / self.meta.video_path.format(
|
||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
latest_size_in_mb = get_video_size_in_mb(latest_path)
|
||||
latest_duration_in_s = get_video_duration_in_s(latest_path)
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb:
|
||||
# Move temporary episode video to a new video file in the dataset
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
new_path = self.root / self.meta.video_path.format(
|
||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
else:
|
||||
# Update latest video file
|
||||
concat_video_files([latest_path, ep_path], self.root, video_key, chunk_idx, file_idx)
|
||||
|
||||
# Remove temporary directory
|
||||
shutil.rmtree(str(ep_path.parent))
|
||||
|
||||
metadata = {
|
||||
"episode_index": episode_index,
|
||||
f"videos/{video_key}/chunk_index": chunk_idx,
|
||||
f"videos/{video_key}/file_index": file_idx,
|
||||
f"videos/{video_key}/from_timestamp": latest_duration_in_s,
|
||||
f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self) -> None:
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
@@ -944,7 +1105,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def stop_image_writer(self) -> None:
|
||||
"""
|
||||
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
||||
remove the image_writer in order for the LeRobotDataset object to be picklable and parallelized.
|
||||
remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
"""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.stop()
|
||||
@@ -955,34 +1116,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def encode_videos(self) -> None:
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
for ep_idx in range(self.meta.total_episodes):
|
||||
self.encode_episode_videos(ep_idx)
|
||||
|
||||
def encode_episode_videos(self, episode_index: int) -> dict:
|
||||
"""
|
||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
video_paths = {}
|
||||
for key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
video_paths[key] = str(video_path)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
||||
|
||||
return video_paths
|
||||
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
img_dir = self._get_image_file_dir(episode_index, video_key)
|
||||
encode_video_frames(img_dir, temp_path, self.fps, overwrite=True)
|
||||
return temp_path
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -1027,7 +1170,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
return obj
|
||||
|
||||
|
||||
@@ -337,13 +337,11 @@ def compute_sampler_weights(
|
||||
if len(offline_dataset) > 0:
|
||||
offline_data_mask_indices = []
|
||||
for start_index, end_index in zip(
|
||||
offline_dataset.episode_data_index["from"],
|
||||
offline_dataset.episode_data_index["to"],
|
||||
offline_dataset.meta.episodes["dataset_from_index"],
|
||||
offline_dataset.meta.episodes["dataset_to_index"],
|
||||
strict=True,
|
||||
):
|
||||
offline_data_mask_indices.extend(
|
||||
range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
|
||||
)
|
||||
offline_data_mask_indices.extend(range(start_index, end_index - offline_drop_n_last_frames))
|
||||
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
|
||||
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
|
||||
weights.append(
|
||||
|
||||
@@ -21,7 +21,8 @@ import torch
|
||||
class EpisodeAwareSampler:
|
||||
def __init__(
|
||||
self,
|
||||
episode_data_index: dict,
|
||||
dataset_from_indices: list[int],
|
||||
dataset_to_indices: list[int],
|
||||
episode_indices_to_use: Union[list, None] = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
@@ -30,7 +31,8 @@ class EpisodeAwareSampler:
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
Args:
|
||||
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
|
||||
dataset_from_indices: List of indices containing the start of each episode in the dataset.
|
||||
dataset_to_indices: List of indices containing the end of each episode in the dataset.
|
||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||
Assumes that episodes are indexed from 0 to N-1.
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
@@ -39,12 +41,10 @@ class EpisodeAwareSampler:
|
||||
"""
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
|
||||
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
indices.extend(
|
||||
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
||||
)
|
||||
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
|
||||
@@ -128,7 +128,7 @@ class SharpnessJitter(Transform):
|
||||
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
|
||||
|
||||
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
||||
raise ValueError(f"sharpness values should be between (0., inf), but got {sharpness}.")
|
||||
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
|
||||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
|
||||
@@ -17,18 +17,23 @@ import contextlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from collections.abc import Iterator
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import pandas
|
||||
import pandas as pd
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset, concatenate_datasets
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
@@ -42,19 +47,25 @@ from lerobot.common.datasets.backward_compatibility import (
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 500 # Max size per file
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
EPISODES_PATH = "meta/episodes.jsonl"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
EPISODES_DIR = "meta/episodes"
|
||||
DATA_DIR = "data"
|
||||
VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
@@ -75,6 +86,115 @@ DEFAULT_FEATURES = {
|
||||
}
|
||||
|
||||
|
||||
def get_parquet_file_size_in_mb(parquet_path):
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
total_uncompressed_size = 0
|
||||
for row_group in range(metadata.num_row_groups):
|
||||
rg_metadata = metadata.row_group(row_group)
|
||||
for column in range(rg_metadata.num_columns):
|
||||
col_metadata = rg_metadata.column(column)
|
||||
total_uncompressed_size += col_metadata.total_uncompressed_size
|
||||
return total_uncompressed_size / (1024**2)
|
||||
|
||||
|
||||
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
|
||||
return hf_ds.data.nbytes / (1024**2)
|
||||
|
||||
|
||||
def get_pd_dataframe_size_in_mb(df: pandas.DataFrame) -> int:
|
||||
# TODO(rcadene): unused?
|
||||
memory_usage_bytes = df.memory_usage(deep=True).sum()
|
||||
return memory_usage_bytes / (1024**2)
|
||||
|
||||
|
||||
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int):
|
||||
if file_idx == chunks_size - 1:
|
||||
file_idx = 0
|
||||
chunk_idx += 1
|
||||
else:
|
||||
file_idx += 1
|
||||
return chunk_idx, file_idx
|
||||
|
||||
|
||||
def load_nested_dataset(pq_dir: Path) -> Dataset:
|
||||
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
||||
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
||||
Concatenate all pyarrow references to return HF Dataset format
|
||||
"""
|
||||
paths = sorted(pq_dir.glob("*/*.parquet"))
|
||||
if len(paths) == 0:
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||
datasets = [Dataset.from_parquet(str(path)) for path in paths]
|
||||
return concatenate_datasets(datasets)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path):
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
return metadata.num_rows
|
||||
|
||||
|
||||
def get_video_size_in_mb(mp4_path: Path):
|
||||
file_size_bytes = mp4_path.stat().st_size
|
||||
file_size_mb = file_size_bytes / (1024**2)
|
||||
return file_size_mb
|
||||
|
||||
|
||||
def concat_video_files(paths_to_cat: list[Path], root: Path, video_key: str, chunk_idx: int, file_idx: int):
|
||||
# TODO(rcadene): move to video_utils.py
|
||||
# TODO(rcadene): add docstring
|
||||
tmp_dir = Path(tempfile.mkdtemp(dir=root))
|
||||
# Create a text file with the list of files to concatenate
|
||||
path_concat_video_files = tmp_dir / "concat_video_files.txt"
|
||||
with open(path_concat_video_files, "w") as f:
|
||||
for ep_path in paths_to_cat:
|
||||
f.write(f"file '{str(ep_path)}'\n")
|
||||
|
||||
path_tmp_output = tmp_dir / "tmp_output.mp4"
|
||||
command = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-f",
|
||||
"concat",
|
||||
"-safe",
|
||||
"0",
|
||||
"-i",
|
||||
str(path_concat_video_files),
|
||||
"-c",
|
||||
"copy",
|
||||
str(path_tmp_output),
|
||||
]
|
||||
subprocess.run(command, check=True)
|
||||
|
||||
output_path = root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(path_tmp_output), str(output_path))
|
||||
shutil.rmtree(str(tmp_dir))
|
||||
|
||||
|
||||
def get_video_duration_in_s(mp4_file: Path):
|
||||
# TODO(rcadene): move to video_utils.py
|
||||
command = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
str(mp4_file),
|
||||
]
|
||||
result = subprocess.run(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
return float(result.stdout)
|
||||
|
||||
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
|
||||
@@ -107,23 +227,13 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
return outdict
|
||||
|
||||
|
||||
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
||||
split_keys = flattened_key.split(sep)
|
||||
getter = obj[split_keys[0]]
|
||||
if len(split_keys) == 1:
|
||||
return getter
|
||||
|
||||
for key in split_keys[1:]:
|
||||
getter = getter[key]
|
||||
|
||||
return getter
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
serialized_dict = {}
|
||||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
serialized_dict[key] = value.tolist()
|
||||
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
|
||||
serialized_dict[key] = value
|
||||
elif isinstance(value, np.generic):
|
||||
serialized_dict[key] = value.item()
|
||||
elif isinstance(value, (int, float)):
|
||||
@@ -153,23 +263,6 @@ def write_json(data: dict, fpath: Path) -> None:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def load_jsonlines(fpath: Path) -> list[Any]:
|
||||
with jsonlines.open(fpath, "r") as reader:
|
||||
return list(reader)
|
||||
|
||||
|
||||
def write_jsonlines(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(data)
|
||||
|
||||
|
||||
def append_jsonlines(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "a") as writer:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path):
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
@@ -198,43 +291,42 @@ def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||
return cast_stats_to_numpy(stats)
|
||||
|
||||
|
||||
def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonlines(task_dict, local_dir / TASKS_PATH)
|
||||
def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
|
||||
if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||
raise NotImplementedError("Contact a maintainer.")
|
||||
|
||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset.to_parquet(path)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
return tasks, task_to_task_index
|
||||
def write_tasks(tasks: pandas.DataFrame, local_dir: Path):
|
||||
path = local_dir / DEFAULT_TASKS_PATH
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tasks.to_parquet(path)
|
||||
|
||||
|
||||
def write_episode(episode: dict, local_dir: Path):
|
||||
append_jsonlines(episode, local_dir / EPISODES_PATH)
|
||||
def load_tasks(local_dir: Path):
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
return tasks
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
def write_episodes(episodes: Dataset, local_dir: Path):
|
||||
if get_hf_dataset_size_in_mb(episodes) > DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||
raise NotImplementedError("Contact a maintainer.")
|
||||
|
||||
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes.to_parquet(fpath)
|
||||
|
||||
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||
# is a dictionary of stats and not an integer.
|
||||
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
||||
|
||||
|
||||
def load_episodes_stats(local_dir: Path) -> dict:
|
||||
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
|
||||
return {
|
||||
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
||||
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||
}
|
||||
def load_episodes(local_dir: Path) -> datasets.Dataset:
|
||||
episodes = load_nested_dataset(local_dir / EPISODES_DIR)
|
||||
# Select episode features/columns containing references to episode data and videos
|
||||
# (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.)
|
||||
# This is to speedup access to these data, instead of having to load episode stats.
|
||||
episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")])
|
||||
return episodes
|
||||
|
||||
|
||||
def backward_compatible_episodes_stats(
|
||||
@@ -388,6 +480,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
|
||||
|
||||
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||
# TODO(rcadene): add fps for each feature
|
||||
camera_ft = {}
|
||||
if robot.cameras:
|
||||
camera_ft = {
|
||||
@@ -441,31 +534,17 @@ def create_empty_dataset_info(
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"total_videos": 0,
|
||||
"total_chunks": 0,
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"data_files_size_in_mb": DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
"video_files_size_in_mb": DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
||||
def get_episode_data_index(
|
||||
episode_dicts: dict[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
cumulative_lengths = list(accumulate(episode_lengths.values()))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lengths[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lengths),
|
||||
}
|
||||
|
||||
|
||||
def check_timestamps_sync(
|
||||
timestamps: np.ndarray,
|
||||
episode_indices: np.ndarray,
|
||||
@@ -811,3 +890,11 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path):
|
||||
""" This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formated images are returned.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
|
||||
@@ -121,12 +121,12 @@ from safetensors.torch import load_file
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
LEGACY_EPISODES_PATH,
|
||||
LEGACY_TASKS_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
@@ -291,14 +291,12 @@ def split_parquet_by_episodes(
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
chunk_dir = "/".join(DEFAULT_DATA_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_idx
|
||||
)
|
||||
output_file = output_dir / DEFAULT_DATA_PATH.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
pq.write_table(ep_table, output_file)
|
||||
|
||||
return episode_lengths
|
||||
@@ -496,7 +494,7 @@ def convert_dataset(
|
||||
|
||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
write_jsonlines(tasks, v20_dir / LEGACY_TASKS_PATH)
|
||||
features["task_index"] = {
|
||||
"dtype": "int64",
|
||||
"shape": (1,),
|
||||
@@ -546,7 +544,7 @@ def convert_dataset(
|
||||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||
for ep_idx in episode_indices
|
||||
]
|
||||
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
||||
write_jsonlines(episodes, v20_dir / LEGACY_EPISODES_PATH)
|
||||
|
||||
# Assemble metadata v2.0
|
||||
metadata_v2_0 = {
|
||||
@@ -560,7 +558,7 @@ def convert_dataset(
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": metadata_v1["fps"],
|
||||
"splits": {"train": f"0:{total_episodes}"},
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ import logging
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.utils import LEGACY_EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
@@ -61,8 +61,8 @@ def convert_dataset(
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
||||
if (dataset.root / LEGACY_EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / LEGACY_EPISODES_STATS_PATH).unlink()
|
||||
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
|
||||
@@ -19,7 +19,7 @@ from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
from lerobot.common.datasets.utils import legacy_write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
@@ -72,7 +72,7 @@ def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||||
convert_episode_stats(dataset, ep_idx)
|
||||
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
legacy_write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
|
||||
|
||||
def check_aggregate_stats(
|
||||
|
||||
452
lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py
Normal file
@@ -0,0 +1,452 @@
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.1 to
|
||||
3.0. It will:
|
||||
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id=lerobot/pusht
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import jsonlines
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import tqdm
|
||||
from datasets import Dataset, Features, Image
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from requests import HTTPError
|
||||
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
cast_stats_to_numpy,
|
||||
concat_video_files,
|
||||
flatten_dict,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_parquet_num_frames,
|
||||
get_video_duration_in_s,
|
||||
get_video_size_in_mb,
|
||||
load_info,
|
||||
update_chunk_file_indices,
|
||||
write_episodes,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
||||
LEGACY_DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
|
||||
V21 = "v2.1"
|
||||
|
||||
|
||||
"""
|
||||
-------------------------
|
||||
OLD
|
||||
data/chunk-000/episode_000000.parquet
|
||||
|
||||
NEW
|
||||
data/chunk-000/file_000.parquet
|
||||
-------------------------
|
||||
OLD
|
||||
videos/chunk-000/CAMERA/episode_000000.mp4
|
||||
|
||||
NEW
|
||||
videos/chunk-000/file_000.mp4
|
||||
-------------------------
|
||||
OLD
|
||||
episodes.jsonl
|
||||
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
|
||||
|
||||
NEW
|
||||
meta/episodes/chunk-000/episodes_000.parquet
|
||||
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
|
||||
-------------------------
|
||||
OLD
|
||||
tasks.jsonl
|
||||
{"task_index": 1, "task": "Put the blue block in the green bowl"}
|
||||
|
||||
NEW
|
||||
meta/tasks/chunk-000/file_000.parquet
|
||||
task_index | task
|
||||
-------------------------
|
||||
OLD
|
||||
episodes_stats.jsonl
|
||||
|
||||
NEW
|
||||
meta/episodes_stats/chunk-000/file_000.parquet
|
||||
episode_index | mean | std | min | max
|
||||
-------------------------
|
||||
UPDATE
|
||||
meta/info.json
|
||||
-------------------------
|
||||
"""
|
||||
|
||||
|
||||
def load_jsonlines(fpath: Path) -> list[Any]:
|
||||
with jsonlines.open(fpath, "r") as reader:
|
||||
return list(reader)
|
||||
|
||||
|
||||
def legacy_load_episodes(local_dir: Path) -> dict:
|
||||
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
|
||||
|
||||
def legacy_load_episodes_stats(local_dir: Path) -> dict:
|
||||
episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH)
|
||||
return {
|
||||
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
||||
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||
}
|
||||
|
||||
|
||||
def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
return tasks, task_to_task_index
|
||||
|
||||
|
||||
def convert_tasks(root, new_root):
|
||||
tasks, _ = legacy_load_tasks(root)
|
||||
task_indices = tasks.keys()
|
||||
task_strings = tasks.values()
|
||||
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
|
||||
write_tasks(df_tasks, new_root)
|
||||
|
||||
|
||||
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys):
|
||||
# TODO(rcadene): to save RAM use Dataset.from_parquet(file) and concatenate_datasets
|
||||
dataframes = [pd.read_parquet(file) for file in paths_to_cat]
|
||||
# Concatenate all DataFrames along rows
|
||||
concatenated_df = pd.concat(dataframes, ignore_index=True)
|
||||
|
||||
path = new_root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(image_keys) > 0:
|
||||
schema = pa.Schema.from_pandas(concatenated_df)
|
||||
features = Features.from_arrow_schema(schema)
|
||||
for key in image_keys:
|
||||
features[key] = Image()
|
||||
schema = features.arrow_schema
|
||||
else:
|
||||
schema = None
|
||||
|
||||
concatenated_df.to_parquet(path, index=False, schema=schema)
|
||||
|
||||
|
||||
def convert_data(root, new_root):
|
||||
data_dir = root / "data"
|
||||
ep_paths = sorted(data_dir.glob("*/*.parquet"))
|
||||
|
||||
image_keys = get_image_keys(root)
|
||||
|
||||
ep_idx = 0
|
||||
chunk_idx = 0
|
||||
file_idx = 0
|
||||
size_in_mb = 0
|
||||
num_frames = 0
|
||||
paths_to_cat = []
|
||||
episodes_metadata = []
|
||||
for ep_path in ep_paths:
|
||||
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
||||
ep_num_frames = get_parquet_num_frames(ep_path)
|
||||
ep_metadata = {
|
||||
"episode_index": ep_idx,
|
||||
"data/chunk_index": chunk_idx,
|
||||
"data/file_index": file_idx,
|
||||
"dataset_from_index": num_frames,
|
||||
"dataset_to_index": num_frames + ep_num_frames,
|
||||
}
|
||||
size_in_mb += ep_size_in_mb
|
||||
num_frames += ep_num_frames
|
||||
episodes_metadata.append(ep_metadata)
|
||||
ep_idx += 1
|
||||
|
||||
if size_in_mb < DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||
paths_to_cat.append(ep_path)
|
||||
continue
|
||||
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = ep_size_in_mb
|
||||
num_frames = ep_num_frames
|
||||
paths_to_cat = [ep_path]
|
||||
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
# Write remaining data if any
|
||||
if paths_to_cat:
|
||||
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def get_video_keys(root):
|
||||
info = load_info(root)
|
||||
features = info["features"]
|
||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||
return video_keys
|
||||
|
||||
|
||||
def get_image_keys(root):
|
||||
info = load_info(root)
|
||||
features = info["features"]
|
||||
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
|
||||
return image_keys
|
||||
|
||||
|
||||
def convert_videos(root: Path, new_root: Path):
|
||||
video_keys = get_video_keys(root)
|
||||
if len(video_keys) == 0:
|
||||
return None
|
||||
|
||||
video_keys = sorted(video_keys)
|
||||
|
||||
eps_metadata_per_cam = []
|
||||
for camera in video_keys:
|
||||
eps_metadata = convert_videos_of_camera(root, new_root, camera)
|
||||
eps_metadata_per_cam.append(eps_metadata)
|
||||
|
||||
num_eps_per_cam = [len(eps_cam_map) for eps_cam_map in eps_metadata_per_cam]
|
||||
if len(set(num_eps_per_cam)) != 1:
|
||||
raise ValueError(f"All cams dont have same number of episodes ({num_eps_per_cam}).")
|
||||
|
||||
episods_metadata = []
|
||||
num_cameras = len(video_keys)
|
||||
num_episodes = num_eps_per_cam[0]
|
||||
for ep_idx in range(num_episodes):
|
||||
# Sanity check
|
||||
ep_ids = [eps_metadata_per_cam[cam_idx][ep_idx]["episode_index"] for cam_idx in range(num_cameras)]
|
||||
ep_ids += [ep_idx]
|
||||
if len(set(ep_ids)) != 1:
|
||||
raise ValueError(f"All episode indices need to match ({ep_ids}).")
|
||||
|
||||
ep_dict = {}
|
||||
for cam_idx in range(num_cameras):
|
||||
ep_dict.update(eps_metadata_per_cam[cam_idx][ep_idx])
|
||||
episods_metadata.append(ep_dict)
|
||||
|
||||
return episods_metadata
|
||||
|
||||
|
||||
def convert_videos_of_camera(root: Path, new_root: Path, video_key):
|
||||
# Access old paths to mp4
|
||||
videos_dir = root / "videos"
|
||||
ep_paths = sorted(videos_dir.glob(f"*/{video_key}/*.mp4"))
|
||||
|
||||
ep_idx = 0
|
||||
chunk_idx = 0
|
||||
file_idx = 0
|
||||
size_in_mb = 0
|
||||
duration_in_s = 0.0
|
||||
paths_to_cat = []
|
||||
episodes_metadata = []
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
||||
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
ep_metadata = {
|
||||
"episode_index": ep_idx,
|
||||
f"videos/{video_key}/chunk_index": chunk_idx,
|
||||
f"videos/{video_key}/file_index": file_idx,
|
||||
f"videos/{video_key}/from_timestamp": duration_in_s,
|
||||
f"videos/{video_key}/to_timestamp": duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
size_in_mb += ep_size_in_mb
|
||||
duration_in_s += ep_duration_in_s
|
||||
episodes_metadata.append(ep_metadata)
|
||||
ep_idx += 1
|
||||
|
||||
if size_in_mb < DEFAULT_VIDEO_FILE_SIZE_IN_MB:
|
||||
paths_to_cat.append(ep_path)
|
||||
continue
|
||||
|
||||
concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx)
|
||||
|
||||
# Reset for the next file
|
||||
size_in_mb = ep_size_in_mb
|
||||
duration_in_s = ep_duration_in_s
|
||||
paths_to_cat = [ep_path]
|
||||
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
# Write remaining videos if any
|
||||
if paths_to_cat:
|
||||
concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx)
|
||||
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
|
||||
):
|
||||
num_episodes = len(episodes_metadata)
|
||||
episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
|
||||
episodes_stats_vals = list(episodes_stats.values())
|
||||
episodes_stats_keys = list(episodes_stats.keys())
|
||||
|
||||
for i in range(num_episodes):
|
||||
ep_legacy_metadata = episodes_legacy_metadata_vals[i]
|
||||
ep_metadata = episodes_metadata[i]
|
||||
ep_stats = episodes_stats_vals[i]
|
||||
|
||||
ep_ids_set = {
|
||||
ep_legacy_metadata["episode_index"],
|
||||
ep_metadata["episode_index"],
|
||||
episodes_stats_keys[i],
|
||||
}
|
||||
|
||||
if episodes_videos is None:
|
||||
ep_video = {}
|
||||
else:
|
||||
ep_video = episodes_videos[i]
|
||||
ep_ids_set.add(ep_video["episode_index"])
|
||||
|
||||
if len(ep_ids_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
|
||||
|
||||
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
|
||||
ep_dict["meta/episodes/chunk_index"] = 0
|
||||
ep_dict["meta/episodes/file_index"] = 0
|
||||
yield ep_dict
|
||||
|
||||
|
||||
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
|
||||
episodes_legacy_metadata = legacy_load_episodes(root)
|
||||
episodes_stats = legacy_load_episodes_stats(root)
|
||||
|
||||
num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
|
||||
if episodes_video_metadata is not None:
|
||||
num_eps_set.add(len(episodes_video_metadata))
|
||||
|
||||
if len(num_eps_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
|
||||
|
||||
ds_episodes = Dataset.from_generator(
|
||||
lambda: generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
|
||||
)
|
||||
)
|
||||
write_episodes(ds_episodes, new_root)
|
||||
|
||||
stats = aggregate_stats(list(episodes_stats.values()))
|
||||
write_stats(stats, new_root)
|
||||
|
||||
|
||||
def convert_info(root, new_root):
|
||||
info = load_info(root)
|
||||
info["codebase_version"] = "v3.0"
|
||||
del info["total_chunks"]
|
||||
del info["total_videos"]
|
||||
info["data_files_size_in_mb"] = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
info["video_files_size_in_mb"] = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
info["data_path"] = DEFAULT_DATA_PATH
|
||||
info["video_path"] = DEFAULT_VIDEO_PATH
|
||||
info["fps"] = float(info["fps"])
|
||||
for key in info["features"]:
|
||||
if info["features"][key]["dtype"] == "video":
|
||||
# already has fps in video_info
|
||||
continue
|
||||
info["features"][key]["fps"] = info["fps"]
|
||||
write_info(info, new_root)
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
root = HF_LEROBOT_HOME / repo_id
|
||||
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
|
||||
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
|
||||
|
||||
if old_root.is_dir() and root.is_dir():
|
||||
shutil.rmtree(str(root))
|
||||
shutil.move(str(old_root), str(root))
|
||||
|
||||
if new_root.is_dir():
|
||||
shutil.rmtree(new_root)
|
||||
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
repo_type="dataset",
|
||||
revision=V21,
|
||||
local_dir=root,
|
||||
)
|
||||
|
||||
convert_info(root, new_root)
|
||||
convert_tasks(root, new_root)
|
||||
episodes_metadata = convert_data(root, new_root)
|
||||
episodes_videos_metadata = convert_videos(root, new_root)
|
||||
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
|
||||
|
||||
shutil.move(str(root), str(old_root))
|
||||
shutil.move(str(new_root), str(root))
|
||||
|
||||
hub_api = HfApi()
|
||||
try:
|
||||
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
except HTTPError as e:
|
||||
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
||||
pass
|
||||
hub_api.delete_files(
|
||||
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
|
||||
repo_id=repo_id,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
)
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
LeRobotDataset(repo_id).push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to push your dataset. Defaults to the main branch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dataset(**vars(args))
|
||||
@@ -13,15 +13,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import av
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
@@ -101,7 +102,7 @@ def decode_video_frames_torchvision(
|
||||
keyframes_only = False
|
||||
torchvision.set_video_backend(backend)
|
||||
if backend == "pyav":
|
||||
keyframes_only = True # pyav doesn't support accurate seek
|
||||
keyframes_only = True # pyav doesnt support accuracte seek
|
||||
|
||||
# set a video stream reader
|
||||
# TODO(rcadene): also load audio stream at the same time
|
||||
@@ -154,6 +155,7 @@ def decode_video_frames_torchvision(
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
# TODO(rcadene): remove torch.stack
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
@@ -251,83 +253,51 @@ def encode_video_frames(
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
log_level: str | None = "quiet",
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
# Check encoder availability
|
||||
if vcodec not in ["h264", "hevc", "libsvtav1"]:
|
||||
raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.")
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_path.parent.mkdir(parents=True, exist_ok=overwrite)
|
||||
|
||||
# Encoders/pixel formats incompatibility check
|
||||
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
||||
logging.warning(
|
||||
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
|
||||
)
|
||||
pix_fmt = "yuv420p"
|
||||
|
||||
# Get input frames
|
||||
template = "frame_" + ("[0-9]" * 6) + ".png"
|
||||
input_list = sorted(
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("_")[-1].split(".")[0])
|
||||
ffmpeg_args = OrderedDict(
|
||||
[
|
||||
("-f", "image2"),
|
||||
("-r", str(fps)),
|
||||
("-i", str(imgs_dir / "frame-%06d.png")),
|
||||
("-vcodec", vcodec),
|
||||
("-pix_fmt", pix_fmt),
|
||||
]
|
||||
)
|
||||
|
||||
# Define video output frame size (assuming all input frames are the same size)
|
||||
if len(input_list) == 0:
|
||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
||||
dummy_image = Image.open(input_list[0])
|
||||
width, height = dummy_image.size
|
||||
|
||||
# Define video codec options
|
||||
video_options = {}
|
||||
|
||||
if g is not None:
|
||||
video_options["g"] = str(g)
|
||||
ffmpeg_args["-g"] = str(g)
|
||||
|
||||
if crf is not None:
|
||||
video_options["crf"] = str(crf)
|
||||
ffmpeg_args["-crf"] = str(crf)
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
ffmpeg_args[key] = value
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
# "While less efficient, it is generally preferable to modify logging with Python’s logging"
|
||||
logging.getLogger("libav").setLevel(log_level)
|
||||
ffmpeg_args["-loglevel"] = str(log_level)
|
||||
|
||||
# Create and open output file (overwrite by default)
|
||||
with av.open(str(video_path), "w") as output:
|
||||
output_stream = output.add_stream(vcodec, fps, options=video_options)
|
||||
output_stream.pix_fmt = pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
||||
if overwrite:
|
||||
ffmpeg_args.append("-y")
|
||||
|
||||
# Loop through input frames and encode them
|
||||
for input_data in input_list:
|
||||
input_image = Image.open(input_data).convert("RGB")
|
||||
input_frame = av.VideoFrame.from_image(input_image)
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
|
||||
# Flush the encoder
|
||||
packet = output_stream.encode()
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
|
||||
# Reset logging level
|
||||
if log_level is not None:
|
||||
av.logging.restore_default_callback()
|
||||
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
||||
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||
|
||||
if not video_path.exists():
|
||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
||||
raise OSError(
|
||||
f"Video encoding did not work. File not found: {video_path}. "
|
||||
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -363,68 +333,78 @@ with warnings.catch_warnings():
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
ffprobe_audio_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"a:0",
|
||||
"-show_entries",
|
||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
# Getting audio stream information
|
||||
audio_info = {}
|
||||
with av.open(str(video_path), "r") as audio_file:
|
||||
try:
|
||||
audio_stream = audio_file.streams.audio[0]
|
||||
except IndexError:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
return {"has_audio": False}
|
||||
info = json.loads(result.stdout)
|
||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||
if audio_stream_info is None:
|
||||
return {"has_audio": False}
|
||||
|
||||
audio_info["audio.channels"] = audio_stream.channels
|
||||
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
||||
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
||||
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
||||
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
||||
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
||||
# In an ideal loseless case : fixed number of bits per sample.
|
||||
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
||||
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
||||
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
||||
audio_info["has_audio"] = True
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
return audio_info
|
||||
# Return the information, defaulting to None if no audio stream is present
|
||||
return {
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
ffprobe_video_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
# Getting video stream information
|
||||
video_info = {}
|
||||
with av.open(str(video_path), "r") as video_file:
|
||||
try:
|
||||
video_stream = video_file.streams.video[0]
|
||||
except IndexError:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
return {}
|
||||
info = json.loads(result.stdout)
|
||||
video_stream_info = info["streams"][0]
|
||||
|
||||
video_info["video.height"] = video_stream.height
|
||||
video_info["video.width"] = video_stream.width
|
||||
video_info["video.codec"] = video_stream.codec.canonical_name
|
||||
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
||||
video_info["video.is_depth_map"] = False
|
||||
# Calculate fps from r_frame_rate
|
||||
r_frame_rate = video_stream_info["r_frame_rate"]
|
||||
num, denom = map(int, r_frame_rate.split("/"))
|
||||
fps = num / denom
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
video_info["video.fps"] = int(video_stream.base_rate)
|
||||
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
|
||||
video_info["video.channels"] = pixel_channels
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
# Adding audio stream information
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
video_info = {
|
||||
"video.fps": fps,
|
||||
"video.height": video_stream_info["height"],
|
||||
"video.width": video_stream_info["width"],
|
||||
"video.channels": pixel_channels,
|
||||
"video.codec": video_stream_info["codec_name"],
|
||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||
"video.is_depth_map": False,
|
||||
**get_audio_info(video_path),
|
||||
}
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
@@ -15,6 +15,5 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
|
||||
@@ -27,7 +27,6 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
@@ -60,10 +59,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "smolvla":
|
||||
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -81,8 +76,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
@@ -357,7 +357,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expected by siglip
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
|
||||
@@ -516,7 +516,7 @@ class PI0FAST(nn.Module):
|
||||
interpolate_like_pi=self.config.interpolate_like_pi,
|
||||
)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expected by siglip
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("smolvla")
|
||||
@dataclass
|
||||
class SmolVLAConfig(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (512, 512)
|
||||
|
||||
# Add empty images. Used by smolvla_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
|
||||
# Decoding
|
||||
num_steps: int = 10
|
||||
|
||||
# Attention utils
|
||||
use_cache: bool = True
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = True
|
||||
train_expert_only: bool = True
|
||||
train_state_proj: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
|
||||
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
|
||||
|
||||
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
|
||||
|
||||
attention_mode: str = "cross_attn"
|
||||
|
||||
prefix_length: int = -1
|
||||
|
||||
pad_language_to: str = "longest" # "max_length"
|
||||
|
||||
num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers.
|
||||
num_vlm_layers: int = 16 # Number of layers used in the VLM (first num_vlm_layers layers)
|
||||
self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers
|
||||
expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM)
|
||||
|
||||
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
|
||||
max_period: float = 4.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.use_delta_joint_actions_aloha:
|
||||
raise NotImplementedError(
|
||||
"`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return [0]
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,802 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SmolVLA:
|
||||
|
||||
[Paper](https://huggingface.co/papers/2506.01844)
|
||||
|
||||
Designed by Hugging Face.
|
||||
|
||||
Install smolvla extra dependencies:
|
||||
```bash
|
||||
pip install -e ".[smolvla]"
|
||||
```
|
||||
|
||||
Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM,
|
||||
and an action expert.
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=smolvla \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
Example of using the smolvla pretrained model outside LeRobot training framework:
|
||||
```python
|
||||
policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import (
|
||||
Normalize,
|
||||
Unnormalize,
|
||||
)
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.common.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.common.policies.utils import (
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.common.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
||||
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
||||
return gamma1 / (gamma1 + gamma2)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
it and 0 where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
if att_masks.ndim != 2:
|
||||
raise ValueError(att_masks.ndim)
|
||||
if pad_masks.ndim != 2:
|
||||
raise ValueError(pad_masks.ndim)
|
||||
|
||||
cumsum = torch.cumsum(att_masks, dim=1)
|
||||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
att_2d_masks = att_2d_masks & pad_2d_masks
|
||||
return att_2d_masks
|
||||
|
||||
|
||||
def resize_with_pad(img, width, height, pad_value=-1):
|
||||
# assume no-op when width height fits already
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||||
|
||||
cur_height, cur_width = img.shape[2:]
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, int(height - resized_height))
|
||||
pad_width = max(0, int(width - resized_width))
|
||||
|
||||
# pad on left and top of image
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] == new_dim:
|
||||
return vector
|
||||
shape = list(vector.shape)
|
||||
current_dim = shape[-1]
|
||||
shape[-1] = new_dim
|
||||
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
||||
new_vector[..., :current_dim] = vector
|
||||
return new_vector
|
||||
|
||||
|
||||
def normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def safe_arcsin(value):
|
||||
# This ensures that the input stays within
|
||||
# [−1,1] to avoid invalid values for arcsin
|
||||
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||||
|
||||
|
||||
def aloha_gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with smolvla which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# Normalize to [0, 1].
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular(value):
|
||||
# Convert from the gripper position used by smolvla to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
class SmolVLAPolicy(PreTrainedPolicy):
|
||||
"""Wrapper class around VLAFlowMatching model to train and run inference within LeRobot."""
|
||||
|
||||
config_class = SmolVLAConfig
|
||||
name = "smolvla"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SmolVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
|
||||
self.model = VLAFlowMatching(config)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._queues = {
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
for k in batch:
|
||||
if k in self._queues:
|
||||
batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
)
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
in_episode_bound = ~actions_is_pad
|
||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
|
||||
# Remove padding
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
|
||||
# For backward pass
|
||||
loss = losses.mean()
|
||||
# For backward pass
|
||||
loss_dict["loss"] = loss
|
||||
return loss, loss_dict
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
||||
"""
|
||||
images = []
|
||||
img_masks = []
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||
)
|
||||
# Preprocess image features present in the batch
|
||||
for key in present_img_keys:
|
||||
img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key]
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
if f"{key}_padding_mask" in batch:
|
||||
mask = batch[f"{key}_padding_mask"].bool()
|
||||
else:
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
# Create image features not present in the batch
|
||||
# as fully 0 padded images.
|
||||
for num_empty_cameras in range(len(missing_img_keys)):
|
||||
if num_empty_cameras >= self.config.empty_cameras:
|
||||
break
|
||||
img = torch.ones_like(img) * -1
|
||||
mask = torch.zeros_like(mask)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_ROBOT].device
|
||||
tasks = batch["task"]
|
||||
if len(tasks) == 1:
|
||||
tasks = [tasks[0] for _ in range(batch[OBS_ROBOT].shape[0])]
|
||||
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding=self.config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
state[:, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||||
return state
|
||||
|
||||
def _pi_aloha_encode_actions(self, actions):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
# Flip the joints again.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def prepare_state(self, batch):
|
||||
"""Pad state"""
|
||||
state = batch[OBS_ROBOT][:, -1, :] if batch[OBS_ROBOT].ndim > 2 else batch[OBS_ROBOT]
|
||||
state = pad_vector(state, self.config.max_state_dim)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch):
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
|
||||
def pad_tensor(tensor, max_len, pad_value=0):
|
||||
"""
|
||||
Efficiently pads a tensor along sequence dimension to match max_len.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Shape (B, L, ...) or (B, L).
|
||||
max_len (int): Fixed sequence length.
|
||||
pad_value (int/float): Value for padding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Shape (B, max_len, ...) or (B, max_len).
|
||||
"""
|
||||
b, d = tensor.shape[:2]
|
||||
|
||||
# Create a padded tensor of max_len and copy the existing values
|
||||
padded_tensor = torch.full(
|
||||
(b, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device
|
||||
)
|
||||
padded_tensor[:, :d] = tensor # Efficient in-place copy
|
||||
|
||||
return padded_tensor
|
||||
|
||||
|
||||
class VLAFlowMatching(nn.Module):
|
||||
"""
|
||||
SmolVLA
|
||||
|
||||
[Paper]()
|
||||
|
||||
Designed by Hugging Face.
|
||||
┌──────────────────────────────┐
|
||||
│ actions │
|
||||
│ ▲ │
|
||||
│ ┌─────────┐ ┌─|────┐ │
|
||||
│ | │────► │ │ │
|
||||
│ | │ kv │ │ │
|
||||
│ | │────► │Action│ │
|
||||
│ | VLM │cache │Expert│ |
|
||||
│ │ │────► | │ │
|
||||
│ │ │ │ │ │
|
||||
│ └▲──▲───▲─┘ └───▲──┘ |
|
||||
│ │ | | │ |
|
||||
│ | | | noise │
|
||||
│ │ │ state │
|
||||
│ │ language tokens │
|
||||
│ image(s) │
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.vlm_with_expert = SmolVLMWithExpertModel(
|
||||
model_id=self.config.vlm_model_name,
|
||||
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
||||
train_expert_only=self.config.train_expert_only,
|
||||
load_vlm_weights=self.config.load_vlm_weights,
|
||||
attention_mode=self.config.attention_mode,
|
||||
num_expert_layers=self.config.num_expert_layers,
|
||||
num_vlm_layers=self.config.num_vlm_layers,
|
||||
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
||||
expert_width_multiplier=self.config.expert_width_multiplier,
|
||||
)
|
||||
self.state_proj = nn.Linear(
|
||||
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
|
||||
)
|
||||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size)
|
||||
self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim)
|
||||
|
||||
self.action_time_mlp_in = nn.Linear(
|
||||
self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size
|
||||
)
|
||||
self.action_time_mlp_out = nn.Linear(
|
||||
self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size
|
||||
)
|
||||
|
||||
self.set_requires_grad()
|
||||
self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id
|
||||
self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id
|
||||
self.global_image_start_token = torch.tensor(
|
||||
[self.fake_image_token, self.global_image_token], dtype=torch.long
|
||||
)
|
||||
|
||||
self.add_image_special_tokens = self.config.add_image_special_tokens
|
||||
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
|
||||
self.prefix_length = self.config.prefix_length
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
params.requires_grad = self.config.train_state_proj
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
||||
for SmolVLM transformer processing.
|
||||
"""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
for _img_idx, (
|
||||
img,
|
||||
img_mask,
|
||||
) in enumerate(zip(images, img_masks, strict=False)):
|
||||
if self.add_image_special_tokens:
|
||||
image_start_token = (
|
||||
self.vlm_with_expert.embed_language_tokens(
|
||||
self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(img.shape[0], -1, -1)
|
||||
)
|
||||
image_start_mask = torch.ones_like(
|
||||
image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device
|
||||
)
|
||||
att_masks += [0] * (image_start_mask.shape[-1])
|
||||
embs.append(image_start_token)
|
||||
pad_masks.append(image_start_mask)
|
||||
|
||||
img_emb = self.vlm_with_expert.embed_image(img)
|
||||
img_emb = img_emb
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask)
|
||||
|
||||
att_masks += [0] * (num_img_embs)
|
||||
if self.add_image_special_tokens:
|
||||
image_end_token = (
|
||||
self.vlm_with_expert.embed_language_tokens(
|
||||
self.image_end_token.to(device=self.vlm_with_expert.vlm.device)
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(img.shape[0], -1, -1)
|
||||
)
|
||||
image_end_mask = torch.ones_like(
|
||||
image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device
|
||||
)
|
||||
embs.append(image_end_token)
|
||||
pad_masks.append(image_end_mask)
|
||||
att_masks += [0] * (image_end_mask.shape[1])
|
||||
lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens)
|
||||
# Normalize language embeddings
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
|
||||
state_emb = self.state_proj(state)
|
||||
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
|
||||
embs.append(state_emb)
|
||||
bsize = state_emb.shape[0]
|
||||
device = state_emb.device
|
||||
|
||||
states_seq_len = state_emb.shape[1]
|
||||
state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device)
|
||||
pad_masks.append(state_mask)
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
att_masks += [1] * (states_seq_len)
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
att_masks = att_masks[None, :]
|
||||
|
||||
seq_len = pad_masks.shape[1]
|
||||
if seq_len < self.prefix_length:
|
||||
embs = pad_tensor(embs, self.prefix_length, pad_value=0)
|
||||
pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0)
|
||||
att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0)
|
||||
|
||||
att_masks = att_masks.expand(bsize, -1)
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, noisy_actions, timestep):
|
||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# Fuse timestep + action information using an MLP
|
||||
action_emb = self.action_in_proj(noisy_actions)
|
||||
device = action_emb.device
|
||||
bsize = action_emb.shape[0]
|
||||
dtype = action_emb.dtype
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.vlm_with_expert.expert_hidden_size,
|
||||
self.config.min_period,
|
||||
self.config.max_period,
|
||||
device=device,
|
||||
)
|
||||
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
||||
action_time_emb = F.silu(action_time_emb) # swish == silu
|
||||
action_time_emb = self.action_time_mlp_out(action_time_emb)
|
||||
|
||||
# Add to input tokens
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
att_masks += [1] * self.config.chunk_size
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, time)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
(_, suffix_out), _ = self.vlm_with_expert.forward(
|
||||
attention_mask=att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
# Original openpi code, upcast attention output
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
# Compute image and language key value cache
|
||||
_, past_key_values = self.vlm_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
dt = -1.0 / self.config.num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
time += dt
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
outputs_embeds, _ = self.vlm_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[None, suffix_embs],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
return v_t
|
||||
@@ -1,550 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
SmolVLMForConditionalGeneration,
|
||||
)
|
||||
|
||||
|
||||
def apply_rope(x, positions, max_wavelength=10_000):
|
||||
"""
|
||||
Applies RoPE positions [B, L] to x [B, L, H, D].
|
||||
"""
|
||||
d_half = x.shape[-1] // 2
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
sin = torch.sin(radians) # .to(dtype=dtype)
|
||||
cos = torch.cos(radians) # .to(dtype=dtype)
|
||||
|
||||
x1, x2 = x.split(d_half, dim=-1)
|
||||
res = torch.empty_like(x)
|
||||
res[..., :d_half] = x1 * cos - x2 * sin
|
||||
res[..., d_half:] = x2 * cos + x1 * sin
|
||||
|
||||
return res.to(dtype)
|
||||
|
||||
|
||||
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
return hidden_dim
|
||||
|
||||
|
||||
class SmolVLMWithExpertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
|
||||
load_vlm_weights: bool = True,
|
||||
train_expert_only: bool = True,
|
||||
freeze_vision_encoder: bool = False,
|
||||
attention_mode: str = "self_attn",
|
||||
num_expert_layers: int = -1,
|
||||
num_vlm_layers: int = -1,
|
||||
self_attn_every_n_layers: int = -1,
|
||||
expert_width_multiplier: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
if load_vlm_weights:
|
||||
print(f"Loading {model_id} weights ...")
|
||||
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype="bfloat16",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
config = self.vlm.config
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
self.vlm = SmolVLMForConditionalGeneration(config=config)
|
||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||
if num_vlm_layers > 0:
|
||||
print(f"Reducing the number of VLM layers to {num_vlm_layers} ...")
|
||||
self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers]
|
||||
self.num_vlm_layers = len(self.get_vlm_model().text_model.layers)
|
||||
self.config = config
|
||||
# Smaller lm expert
|
||||
lm_expert_config = copy.deepcopy(config.text_config)
|
||||
hidden_size = lm_expert_config.hidden_size
|
||||
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
|
||||
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
|
||||
lm_expert_config.num_hidden_layers = self.num_vlm_layers
|
||||
if num_expert_layers > 0:
|
||||
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
|
||||
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
|
||||
)
|
||||
lm_expert_config.num_hidden_layers = num_expert_layers
|
||||
self.lm_expert = AutoModel.from_config(lm_expert_config)
|
||||
|
||||
self.num_expert_layers = len(self.lm_expert.layers)
|
||||
self.self_attn_every_n_layers = self_attn_every_n_layers
|
||||
if "cross" in attention_mode:
|
||||
# Reshape qkv projections to have the same input dimension as the vlm
|
||||
for layer_idx in range(len(self.lm_expert.layers)):
|
||||
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
|
||||
continue
|
||||
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
# Remove unused embed_tokens
|
||||
self.lm_expert.embed_tokens = None
|
||||
|
||||
self.num_attention_heads = self.config.text_config.num_attention_heads
|
||||
self.num_key_value_heads = self.config.text_config.num_key_value_heads
|
||||
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
self.attention_mode = attention_mode
|
||||
self.expert_hidden_size = lm_expert_config.hidden_size
|
||||
self.set_requires_grad()
|
||||
|
||||
def get_vlm_model(self):
|
||||
return self.vlm.model
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.get_vlm_model().vision_model.eval()
|
||||
for params in self.get_vlm_model().vision_model.parameters():
|
||||
params.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.vlm.eval()
|
||||
for params in self.vlm.parameters():
|
||||
params.requires_grad = False
|
||||
else:
|
||||
# To avoid unused params issue with distributed training
|
||||
last_layers = [self.num_vlm_layers - 1]
|
||||
if (
|
||||
self.num_vlm_layers != self.num_expert_layers
|
||||
and self.num_vlm_layers % self.num_expert_layers == 0
|
||||
):
|
||||
last_layers.append(self.num_vlm_layers - 2)
|
||||
frozen_layers = [
|
||||
"lm_head",
|
||||
"text_model.model.norm.weight",
|
||||
]
|
||||
for layer in last_layers:
|
||||
frozen_layers.append(f"text_model.model.layers.{layer}.")
|
||||
|
||||
for name, params in self.vlm.named_parameters():
|
||||
if any(k in name for k in frozen_layers):
|
||||
params.requires_grad = False
|
||||
# To avoid unused params issue with distributed training
|
||||
for name, params in self.lm_expert.named_parameters():
|
||||
if "lm_head" in name:
|
||||
params.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
|
||||
if self.freeze_vision_encoder:
|
||||
self.get_vlm_model().vision_model.eval()
|
||||
|
||||
if self.train_expert_only:
|
||||
self.vlm.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
patch_attention_mask = None
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = (
|
||||
self.get_vlm_model()
|
||||
.vision_model(
|
||||
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
.last_hidden_state
|
||||
)
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
|
||||
return image_hidden_states
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = model_layers[i][layer_idx]
|
||||
if hidden_states is None or layer is None:
|
||||
continue
|
||||
hidden_states = layer.input_layernorm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
# concatenate on the number of embeddings/tokens
|
||||
query_states = torch.cat(query_states, dim=1)
|
||||
key_states = torch.cat(key_states, dim=1)
|
||||
value_states = torch.cat(value_states, dim=1)
|
||||
seq_len = query_states.shape[1]
|
||||
if seq_len < position_ids.shape[1]:
|
||||
_position_ids = position_ids[:, :seq_len]
|
||||
_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||
else:
|
||||
_position_ids = position_ids
|
||||
_attention_mask = attention_mask
|
||||
|
||||
attention_mask_ = _attention_mask
|
||||
position_ids_ = _position_ids
|
||||
|
||||
query_states = apply_rope(query_states, position_ids_)
|
||||
key_states = apply_rope(key_states, position_ids_)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
att_output = attention_interface(
|
||||
attention_mask_, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
return [att_output], past_key_values
|
||||
|
||||
def forward_cross_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
att_outputs = []
|
||||
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
|
||||
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
|
||||
)
|
||||
|
||||
if len(inputs_embeds) == 2 and not past_key_values:
|
||||
# Prefix attention
|
||||
seq_len = inputs_embeds[0].shape[1]
|
||||
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
|
||||
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||
|
||||
layer = model_layers[0][layer_idx]
|
||||
|
||||
hidden_states = layer.input_layernorm(inputs_embeds[0])
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
query_states = apply_rope(query_state, position_id)
|
||||
key_states = apply_rope(key_state, position_id)
|
||||
|
||||
att_output = attention_interface(
|
||||
prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_outputs.append(att_output)
|
||||
else:
|
||||
expert_position_id = position_ids
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = past_key_values[layer_idx]["key_states"]
|
||||
value_states = past_key_values[layer_idx]["value_states"]
|
||||
|
||||
# Expert
|
||||
expert_layer = model_layers[1][layer_idx]
|
||||
if expert_layer is not None:
|
||||
expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1])
|
||||
|
||||
expert_input_shape = expert_hidden_states.shape[:-1]
|
||||
expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim)
|
||||
|
||||
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
|
||||
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
|
||||
|
||||
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
|
||||
*key_states.shape[:2], -1
|
||||
)
|
||||
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
|
||||
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
) # k_proj should have same dim as kv
|
||||
|
||||
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
|
||||
*value_states.shape[:2], -1
|
||||
)
|
||||
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
|
||||
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
)
|
||||
|
||||
expert_position_id = (
|
||||
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
|
||||
) # start from 0
|
||||
expert_attention_mask = attention_mask[
|
||||
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
|
||||
] # take into account kv
|
||||
|
||||
expert_query_states = apply_rope(expert_query_state, expert_position_id)
|
||||
|
||||
att_output = attention_interface(
|
||||
expert_attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
expert_query_states,
|
||||
expert_key_states,
|
||||
expert_value_states,
|
||||
)
|
||||
att_outputs.append(att_output)
|
||||
else:
|
||||
att_outputs.append(None)
|
||||
|
||||
# att_output = att_output.to(dtype=models[i].dtype)
|
||||
return att_outputs, past_key_values
|
||||
|
||||
def get_model_layers(self, models: list) -> list:
|
||||
vlm_layers = []
|
||||
expert_layers = []
|
||||
multiple_of = self.num_vlm_layers // self.num_expert_layers
|
||||
for i in range(self.num_vlm_layers):
|
||||
if multiple_of > 0 and i > 0 and i % multiple_of != 0:
|
||||
expert_layer = None
|
||||
else:
|
||||
expert_layer_index = i // multiple_of if multiple_of > 0 else i
|
||||
expert_layer = models[1].layers[expert_layer_index]
|
||||
vlm_layers.append(models[0].layers[i])
|
||||
expert_layers.append(expert_layer)
|
||||
return [vlm_layers, expert_layers]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: List[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
fill_kv_cache: Optional[bool] = None,
|
||||
):
|
||||
models = [self.get_vlm_model().text_model, self.lm_expert]
|
||||
model_layers = self.get_model_layers(models)
|
||||
for hidden_states in inputs_embeds:
|
||||
# TODO this is very inefficient
|
||||
# dtype is always the same, batch size too (if > 1 len)
|
||||
# device could be trickier in multi gpu edge cases but that's it
|
||||
if hidden_states is None:
|
||||
continue
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# RMSNorm
|
||||
num_layers = self.num_vlm_layers
|
||||
head_dim = self.vlm.config.text_config.head_dim
|
||||
for layer_idx in range(num_layers):
|
||||
if (
|
||||
fill_kv_cache
|
||||
or "cross" not in self.attention_mode
|
||||
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
|
||||
):
|
||||
att_outputs, past_key_values = self.forward_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
att_outputs, past_key_values = self.forward_cross_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
outputs_embeds = []
|
||||
start = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = model_layers[i][layer_idx]
|
||||
att_output = (
|
||||
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
|
||||
) # in case of self_attn
|
||||
if hidden_states is not None:
|
||||
if layer is None:
|
||||
outputs_embeds.append(hidden_states)
|
||||
continue
|
||||
end = start + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
att_out = att_output[:, start:end]
|
||||
out_emb = layer.self_attn.o_proj(att_out)
|
||||
|
||||
out_emb += hidden_states
|
||||
after_first_residual = out_emb.clone()
|
||||
|
||||
out_emb = layer.post_attention_layernorm(out_emb)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
|
||||
out_emb += after_first_residual
|
||||
|
||||
outputs_embeds.append(out_emb)
|
||||
|
||||
start = end if len(att_outputs) == 1 else 0
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
inputs_embeds = outputs_embeds
|
||||
|
||||
# final norm
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is not None:
|
||||
out_emb = models[i].norm(hidden_states)
|
||||
outputs_embeds.append(out_emb)
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
return outputs_embeds, past_key_values
|
||||
|
||||
def get_attention_interface(self):
|
||||
attention_interface = self.eager_attention_forward
|
||||
return attention_interface
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
num_att_heads = self.num_attention_heads
|
||||
num_key_value_heads = self.num_key_value_heads
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
query_states = query_states.to(dtype=torch.float32)
|
||||
key_states = key_states.to(dtype=torch.float32)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
att_weights *= head_dim**-0.5
|
||||
|
||||
att_weights = att_weights.to(dtype=torch.float32)
|
||||
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
|
||||
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
|
||||
return att_output
|
||||
@@ -109,10 +109,6 @@ def predict_action(observation, policy, device, use_amp):
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
# Skip all observations that are not tensors (e.g. text)
|
||||
if not isinstance(observation[name], torch.Tensor):
|
||||
continue
|
||||
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
@@ -247,11 +243,6 @@ def control_loop(
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
# Controls starts, if policy is given it needs cleaning up
|
||||
if policy is not None:
|
||||
policy.reset()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
@@ -259,9 +250,7 @@ def control_loop(
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
action = None
|
||||
observation["task"] = [single_task]
|
||||
observation["robot_type"] = [policy.robot_type] if hasattr(policy, "robot_type") else [""]
|
||||
|
||||
if policy is not None:
|
||||
pred_action = predict_action(
|
||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
@@ -277,10 +266,9 @@ def control_loop(
|
||||
|
||||
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon)
|
||||
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")):
|
||||
if action is not None:
|
||||
for k, v in action.items():
|
||||
for i, vv in enumerate(v):
|
||||
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
|
||||
for k, v in action.items():
|
||||
for i, vv in enumerate(v):
|
||||
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy()))
|
||||
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
|
||||
@@ -431,69 +431,6 @@ class MossRobotConfig(ManipulatorRobotConfig):
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so101")
|
||||
@dataclass
|
||||
class So101RobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/so101"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431091",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"laptop": OpenCVCameraConfig(
|
||||
camera_index=0,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"phone": OpenCVCameraConfig(
|
||||
camera_index=1,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so100")
|
||||
@dataclass
|
||||
class So100RobotConfig(ManipulatorRobotConfig):
|
||||
|
||||
@@ -36,12 +36,6 @@ ZERO_POSITION_DEGREE = 0
|
||||
ROTATED_POSITION_DEGREE = 90
|
||||
|
||||
|
||||
def reset_middle_positions(arm: MotorsBus):
|
||||
input("Please move the robot to the new middle position for calibration, then press Enter...")
|
||||
# Write 128 to Torque_Enable for all motors.
|
||||
arm.write("Torque_Enable", 128)
|
||||
|
||||
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
@@ -445,8 +439,6 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
reset_middle_positions(arm)
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
@@ -243,7 +243,7 @@ class ManipulatorRobot:
|
||||
|
||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
elif self.robot_type in ["so100", "so101", "moss", "lekiwi"]:
|
||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||
|
||||
# We assume that at connection time, arms are in a rest position, and torque can
|
||||
@@ -260,7 +260,7 @@ class ManipulatorRobot:
|
||||
self.set_koch_robot_preset()
|
||||
elif self.robot_type == "aloha":
|
||||
self.set_aloha_robot_preset()
|
||||
elif self.robot_type in ["so100", "so101", "moss", "lekiwi"]:
|
||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||
self.set_so100_robot_preset()
|
||||
|
||||
# Enable torque on all motors of the follower arms
|
||||
@@ -313,7 +313,7 @@ class ManipulatorRobot:
|
||||
|
||||
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
||||
|
||||
elif self.robot_type in ["so100", "so101", "moss", "lekiwi"]:
|
||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
|
||||
@@ -23,7 +23,6 @@ from lerobot.common.robot_devices.robots.configs import (
|
||||
MossRobotConfig,
|
||||
RobotConfig,
|
||||
So100RobotConfig,
|
||||
So101RobotConfig,
|
||||
StretchRobotConfig,
|
||||
)
|
||||
|
||||
@@ -59,8 +58,6 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
||||
return MossRobotConfig(**kwargs)
|
||||
elif robot_type == "so100":
|
||||
return So100RobotConfig(**kwargs)
|
||||
elif robot_type == "so101":
|
||||
return So101RobotConfig(**kwargs)
|
||||
elif robot_type == "stretch":
|
||||
return StretchRobotConfig(**kwargs)
|
||||
elif robot_type == "lekiwi":
|
||||
|
||||
@@ -228,3 +228,13 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
||||
except TypeError:
|
||||
# If a TypeError is raised, the string is not a valid dtype
|
||||
return False
|
||||
|
||||
|
||||
def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float):
|
||||
days = int(elapsed_time_s // (24 * 3600))
|
||||
elapsed_time_s %= 24 * 3600
|
||||
hours = int(elapsed_time_s // 3600)
|
||||
elapsed_time_s %= 3600
|
||||
minutes = int(elapsed_time_s // 60)
|
||||
seconds = elapsed_time_s % 60
|
||||
return days, hours, minutes, seconds
|
||||
|
||||
@@ -94,8 +94,8 @@ def rollout(
|
||||
data will probably need to be discarded (for environments that aren't the first one to be done).
|
||||
|
||||
The return dictionary contains:
|
||||
(optional) "observation": A dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
||||
keys. NOTE that this has an extra sequence element relative to the other keys in the
|
||||
(optional) "observation": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
||||
keys. NOTE the that this has an extra sequence element relative to the other keys in the
|
||||
dictionary. This is because an extra observation is included for after the environment is
|
||||
terminated or truncated.
|
||||
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
|
||||
|
||||
@@ -166,7 +166,8 @@ def train(cfg: TrainPipelineConfig):
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.episode_data_index,
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
@@ -79,8 +79,8 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
@@ -283,7 +283,7 @@ def main():
|
||||
tolerance_s = kwargs.pop("tolerance_s")
|
||||
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
|
||||
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
|
||||
|
||||
visualize_dataset(dataset, **vars(args))
|
||||
|
||||
|
||||
@@ -271,8 +271,8 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
selected_columns.insert(0, "timestamp")
|
||||
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
data = (
|
||||
dataset.hf_dataset.select(range(from_idx, to_idx))
|
||||
.select_columns(selected_columns)
|
||||
@@ -308,7 +308,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
|
||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# get first frame of episode (hack to get video_path of the episode)
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
first_frame_idx = dataset.meta.episodes["dataset_from_index"][ep_index]
|
||||
return [
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
||||
for key in dataset.meta.video_keys
|
||||
@@ -321,7 +321,7 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
|
||||
return None
|
||||
|
||||
# get first frame index
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
first_frame_idx = dataset.meta.episodes["dataset_from_index"][ep_index]
|
||||
|
||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||
|
||||
|
Before Width: | Height: | Size: 64 KiB |
|
Before Width: | Height: | Size: 41 KiB |
|
Before Width: | Height: | Size: 45 KiB |
|
Before Width: | Height: | Size: 64 KiB |
|
Before Width: | Height: | Size: 35 KiB |
|
Before Width: | Height: | Size: 39 KiB |
|
Before Width: | Height: | Size: 38 KiB |
|
Before Width: | Height: | Size: 30 KiB |
|
Before Width: | Height: | Size: 151 KiB |
|
Before Width: | Height: | Size: 130 KiB |
@@ -49,7 +49,7 @@ dependencies = [
|
||||
"datasets>=2.19.0",
|
||||
"deepdiff>=7.0.1",
|
||||
"diffusers>=0.27.2",
|
||||
"draccus==0.10.0",
|
||||
"draccus>=0.10.0",
|
||||
"einops>=0.8.0",
|
||||
"flask>=3.0.3",
|
||||
"gdown>=5.1.0",
|
||||
@@ -62,8 +62,8 @@ dependencies = [
|
||||
"omegaconf>=2.3.0",
|
||||
"opencv-python-headless>=4.9.0",
|
||||
"packaging>=24.2",
|
||||
"av>=14.2.0",
|
||||
"pymunk>=6.6.0,<7.0.0",
|
||||
"av>=12.0.5",
|
||||
"pymunk>=6.6.0",
|
||||
"pynput>=1.7.7",
|
||||
"pyzmq>=26.2.1",
|
||||
"rerun-sdk>=0.21.0",
|
||||
@@ -77,7 +77,6 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"]
|
||||
docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"]
|
||||
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"]
|
||||
dora = [
|
||||
"gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'",
|
||||
@@ -86,7 +85,6 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"]
|
||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||
stretch = [
|
||||
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
||||
|
||||
@@ -47,17 +47,23 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
)
|
||||
|
||||
# save 2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
i = dataset.meta.episodes["dataset_from_index"][0].item()
|
||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
||||
|
||||
# save 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
i = int(
|
||||
(
|
||||
dataset.meta.episodes["dataset_to_index"][0].item()
|
||||
- dataset.meta.episodes["dataset_from_index"][0].item()
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
||||
|
||||
# save 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
i = dataset.meta.episodes["dataset_to_index"][0].item()
|
||||
save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors")
|
||||
save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors")
|
||||
|
||||
@@ -65,17 +71,17 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# # save 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# i = dataset.meta.episodes["dataset_from_index"][1].item()
|
||||
# save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
# save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# # save 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# i = dataset.meta.episodes["dataset_to_index"][1].item()
|
||||
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
# # save 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# i = dataset.meta.episodes["dataset_to_index"][-1].item()
|
||||
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6b1e600768a8771c5fe650e038a1193597e3810f032041b2a0d021e4496381c1
|
||||
oid sha256:0389a716d51c1c615fb2a3bfa386d89f00b0deca08c4fa21b23e020a939d0213
|
||||
size 3686488
|
||||
|
||||
@@ -28,7 +28,7 @@ from lerobot.common.datasets.transforms import (
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
|
||||
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
|
||||
DATASET_REPO_ID = "lerobot/aloha_static_cups_open"
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
|
||||
|
||||
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9d4ebab73eabddc58879a4e770289d19e00a1a4cf2fa5fa33cd3a3246992bc90
|
||||
oid sha256:0dc691503e7d90b2086bb408e89a65f772ce5ee6e3562ef8c127bcb09bd90851
|
||||
size 40551392
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77
|
||||
oid sha256:cc67af1d60f95d84c98d6c9ebd648990e0f0705368bd6b72d2b39533950b0179
|
||||
size 5104
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1a7a8b1a457149109f843c32bcbb047d09de2201847b9b79f7501b447f77ecf4
|
||||
oid sha256:64518cf652105d15f5fd2cfc13d0681f66a4ec4797dc5d5dc2f7b0d91fe5dfd6
|
||||
size 31672
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5e6ce85296b2009e7c2060d336c0429b1c7197d9adb159e7df0ba18003067b36
|
||||
oid sha256:32b6d14fab4244b5140adb345e47f662b6739c04974e04b21c3127caa988abbb
|
||||
size 68
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603
|
||||
oid sha256:e1904ef0338f7b6efdec70ec235ee931b5751008bf4eb433edb0b3fa0838a4f1
|
||||
size 33400
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b
|
||||
oid sha256:fa544a97f00bf46393a09b006b44c2499bbf7d177782360a8c21cacbf200c07a
|
||||
size 515400
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:224b5fa4828aa88171b68c036e8919c1eae563e2113f03b6461eadf5bf8525a6
|
||||
oid sha256:83c7a8ae912300b5cedba31904f7ba22542059fd60dd86548a95e415713f719e
|
||||
size 31672
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:016d2fa8fe5f58017dfd46f4632fdc19dfd751e32a2c7cde2077c6f95546d6bd
|
||||
oid sha256:5a010633237b3a1141603c65174c551daa9e7b4c474af5a1376d73e5425bfb5d
|
||||
size 68
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075
|
||||
oid sha256:ec8b5c440e9fcec190c9be48b28ebb79f82ae63626afe7c811e4bb0c3dd08842
|
||||
size 33400
|
||||
|
||||
29
tests/datasets/test_aggregate.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from lerobot.common.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_0",
|
||||
total_episodes=10,
|
||||
total_frames=400,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_1",
|
||||
total_episodes=10,
|
||||
total_frames=400,
|
||||
)
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
|
||||
aggr_root=tmp_path / "test_aggr",
|
||||
)
|
||||
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
||||
for _ in aggr_ds:
|
||||
pass
|
||||
@@ -13,10 +13,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
@@ -36,8 +34,6 @@ from lerobot.common.datasets.lerobot_dataset import (
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.envs.factory import make_env_config
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
@@ -75,7 +71,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
|
||||
|
||||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
@@ -100,6 +96,25 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
|
||||
assert dataset.num_frames == len(dataset)
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): do not run LeRobotDataset.create, instead refactor LeRobotDatasetMetadata.create
|
||||
# and test the small resulting function that validates the features
|
||||
def test_dataset_feature_with_forward_slash_raises_error():
|
||||
# make sure dir does not exist
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
|
||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||
# make sure does not exist
|
||||
if dataset_dir.exists():
|
||||
dataset_dir.rmdir()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LeRobotDataset.create(
|
||||
repo_id="lerobot/test/with/slash",
|
||||
fps=30,
|
||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
@@ -329,6 +344,13 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
||||
# - [ ] test push_to_hub
|
||||
# - [ ] test smaller methods
|
||||
|
||||
# TODO(rcadene):
|
||||
# - [ ] fix code so that old test_factory + backward pass
|
||||
# - [ ] write new unit tests to test save_episode + getitem
|
||||
# - [ ] save_episode : case where new dataset, concatenate same file, write new file (meta/episodes, data, videos)
|
||||
# - [ ]
|
||||
# - [ ] remove old tests
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
@@ -436,30 +458,6 @@ def test_multidataset_frames():
|
||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
|
||||
# TODO(aliberts): Move to more appropriate location
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"mean": 2,
|
||||
"std": 3,
|
||||
},
|
||||
"action": {
|
||||
"min": 4,
|
||||
"max": 5,
|
||||
"mean": 6,
|
||||
"std": 7,
|
||||
},
|
||||
}
|
||||
|
||||
original_d = deepcopy(d)
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
[
|
||||
@@ -506,17 +504,23 @@ def test_backward_compatibility(repo_id):
|
||||
)
|
||||
|
||||
# test2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
i = dataset.meta.episodes["dataset_from_index"][0].item()
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
i = int(
|
||||
(
|
||||
dataset.meta.episodes["dataset_to_index"][0].item()
|
||||
- dataset.meta.episodes["dataset_from_index"][0].item()
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
i = dataset.meta.episodes["dataset_to_index"][0].item()
|
||||
load_and_compare(i - 2)
|
||||
load_and_compare(i - 1)
|
||||
|
||||
@@ -524,17 +528,17 @@ def test_backward_compatibility(repo_id):
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# # test 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# i = dataset.meta.episodes["dataset_from_index"][1].item()
|
||||
# load_and_compare(i)
|
||||
# load_and_compare(i + 1)
|
||||
|
||||
# # test 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# i = dataset.meta.episodes["dataset_to_index"][1].item()
|
||||
# load_and_compare(i - 2)
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
# # test 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# i = dataset.meta.episodes["dataset_to_index"][-1].item()
|
||||
# load_and_compare(i - 2)
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
@@ -563,20 +567,3 @@ def test_create_branch():
|
||||
|
||||
# Clean
|
||||
api.delete_repo(repo_id, repo_type=repo_type)
|
||||
|
||||
|
||||
def test_dataset_feature_with_forward_slash_raises_error():
|
||||
# make sure dir does not exist
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
|
||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||
# make sure does not exist
|
||||
if dataset_dir.exists():
|
||||
dataset_dir.rmdir()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LeRobotDataset.create(
|
||||
repo_id="lerobot/test/with/slash",
|
||||
fps=30,
|
||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
||||
)
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
@@ -254,14 +253,7 @@ def test_backward_compatibility_single_transforms(
|
||||
|
||||
|
||||
@require_x86_64_kernel
|
||||
@pytest.mark.skipif(
|
||||
version.parse(torch.__version__) < version.parse("2.7.0"),
|
||||
reason="Test artifacts were generated with PyTorch >= 2.7.0 which has different multinomial behavior",
|
||||
)
|
||||
def test_backward_compatibility_default_config(img_tensor, default_transforms):
|
||||
# NOTE: PyTorch versions have different randomness, it might break this test.
|
||||
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
|
||||
|
||||
cfg = ImageTransformsConfig(enable=True)
|
||||
default_tf = ImageTransforms(cfg)
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ def test_drop_n_first_frames():
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
|
||||
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_first_frames=1)
|
||||
assert sampler.indices == [1, 4, 5]
|
||||
assert len(sampler) == 3
|
||||
assert list(sampler) == [1, 4, 5]
|
||||
@@ -48,7 +48,7 @@ def test_drop_n_last_frames():
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
|
||||
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_last_frames=1)
|
||||
assert sampler.indices == [0, 3, 4]
|
||||
assert len(sampler) == 3
|
||||
assert list(sampler) == [0, 3, 4]
|
||||
@@ -64,7 +64,9 @@ def test_episode_indices_to_use():
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
|
||||
sampler = EpisodeAwareSampler(
|
||||
episode_data_index["from"], episode_data_index["to"], episode_indices_to_use=[0, 2]
|
||||
)
|
||||
assert sampler.indices == [0, 1, 3, 4, 5]
|
||||
assert len(sampler) == 5
|
||||
assert list(sampler) == [0, 1, 3, 4, 5]
|
||||
@@ -80,11 +82,11 @@ def test_shuffle():
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
|
||||
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=False)
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert list(sampler) == [0, 1, 2, 3, 4, 5]
|
||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
|
||||
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=True)
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
@@ -14,12 +14,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
hf_transform_to_torch,
|
||||
unflatten_dict,
|
||||
)
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
@@ -53,3 +61,26 @@ def test_calculate_episode_data_index():
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
|
||||
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"mean": 2,
|
||||
"std": 3,
|
||||
},
|
||||
"action": {
|
||||
"min": 4,
|
||||
"max": 5,
|
||||
"mean": 6,
|
||||
"std": 7,
|
||||
},
|
||||
}
|
||||
|
||||
original_d = deepcopy(d)
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
4
tests/fixtures/constants.py
vendored
@@ -29,8 +29,8 @@ DUMMY_MOTOR_FEATURES = {
|
||||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
}
|
||||
DEFAULT_FPS = 30
|
||||
DUMMY_VIDEO_INFO = {
|
||||
|
||||
249
tests/fixtures/dataset_factories.py
vendored
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
import shutil
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
@@ -19,19 +20,25 @@ from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
flatten_dict,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
@@ -46,10 +53,10 @@ class LeRobotDatasetFactory(Protocol):
|
||||
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
|
||||
|
||||
|
||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
|
||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||
return task_to_task_index[task]
|
||||
def get_task_index(tasks: datasets.Dataset, task: str) -> int:
|
||||
# TODO(rcadene): a bit complicated no? ^^
|
||||
task_idx = tasks.loc[task].task_index.item()
|
||||
return task_idx
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -90,7 +97,7 @@ def features_factory():
|
||||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
use_videos: bool = False,
|
||||
) -> dict:
|
||||
if use_videos:
|
||||
camera_ft = {
|
||||
@@ -117,13 +124,14 @@ def info_factory(features_factory):
|
||||
total_frames: int = 0,
|
||||
total_tasks: int = 0,
|
||||
total_videos: int = 0,
|
||||
total_chunks: int = 0,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_path: str = DEFAULT_PARQUET_PATH,
|
||||
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
data_path: str = DEFAULT_DATA_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
use_videos: bool = False,
|
||||
) -> dict:
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
return {
|
||||
@@ -133,8 +141,9 @@ def info_factory(features_factory):
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": total_tasks,
|
||||
"total_videos": total_videos,
|
||||
"total_chunks": total_chunks,
|
||||
"chunks_size": chunks_size,
|
||||
"data_files_size_in_mb": data_files_size_in_mb,
|
||||
"video_files_size_in_mb": video_files_size_in_mb,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": data_path,
|
||||
@@ -175,41 +184,45 @@ def stats_factory():
|
||||
return _create_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_factory(stats_factory):
|
||||
def _create_episodes_stats(
|
||||
features: dict[str],
|
||||
total_episodes: int = 3,
|
||||
) -> dict:
|
||||
episodes_stats = {}
|
||||
for episode_index in range(total_episodes):
|
||||
episodes_stats[episode_index] = {
|
||||
"episode_index": episode_index,
|
||||
"stats": stats_factory(features),
|
||||
}
|
||||
return episodes_stats
|
||||
# @pytest.fixture(scope="session")
|
||||
# def episodes_stats_factory(stats_factory):
|
||||
# def _create_episodes_stats(
|
||||
# features: dict[str],
|
||||
# total_episodes: int = 3,
|
||||
# ) -> dict:
|
||||
|
||||
return _create_episodes_stats
|
||||
# def _generator(total_episodes):
|
||||
# for ep_idx in range(total_episodes):
|
||||
# flat_ep_stats = flatten_dict(stats_factory(features))
|
||||
# flat_ep_stats["episode_index"] = ep_idx
|
||||
# yield flat_ep_stats
|
||||
|
||||
# # Simpler to rely on generator instead of from_dict
|
||||
# return Dataset.from_generator(lambda: _generator(total_episodes))
|
||||
|
||||
# return _create_episodes_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_factory():
|
||||
def _create_tasks(total_tasks: int = 3) -> int:
|
||||
tasks = {}
|
||||
for task_index in range(total_tasks):
|
||||
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
|
||||
tasks[task_index] = task_dict
|
||||
return tasks
|
||||
def _create_tasks(total_tasks: int = 3) -> pd.DataFrame:
|
||||
ids = list(range(total_tasks))
|
||||
tasks = [f"Perform action {i}." for i in ids]
|
||||
df = pd.DataFrame({"task_index": ids}, index=tasks)
|
||||
return df
|
||||
|
||||
return _create_tasks
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_factory(tasks_factory):
|
||||
def episodes_factory(tasks_factory, stats_factory):
|
||||
def _create_episodes(
|
||||
features: dict[str],
|
||||
fps: int = DEFAULT_FPS,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
tasks: dict | None = None,
|
||||
video_keys: list[str] | None = None,
|
||||
tasks: pd.DataFrame | None = None,
|
||||
multi_task: bool = False,
|
||||
):
|
||||
if total_episodes <= 0 or total_frames <= 0:
|
||||
@@ -217,66 +230,139 @@ def episodes_factory(tasks_factory):
|
||||
if total_frames < total_episodes:
|
||||
raise ValueError("total_length must be greater than or equal to num_episodes.")
|
||||
|
||||
if not tasks:
|
||||
if tasks is None:
|
||||
min_tasks = 2 if multi_task else 1
|
||||
total_tasks = random.randint(min_tasks, total_episodes)
|
||||
tasks = tasks_factory(total_tasks)
|
||||
|
||||
if total_episodes < len(tasks) and not multi_task:
|
||||
num_tasks_available = len(tasks)
|
||||
|
||||
if total_episodes < num_tasks_available and not multi_task:
|
||||
raise ValueError("The number of tasks should be less than the number of episodes.")
|
||||
|
||||
# Generate random lengths that sum up to total_length
|
||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
||||
|
||||
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
|
||||
num_tasks_available = len(tasks_list)
|
||||
# Create empty dictionaries with all keys
|
||||
d = {
|
||||
"episode_index": [],
|
||||
"meta/episodes/chunk_index": [],
|
||||
"meta/episodes/file_index": [],
|
||||
"data/chunk_index": [],
|
||||
"data/file_index": [],
|
||||
"dataset_from_index": [],
|
||||
"dataset_to_index": [],
|
||||
"tasks": [],
|
||||
"length": [],
|
||||
}
|
||||
if video_keys is not None:
|
||||
for video_key in video_keys:
|
||||
d[f"videos/{video_key}/chunk_index"] = []
|
||||
d[f"videos/{video_key}/file_index"] = []
|
||||
d[f"videos/{video_key}/from_timestamp"] = []
|
||||
d[f"videos/{video_key}/to_timestamp"] = []
|
||||
|
||||
episodes = {}
|
||||
remaining_tasks = tasks_list.copy()
|
||||
for stats_key in flatten_dict({"stats": stats_factory(features)}):
|
||||
d[stats_key] = []
|
||||
|
||||
num_frames = 0
|
||||
remaining_tasks = list(tasks.index)
|
||||
for ep_idx in range(total_episodes):
|
||||
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
|
||||
tasks_to_sample = remaining_tasks if len(remaining_tasks) > 0 else list(tasks.index)
|
||||
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
|
||||
if remaining_tasks:
|
||||
for task in episode_tasks:
|
||||
remaining_tasks.remove(task)
|
||||
|
||||
episodes[ep_idx] = {
|
||||
"episode_index": ep_idx,
|
||||
"tasks": episode_tasks,
|
||||
"length": lengths[ep_idx],
|
||||
}
|
||||
d["episode_index"].append(ep_idx)
|
||||
# TODO(rcadene): remove heuristic of only one file
|
||||
d["meta/episodes/chunk_index"].append(0)
|
||||
d["meta/episodes/file_index"].append(0)
|
||||
d["data/chunk_index"].append(0)
|
||||
d["data/file_index"].append(0)
|
||||
d["dataset_from_index"].append(num_frames)
|
||||
d["dataset_to_index"].append(num_frames + lengths[ep_idx])
|
||||
d["tasks"].append(episode_tasks)
|
||||
d["length"].append(lengths[ep_idx])
|
||||
|
||||
return episodes
|
||||
if video_keys is not None:
|
||||
for video_key in video_keys:
|
||||
d[f"videos/{video_key}/chunk_index"].append(0)
|
||||
d[f"videos/{video_key}/file_index"].append(0)
|
||||
d[f"videos/{video_key}/from_timestamp"].append(num_frames / fps)
|
||||
d[f"videos/{video_key}/to_timestamp"].append((num_frames + lengths[ep_idx]) / fps)
|
||||
|
||||
# Add stats columns like "stats/action/max"
|
||||
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
|
||||
d[stats_key].append(stats)
|
||||
|
||||
num_frames += lengths[ep_idx]
|
||||
|
||||
return Dataset.from_dict(d)
|
||||
|
||||
return _create_episodes
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def create_videos(info_factory, img_array_factory):
|
||||
def _create_video_directory(
|
||||
root: Path,
|
||||
info: dict | None = None,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 150,
|
||||
total_tasks: int = 1,
|
||||
):
|
||||
if info is None:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
)
|
||||
|
||||
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
|
||||
for key, ft in video_feats.items():
|
||||
# create and save images
|
||||
tmp_dir = root / "tmp_images"
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
for frame_index in range(info["total_frames"]):
|
||||
img = img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
|
||||
pil_img = PIL.Image.fromarray(img)
|
||||
path = tmp_dir / f"frame-{frame_index:06d}.png"
|
||||
pil_img.save(path)
|
||||
|
||||
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
|
||||
encode_video_frames(tmp_dir, video_path, fps=ft["video.fps"])
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
return _create_video_directory
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
def _create_hf_dataset(
|
||||
features: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
tasks: pd.DataFrame | None = None,
|
||||
episodes: datasets.Dataset | None = None,
|
||||
fps: int = DEFAULT_FPS,
|
||||
) -> datasets.Dataset:
|
||||
if not tasks:
|
||||
if tasks is None:
|
||||
tasks = tasks_factory()
|
||||
if not episodes:
|
||||
episodes = episodes_factory()
|
||||
if not features:
|
||||
if features is None:
|
||||
features = features_factory()
|
||||
if episodes is None:
|
||||
episodes = episodes_factory(features, fps)
|
||||
|
||||
timestamp_col = np.array([], dtype=np.float32)
|
||||
frame_index_col = np.array([], dtype=np.int64)
|
||||
episode_index_col = np.array([], dtype=np.int64)
|
||||
task_index = np.array([], dtype=np.int64)
|
||||
for ep_dict in episodes.values():
|
||||
for ep_dict in episodes:
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
episode_index_col = np.concatenate(
|
||||
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
||||
)
|
||||
# Slightly incorrect, but for simplicity, we assign to all frames the first task defined in the episode metadata.
|
||||
# TODO(rcadene): assign the tasks of the episode per chunks of frames
|
||||
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
||||
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
||||
|
||||
@@ -286,7 +372,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "image":
|
||||
robot_cols[key] = [
|
||||
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
|
||||
img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
|
||||
for _ in range(len(index_col))
|
||||
]
|
||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||
@@ -314,7 +400,6 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
def lerobot_dataset_metadata_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
episodes_stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
mock_snapshot_download_factory,
|
||||
@@ -324,29 +409,29 @@ def lerobot_dataset_metadata_factory(
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
tasks: pd.DataFrame | None = None,
|
||||
episodes: datasets.Dataset | None = None,
|
||||
) -> LeRobotDatasetMetadata:
|
||||
if not info:
|
||||
if info is None:
|
||||
info = info_factory()
|
||||
if not stats:
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory(
|
||||
features=info["features"], total_episodes=info["total_episodes"]
|
||||
)
|
||||
if not tasks:
|
||||
if tasks is None:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
if episodes is None:
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
video_keys=video_keys,
|
||||
tasks=tasks,
|
||||
)
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info=info,
|
||||
stats=stats,
|
||||
episodes_stats=episodes_stats,
|
||||
tasks=tasks,
|
||||
episodes=episodes,
|
||||
)
|
||||
@@ -368,7 +453,6 @@ def lerobot_dataset_metadata_factory(
|
||||
def lerobot_dataset_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
episodes_stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
@@ -384,38 +468,38 @@ def lerobot_dataset_factory(
|
||||
multi_task: bool = False,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episode_dicts: list[dict] | None = None,
|
||||
tasks: pd.DataFrame | None = None,
|
||||
episodes_metadata: datasets.Dataset | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
if not info:
|
||||
# Instantiate objects
|
||||
if info is None:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
)
|
||||
if not stats:
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
|
||||
if not tasks:
|
||||
if tasks is None:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episode_dicts:
|
||||
episode_dicts = episodes_factory(
|
||||
if episodes_metadata is None:
|
||||
episodes_metadata = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
multi_task=multi_task,
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
|
||||
|
||||
# Write data on disk
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info=info,
|
||||
stats=stats,
|
||||
episodes_stats=episodes_stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
episodes=episodes_metadata,
|
||||
hf_dataset=hf_dataset,
|
||||
)
|
||||
mock_metadata = lerobot_dataset_metadata_factory(
|
||||
@@ -423,9 +507,8 @@ def lerobot_dataset_factory(
|
||||
repo_id=repo_id,
|
||||
info=info,
|
||||
stats=stats,
|
||||
episodes_stats=episodes_stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
episodes=episodes_metadata,
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
|
||||
100
tests/fixtures/files.py
vendored
@@ -11,92 +11,72 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import pandas as pd
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
write_episodes,
|
||||
write_hf_dataset,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_path(info_factory):
|
||||
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
|
||||
if not info:
|
||||
def create_info(info_factory):
|
||||
def _create_info(dir: Path, info: dict | None = None):
|
||||
if info is None:
|
||||
info = info_factory()
|
||||
fpath = dir / INFO_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(info, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
write_info(info, dir)
|
||||
|
||||
return _create_info_json_file
|
||||
return _create_info
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_path(stats_factory):
|
||||
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
|
||||
if not stats:
|
||||
def create_stats(stats_factory):
|
||||
def _create_stats(dir: Path, stats: dict | None = None):
|
||||
if stats is None:
|
||||
stats = stats_factory()
|
||||
fpath = dir / STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(stats, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
write_stats(stats, dir)
|
||||
|
||||
return _create_stats_json_file
|
||||
return _create_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_path(episodes_stats_factory):
|
||||
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory()
|
||||
fpath = dir / EPISODES_STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes_stats.values())
|
||||
return fpath
|
||||
|
||||
return _create_episodes_stats_jsonl_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_path(tasks_factory):
|
||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
||||
if not tasks:
|
||||
def create_tasks(tasks_factory):
|
||||
def _create_tasks(dir: Path, tasks: pd.DataFrame | None = None):
|
||||
if tasks is None:
|
||||
tasks = tasks_factory()
|
||||
fpath = dir / TASKS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(tasks.values())
|
||||
return fpath
|
||||
write_tasks(tasks, dir)
|
||||
|
||||
return _create_tasks_jsonl_file
|
||||
return _create_tasks
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episode_path(episodes_factory):
|
||||
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
|
||||
if not episodes:
|
||||
def create_episodes(episodes_factory):
|
||||
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
|
||||
if episodes is None:
|
||||
# TODO(rcadene): add features, fps as arguments
|
||||
episodes = episodes_factory()
|
||||
fpath = dir / EPISODES_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes.values())
|
||||
return fpath
|
||||
write_episodes(episodes, dir)
|
||||
|
||||
return _create_episodes_jsonl_file
|
||||
return _create_episodes
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def create_hf_dataset(hf_dataset_factory):
|
||||
def _create_hf_dataset(dir: Path, hf_dataset: datasets.Dataset | None = None):
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
write_hf_dataset(hf_dataset, dir)
|
||||
|
||||
return _create_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -104,7 +84,8 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_single_episode_parquet(
|
||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
raise NotImplementedError()
|
||||
if info is None:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
@@ -127,7 +108,8 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_multi_episode_parquet(
|
||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
raise NotImplementedError()
|
||||
if info is None:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
|
||||
128
tests/fixtures/hub.py
vendored
@@ -14,15 +14,17 @@
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
)
|
||||
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
|
||||
@@ -30,17 +32,16 @@ from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_snapshot_download_factory(
|
||||
info_factory,
|
||||
info_path,
|
||||
create_info,
|
||||
stats_factory,
|
||||
stats_path,
|
||||
episodes_stats_factory,
|
||||
episodes_stats_path,
|
||||
create_stats,
|
||||
tasks_factory,
|
||||
tasks_path,
|
||||
create_tasks,
|
||||
episodes_factory,
|
||||
episode_path,
|
||||
single_episode_parquet_path,
|
||||
create_episodes,
|
||||
hf_dataset_factory,
|
||||
create_hf_dataset,
|
||||
create_videos,
|
||||
):
|
||||
"""
|
||||
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
||||
@@ -50,82 +51,91 @@ def mock_snapshot_download_factory(
|
||||
def _mock_snapshot_download_func(
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
tasks: pd.DataFrame | None = None,
|
||||
episodes: datasets.Dataset | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
):
|
||||
if not info:
|
||||
if info is None:
|
||||
info = info_factory()
|
||||
if not stats:
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory(
|
||||
features=info["features"], total_episodes=info["total_episodes"]
|
||||
)
|
||||
if not tasks:
|
||||
if tasks is None:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
if episodes is None:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
)
|
||||
if not hf_dataset:
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
|
||||
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
|
||||
return episode_index
|
||||
else:
|
||||
return None
|
||||
|
||||
def _mock_snapshot_download(
|
||||
repo_id: str,
|
||||
repo_id: str, # TODO(rcadene): repo_id should be used no?
|
||||
local_dir: str | Path | None = None,
|
||||
allow_patterns: str | list[str] | None = None,
|
||||
ignore_patterns: str | list[str] | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not local_dir:
|
||||
if local_dir is None:
|
||||
local_dir = LEROBOT_TEST_DIR
|
||||
|
||||
# List all possible files
|
||||
all_files = []
|
||||
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
||||
all_files.extend(meta_files)
|
||||
all_files = [
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
# TODO(rcadene): remove naive chunk 0 file 0 ?
|
||||
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
|
||||
]
|
||||
|
||||
data_files = []
|
||||
for episode_dict in episodes.values():
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info["chunks_size"]
|
||||
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
data_files.append(data_path)
|
||||
all_files.extend(data_files)
|
||||
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
|
||||
for key in video_keys:
|
||||
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))
|
||||
|
||||
allowed_files = filter_repo_objects(
|
||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
# Create allowed files
|
||||
request_info = False
|
||||
request_tasks = False
|
||||
request_episodes = False
|
||||
request_stats = False
|
||||
request_data = False
|
||||
request_videos = False
|
||||
for rel_path in allowed_files:
|
||||
if rel_path.startswith("data/"):
|
||||
episode_index = _extract_episode_index_from_path(rel_path)
|
||||
if episode_index is not None:
|
||||
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
|
||||
if rel_path == INFO_PATH:
|
||||
_ = info_path(local_dir, info)
|
||||
elif rel_path == STATS_PATH:
|
||||
_ = stats_path(local_dir, stats)
|
||||
elif rel_path == EPISODES_STATS_PATH:
|
||||
_ = episodes_stats_path(local_dir, episodes_stats)
|
||||
elif rel_path == TASKS_PATH:
|
||||
_ = tasks_path(local_dir, tasks)
|
||||
elif rel_path == EPISODES_PATH:
|
||||
_ = episode_path(local_dir, episodes)
|
||||
if rel_path.startswith("meta/info.json"):
|
||||
request_info = True
|
||||
elif rel_path.startswith("meta/stats"):
|
||||
request_stats = True
|
||||
elif rel_path.startswith("meta/tasks"):
|
||||
request_tasks = True
|
||||
elif rel_path.startswith("meta/episodes"):
|
||||
request_episodes = True
|
||||
elif rel_path.startswith("data/"):
|
||||
request_data = True
|
||||
elif rel_path.startswith("videos/"):
|
||||
request_videos = True
|
||||
else:
|
||||
pass
|
||||
raise ValueError(f"{rel_path} not supported.")
|
||||
|
||||
if request_info:
|
||||
create_info(local_dir, info)
|
||||
if request_stats:
|
||||
create_stats(local_dir, stats)
|
||||
if request_tasks:
|
||||
create_tasks(local_dir, tasks)
|
||||
if request_episodes:
|
||||
create_episodes(local_dir, episodes)
|
||||
if request_data:
|
||||
create_hf_dataset(local_dir, hf_dataset)
|
||||
if request_videos:
|
||||
create_videos(root=local_dir, info=info)
|
||||
|
||||
return str(local_dir)
|
||||
|
||||
return _mock_snapshot_download
|
||||
|
||||