Compare commits
2 Commits
Cadene-pat
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bbf617fd92 | ||
|
|
73f46bac56 |
52
.github/workflows/build-docker-images.yml
vendored
52
.github/workflows/build-docker-images.yml
vendored
@@ -14,14 +14,20 @@ env:
|
||||
jobs:
|
||||
latest-cpu:
|
||||
name: CPU
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Git LFS
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
sudo df -h
|
||||
# sudo ls -l /usr/local/lib/
|
||||
# sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -49,15 +55,20 @@ jobs:
|
||||
|
||||
latest-cuda:
|
||||
name: GPU
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Install Git LFS
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
|
||||
sudo df -h
|
||||
# sudo ls -l /usr/local/lib/
|
||||
# sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo df -h
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -84,9 +95,20 @@ jobs:
|
||||
|
||||
latest-cuda-dev:
|
||||
name: GPU Dev
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo df -h
|
||||
# sudo ls -l /usr/local/lib/
|
||||
# sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo df -h
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
6
.github/workflows/nightly-tests.yml
vendored
6
.github/workflows/nightly-tests.yml
vendored
@@ -16,8 +16,7 @@ jobs:
|
||||
name: CPU
|
||||
strategy:
|
||||
fail-fast: false
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: huggingface/lerobot-cpu:latest
|
||||
options: --shm-size "16gb"
|
||||
@@ -44,8 +43,7 @@ jobs:
|
||||
name: GPU
|
||||
strategy:
|
||||
fail-fast: false
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
TEST_TYPE: "single_gpu"
|
||||
|
||||
28
.github/workflows/quality.yml
vendored
28
.github/workflows/quality.yml
vendored
@@ -54,31 +54,3 @@ jobs:
|
||||
|
||||
- name: Poetry check
|
||||
run: poetry check
|
||||
|
||||
|
||||
poetry_relax:
|
||||
name: Poetry relax
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install poetry
|
||||
|
||||
- name: Install poetry-relax
|
||||
run: poetry self add poetry-relax
|
||||
|
||||
- name: Poetry relax
|
||||
id: poetry_relax
|
||||
run: |
|
||||
output=$(poetry relax --check 2>&1)
|
||||
if echo "$output" | grep -q "Proposing updates"; then
|
||||
echo "$output"
|
||||
echo ""
|
||||
echo "Some dependencies have caret '^' version requirement added by poetry by default."
|
||||
echo "Please replace them with '>='. You can do this by hand or use poetry-relax to do this."
|
||||
exit 1
|
||||
else
|
||||
echo "$output"
|
||||
fi
|
||||
|
||||
16
.github/workflows/test-docker-build.yml
vendored
16
.github/workflows/test-docker-build.yml
vendored
@@ -42,14 +42,26 @@ jobs:
|
||||
build_modified_dockerfiles:
|
||||
name: Build modified Docker images
|
||||
needs: get_changed_files
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ needs.get_changed_files.outputs.matrix }} != ''
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo df -h
|
||||
# sudo ls -l /usr/local/lib/
|
||||
# sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
11
.github/workflows/test.yml
vendored
11
.github/workflows/test.yml
vendored
@@ -10,7 +10,6 @@ on:
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
- "Makefile"
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
@@ -20,7 +19,6 @@ on:
|
||||
- "examples/**"
|
||||
- ".github/**"
|
||||
- "poetry.lock"
|
||||
- "Makefile"
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
@@ -34,8 +32,8 @@ jobs:
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev ffmpeg
|
||||
- name: Install EGL
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
@@ -72,9 +70,6 @@ jobs:
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
@@ -109,7 +104,7 @@ jobs:
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
- name: Install apt dependencies
|
||||
- name: Install EGL
|
||||
run: sudo apt-get update && sudo apt-get install -y libegl1-mesa-dev
|
||||
|
||||
- name: Install poetry
|
||||
|
||||
20
.github/workflows/trufflehog.yml
vendored
20
.github/workflows/trufflehog.yml
vendored
@@ -1,20 +0,0 @@
|
||||
on:
|
||||
push:
|
||||
|
||||
name: Secret Leaks
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
with:
|
||||
extra_args: --only-verified
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -121,8 +121,8 @@ celerybeat.pid
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
|
||||
@@ -14,11 +14,11 @@ repos:
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.16.0
|
||||
rev: v3.15.2
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.5.2
|
||||
rev: v0.4.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
@@ -31,7 +31,3 @@ repos:
|
||||
args:
|
||||
- "--check"
|
||||
- "--no-update"
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.18.4
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
31
Makefile
31
Makefile
@@ -5,7 +5,7 @@ PYTHON_PATH := $(shell which python)
|
||||
# If Poetry is installed, redefine PYTHON_PATH to use the Poetry-managed Python
|
||||
POETRY_CHECK := $(shell command -v poetry)
|
||||
ifneq ($(POETRY_CHECK),)
|
||||
PYTHON_PATH := $(shell poetry run which python)
|
||||
PYTHON_PATH := $(shell poetry run which python)
|
||||
endif
|
||||
|
||||
export PATH := $(dir $(PYTHON_PATH)):$(PATH)
|
||||
@@ -26,7 +26,6 @@ test-end-to-end:
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train-with-online
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
|
||||
@@ -47,7 +46,6 @@ test-act-ete-train:
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/act/
|
||||
|
||||
test-act-ete-eval:
|
||||
@@ -75,7 +73,6 @@ test-act-ete-train-amp:
|
||||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/act_amp/ \
|
||||
training.image_transforms.enable=true \
|
||||
use_amp=true
|
||||
|
||||
test-act-ete-eval-amp:
|
||||
@@ -103,7 +100,6 @@ test-diffusion-ete-train:
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/diffusion/
|
||||
|
||||
test-diffusion-ete-eval:
|
||||
@@ -114,6 +110,7 @@ test-diffusion-ete-eval:
|
||||
env.episode_length=8 \
|
||||
device=$(DEVICE) \
|
||||
|
||||
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
|
||||
test-tdmpc-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
policy=tdmpc \
|
||||
@@ -130,31 +127,8 @@ test-tdmpc-ete-train:
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc/
|
||||
|
||||
test-tdmpc-ete-train-with-online:
|
||||
python lerobot/scripts/train.py \
|
||||
env=pusht \
|
||||
env.gym.obs_type=environment_state_agent_pos \
|
||||
policy=tdmpc_pusht_keypoints \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=10 \
|
||||
device=$(DEVICE) \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=20 \
|
||||
training.save_checkpoint=false \
|
||||
training.save_freq=10 \
|
||||
training.batch_size=2 \
|
||||
training.online_rollout_n_episodes=2 \
|
||||
training.online_rollout_batch_size=2 \
|
||||
training.online_steps_between_rollouts=10 \
|
||||
training.online_buffer_capacity=15 \
|
||||
eval.use_async_envs=true \
|
||||
hydra.run.dir=tests/outputs/tdmpc_online/
|
||||
|
||||
|
||||
test-tdmpc-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
-p tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||
@@ -185,6 +159,5 @@ test-act-pusht-tutorial:
|
||||
training.save_model=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
training.image_transforms.enable=true \
|
||||
hydra.run.dir=tests/outputs/act_pusht/
|
||||
rm lerobot/configs/policy/created_by_Makefile.yaml
|
||||
|
||||
128
README.md
128
README.md
@@ -22,21 +22,8 @@
|
||||
|
||||
</div>
|
||||
|
||||
<h2 align="center">
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md">Hot new tutorial: Getting started with real-world robots</a></p>
|
||||
</h2>
|
||||
|
||||
<div align="center">
|
||||
<img src="media/tutorial/koch_v1_1_leader_follower.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="50%">
|
||||
<p>We just dropped an in-depth tutorial on how to build your own robot!</p>
|
||||
<p>Teach it new skills by showing it a few moves with just a laptop.</p>
|
||||
<p>Then watch your homemade robot act autonomously 🤯</p>
|
||||
<p>For more info, see <a href="https://x.com/RemiCadene/status/1825455895561859185">our thread on X</a> or <a href="https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md">our tutorial page</a>.</p>
|
||||
</div>
|
||||
|
||||
|
||||
<h3 align="center">
|
||||
<p>State-of-the-art AI for real-world robotics</p>
|
||||
<p>State-of-the-art Machine Learning for real-world robotics</p>
|
||||
</h3>
|
||||
|
||||
---
|
||||
@@ -71,26 +58,23 @@
|
||||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||
- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official).
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
Download our source code:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git && cd lerobot
|
||||
```
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
```
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
pip install -e .
|
||||
pip install .
|
||||
```
|
||||
|
||||
> **NOTE:** Depending on your platform, If you encounter any build errors during this step
|
||||
@@ -104,7 +88,7 @@ For simulations, 🤗 LeRobot comes with gymnasium environments that can be inst
|
||||
|
||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||
```bash
|
||||
pip install -e ".[aloha, pusht]"
|
||||
pip install ".[aloha, pusht]"
|
||||
```
|
||||
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||
@@ -129,12 +113,10 @@ wandb login
|
||||
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
|
||||
| | ├── envs # various sim environments: aloha, pusht, xarm
|
||||
| | ├── policies # various policies: act, diffusion, tdmpc
|
||||
| | ├── robot_devices # various real devices: dynamixel motors, opencv cameras, koch robots
|
||||
| | └── utils # various utilities
|
||||
| └── scripts # contains functions to execute via command line
|
||||
| ├── eval.py # load policy and evaluate it on an environment
|
||||
| ├── train.py # train a policy via imitation learning and/or reinforcement learning
|
||||
| ├── control_robot.py # teleoperate a real robot, record data, run a policy
|
||||
| ├── push_dataset_to_hub.py # convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub
|
||||
| └── visualize_dataset.py # load a dataset and render its demonstrations
|
||||
├── outputs # contains results of scripts execution: logs, videos, model checkpoints
|
||||
@@ -152,7 +134,7 @@ python lerobot/scripts/visualize_dataset.py \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
or from a dataset in a local folder with the root `DATA_DIR` environment variable (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
or from a dataset in a local folder with the root `DATA_DIR` environment variable
|
||||
```bash
|
||||
DATA_DIR='./my_local_data_dir' python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id lerobot/pusht \
|
||||
@@ -169,50 +151,48 @@ Our script can also visualize datasets stored on a distant server. See `python l
|
||||
|
||||
### The `LeRobotDataset` format
|
||||
|
||||
A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model.
|
||||
A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and Pytorch dataset. For instance `dataset[0]` will retrieve a sample of the dataset observations and actions in pytorch tensors format ready to be fed to a model.
|
||||
|
||||
A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`.
|
||||
A specificity of `LeRobotDataset` is that we can retrieve several frames for one sample query. By setting `delta_timestamps` to a list of delta timestamps, e.g. `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for each query, 4 images including one at -1 second before the current time step, the two others at -0.5 second and -0.2, and the final one at the current time step (0 second). See example [1_load_lerobot_dataset.py](examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`.
|
||||
|
||||
Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor.
|
||||
Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states.
|
||||
|
||||
Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects:
|
||||
|
||||
```
|
||||
dataset attributes:
|
||||
├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example:
|
||||
│ ├ observation.images.cam_high (VideoFrame):
|
||||
│ │ VideoFrame = {'path': path to a mp4 video, 'timestamp' (float32): timestamp in the video}
|
||||
│ ├ observation.state (list of float32): position of an arm joints (for instance)
|
||||
│ ├ observation.images.cam_high: VideoFrame
|
||||
│ │ VideoFrame = {'path': path to a mp4 video, 'timestamp': float32 timestamp in the video}
|
||||
│ ├ observation.state: List of float32: position of an arm joints (for instance)
|
||||
│ ... (more observations)
|
||||
│ ├ action (list of float32): goal position of an arm joints (for instance)
|
||||
│ ├ episode_index (int64): index of the episode for this sample
|
||||
│ ├ frame_index (int64): index of the frame for this sample in the episode ; starts at 0 for each episode
|
||||
│ ├ timestamp (float32): timestamp in the episode
|
||||
│ ├ next.done (bool): indicates the end of en episode ; True for the last frame in each episode
|
||||
│ └ index (int64): general index in the whole dataset
|
||||
│ ├ action: List of float32
|
||||
│ ├ episode_index: int64: index of the episode for this sample
|
||||
│ ├ frame_index: int64: index of the frame for this sample in the episode ; starts at 0 for each episode
|
||||
│ ├ timestamp: float32: timestamp in the episode
|
||||
│ ├ next.done: bool: indicates the end of en episode ; True for the last frame in each episode
|
||||
│ └ index: int64: general index in the whole dataset
|
||||
├ episode_data_index: contains 2 tensors with the start and end indices of each episode
|
||||
│ ├ from (1D int64 tensor): first frame index for each episode — shape (num episodes,) starts with 0
|
||||
│ └ to: (1D int64 tensor): last frame index for each episode — shape (num episodes,)
|
||||
│ ├ from: 1D int64 tensor of first frame index for each episode: shape (num episodes,) starts with 0
|
||||
│ └ to: 1D int64 tensor of last frame index for each episode: shape (num episodes,)
|
||||
├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
|
||||
│ ├ observation.images.cam_high: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
|
||||
│ ...
|
||||
├ info: a dictionary of metadata on the dataset
|
||||
│ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with
|
||||
│ ├ fps (float): frame per second the dataset is recorded/synchronized to
|
||||
│ ├ video (bool): indicates if frames are encoded in mp4 video files to save space or stored as png files
|
||||
│ └ encoding (dict): if video, this documents the main options that were used with ffmpeg to encode the videos
|
||||
├ videos_dir (Path): where the mp4 videos or png images are stored/accessed
|
||||
└ camera_keys (list of string): the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`)
|
||||
│ ├ fps: float - frame per second the dataset is recorded/synchronized to
|
||||
│ └ video: bool - indicates if frames are encoded in mp4 video files to save space or stored as png files
|
||||
├ videos_dir: path to where the mp4 videos or png images are stored/accessed
|
||||
└ camera_keys: List of string: the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`)
|
||||
```
|
||||
|
||||
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
|
||||
- hf_dataset stored using Hugging Face datasets library serialization to parquet
|
||||
- videos are stored in mp4 format to save space or png files
|
||||
- episode_data_index saved using `safetensor` tensor serialization format
|
||||
- stats saved using `safetensor` tensor serialization format
|
||||
- episode_data_index saved using `safetensor` tensor serializtion format
|
||||
- stats saved using `safetensor` tensor serializtion format
|
||||
- info are saved using JSON
|
||||
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can set the `DATA_DIR` environment variable to your root dataset folder as illustrated in the above section on dataset visualization.
|
||||
Dataset can uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can set the `DATA_DIR` environment variable to you root dataset folder as illustrated in the above section on dataset visualization.
|
||||
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
@@ -301,13 +281,13 @@ To add a dataset to the hub, you need to login using a write-access token, which
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then point to your raw dataset folder (e.g. `data/aloha_static_pingpong_test_raw`), and push your dataset to the hub with:
|
||||
Then move your dataset folder in `data` directory (e.g. `data/aloha_static_pingpong_test`), and push your dataset to the hub with:
|
||||
```bash
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir data/aloha_static_pingpong_test_raw \
|
||||
--out-dir data \
|
||||
--repo-id lerobot/aloha_static_pingpong_test \
|
||||
--raw-format aloha_hdf5
|
||||
--data-dir data \
|
||||
--dataset-id aloha_static_pingpong_test \
|
||||
--raw-format aloha_hdf5 \
|
||||
--community-id lerobot
|
||||
```
|
||||
|
||||
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
|
||||
@@ -359,7 +339,7 @@ with profile(
|
||||
## Citation
|
||||
|
||||
If you want, you can cite this work with:
|
||||
```bibtex
|
||||
```
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
@@ -367,45 +347,3 @@ If you want, you can cite this work with:
|
||||
year = {2024}
|
||||
}
|
||||
```
|
||||
|
||||
Additionally, if you are using any of the particular policy architecture, pretrained models, or datasets, it is recommended to cite the original authors of the work as they appear below:
|
||||
|
||||
- [Diffusion Policy](https://diffusion-policy.cs.columbia.edu)
|
||||
```bibtex
|
||||
@article{chi2024diffusionpolicy,
|
||||
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
|
||||
title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
||||
journal = {The International Journal of Robotics Research},
|
||||
year = {2024},
|
||||
}
|
||||
```
|
||||
- [ACT or ALOHA](https://tonyzhaozh.github.io/aloha)
|
||||
```bibtex
|
||||
@article{zhao2023learning,
|
||||
title={Learning fine-grained bimanual manipulation with low-cost hardware},
|
||||
author={Zhao, Tony Z and Kumar, Vikash and Levine, Sergey and Finn, Chelsea},
|
||||
journal={arXiv preprint arXiv:2304.13705},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
- [TDMPC](https://www.nicklashansen.com/td-mpc/)
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Hansen2022tdmpc,
|
||||
title={Temporal Difference Learning for Model Predictive Control},
|
||||
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
|
||||
booktitle={ICML},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
- [VQ-BeT](https://sjlee.cc/vq-bet/)
|
||||
```bibtex
|
||||
@article{lee2024behavior,
|
||||
title={Behavior generation with latent actions},
|
||||
author={Lee, Seungjae and Wang, Yibin and Etukuru, Haritheja and Kim, H Jin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
|
||||
journal={arXiv preprint arXiv:2403.03181},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1,271 +0,0 @@
|
||||
# Video benchmark
|
||||
|
||||
|
||||
## Questions
|
||||
What is the optimal trade-off between:
|
||||
- maximizing loading time with random access,
|
||||
- minimizing memory space on disk,
|
||||
- maximizing success rate of policies,
|
||||
- compatibility across devices/platforms for decoding videos (e.g. video players, web browsers).
|
||||
|
||||
How to encode videos?
|
||||
- Which video codec (`-vcodec`) to use? h264, h265, AV1?
|
||||
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
|
||||
- How much compression (`-crf`)? No compression with `0`, intermediate compression with `25` or extreme with `50+`?
|
||||
- Which frequency to chose for key frames (`-g`)? A key frame every `10` frames?
|
||||
|
||||
How to decode videos?
|
||||
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
|
||||
- What scenarios to use for the requesting timestamps during benchmark? (`timestamps_mode`)
|
||||
|
||||
|
||||
## Variables
|
||||
**Image content & size**
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
||||
- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
||||
|
||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
||||
|
||||
**Data augmentations**
|
||||
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
|
||||
|
||||
### Encoding parameters
|
||||
| parameter | values |
|
||||
|-------------|--------------------------------------------------------------|
|
||||
| **vcodec** | `libx264`, `libx265`, `libsvtav1` |
|
||||
| **pix_fmt** | `yuv444p`, `yuv420p` |
|
||||
| **g** | `1`, `2`, `3`, `4`, `5`, `6`, `10`, `15`, `20`, `40`, `None` |
|
||||
| **crf** | `0`, `5`, `10`, `15`, `20`, `25`, `30`, `40`, `50`, `None` |
|
||||
|
||||
Note that `crf` value might be interpreted differently by various video codecs. In other words, the same value used with one codec doesn't necessarily translate into the same compression level with another codec. In fact, the default value (`None`) isn't the same amongst the different video codecs. Importantly, it is also the case for many other ffmpeg arguments like `g` which specifies the frequency of the key frames.
|
||||
|
||||
For a comprehensive list and documentation of these parameters, see the ffmpeg documentation depending on the video codec used:
|
||||
- h264: https://trac.ffmpeg.org/wiki/Encode/H.264
|
||||
- h265: https://trac.ffmpeg.org/wiki/Encode/H.265
|
||||
- AV1: https://trac.ffmpeg.org/wiki/Encode/AV1
|
||||
|
||||
### Decoding parameters
|
||||
**Decoder**
|
||||
We tested two video decoding backends from torchvision:
|
||||
- `pyav` (default)
|
||||
- `video_reader` (requires to build torchvision from source)
|
||||
|
||||
**Requested timestamps**
|
||||
Given the way video decoding works, once a keyframe has been loaded, the decoding of subsequent frames is fast.
|
||||
This of course is affected by the `-g` parameter during encoding, which specifies the frequency of the keyframes. Given our typical use cases in robotics policies which might request a few timestamps in different random places, we want to replicate these use cases with the following scenarios:
|
||||
- `1_frame`: 1 frame,
|
||||
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
|
||||
- `6_frames`: 6 consecutive frames (e.g. `[t + i / fps for i in range(6)]`)
|
||||
|
||||
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
|
||||
|
||||
Additionally, because some policies might request single timestamps that are a few frames appart, we also have the following scenario:
|
||||
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
|
||||
|
||||
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
|
||||
|
||||
|
||||
## Metrics
|
||||
**Data compression ratio (lower is better)**
|
||||
`video_images_size_ratio` is the ratio of the memory space on disk taken by the encoded video over the memory space taken by the original images. For instance, `video_images_size_ratio=25%` means that the video takes 4 times less memory space on disk compared to the original images.
|
||||
|
||||
**Loading time ratio (lower is better)**
|
||||
`video_images_load_time_ratio` is the ratio of the time it takes to decode frames from the video at a given timestamps over the time it takes to load the exact same original images. Lower is better. For instance, `video_images_load_time_ratio=200%` means that decoding from video is 2 times slower than loading the original images.
|
||||
|
||||
**Average Mean Square Error (lower is better)**
|
||||
`avg_mse` is the average mean square error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
|
||||
|
||||
**Average Peak Signal to Noise Ratio (higher is better)**
|
||||
`avg_psnr` measures the ratio between the maximum possible power of a signal and the power of corrupting noise that affects the fidelity of its representation. Higher PSNR indicates better quality.
|
||||
|
||||
**Average Structural Similarity Index Measure (higher is better)**
|
||||
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
|
||||
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding accross platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not be pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
- `yuv420p` is more widely supported across various platforms, including web browsers.
|
||||
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
|
||||
|
||||
|
||||
<!-- **Loss of a pretrained policy (higher is better)** (not available)
|
||||
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
|
||||
|
||||
**Success rate after retraining (higher is better)** (not available)
|
||||
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best. -->
|
||||
|
||||
|
||||
## How the benchmark works
|
||||
The benchmark evaluates both encoding and decoding of video frames on the first episode of each dataset.
|
||||
|
||||
**Encoding:** for each `vcodec` and `pix_fmt` pair, we use a default value for `g` and `crf` upon which we change a single value (either `g` or `crf`) to one of the specified values (we don't test every combination of those as this would be computationally too heavy).
|
||||
This gives a unique set of encoding parameters which is used to encode the episode.
|
||||
|
||||
**Decoding:** Then, for each of those unique encodings, we iterate through every combination of the decoding parameters `backend` and `timestamps_mode`. For each of them, we record the metrics of a number of samples (given by `--num-samples`). This is parallelized for efficiency and the number of processes can be controlled with `--num-workers`. Ideally, it's best to have a `--num-samples` that is divisible by `--num-workers`.
|
||||
|
||||
Intermediate results saved for each `vcodec` and `pix_fmt` combination in csv tables.
|
||||
These are then all concatenated to a single table ready for analysis.
|
||||
|
||||
## Caveats
|
||||
We tried to measure the most impactful parameters for both encoding and decoding. However, for computational reasons we can't test out every combination.
|
||||
|
||||
Additional encoding parameters exist that are not included in this benchmark. In particular:
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailled info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
|
||||
- `torchaudio`
|
||||
- `ffmpegio`
|
||||
- `decord`
|
||||
- `nvc`
|
||||
|
||||
Note as well that since we are mostly interested in the performance at decoding time (also because encoding is done only once before uploading a dataset), we did not measure encoding times nor have any metrics regarding encoding.
|
||||
However, besides the necessity to build ffmpeg from source, encoding did not pose any issue and it didn't take a significant amount of time during this benchmark.
|
||||
|
||||
|
||||
## Install
|
||||
Building ffmpeg from source is required to include libx265 and libaom/libsvtav1 (av1) video codecs ([compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu)).
|
||||
|
||||
**Note:** While you still need to build torchvision with a conda-installed `ffmpeg<4.3` to use the `video_reader` decoder (as described in [#220](https://github.com/huggingface/lerobot/pull/220)), you also need another version which is custom-built with all the video codecs for encoding. For the script to then use that version, you can prepend the command above with `PATH="$HOME/bin:$PATH"`, which is where ffmpeg should be built.
|
||||
|
||||
|
||||
## Adding a video decoder
|
||||
Right now, we're only benchmarking the two video decoder available with torchvision: `pyav` and `video_reader`.
|
||||
You can easily add a new decoder to benchmark by adding it to this function in the script:
|
||||
```diff
|
||||
def decode_video_frames(
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(
|
||||
video_path, timestamps, tolerance_s, backend
|
||||
)
|
||||
+ elif backend == ["your_decoder"]:
|
||||
+ return your_decoder_function(
|
||||
+ video_path, timestamps, tolerance_s, backend
|
||||
+ )
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
```
|
||||
|
||||
|
||||
## Example
|
||||
For a quick run, you can try these parameters:
|
||||
```bash
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 2 20 None \
|
||||
--crf 10 40 None \
|
||||
--timestamps-modes 1_frame 2_frames \
|
||||
--backends pyav video_reader \
|
||||
--num-samples 5 \
|
||||
--num-workers 5 \
|
||||
--save-frames 0
|
||||
```
|
||||
|
||||
|
||||
## Results
|
||||
|
||||
### Reproduce
|
||||
We ran the benchmark with the following parameters:
|
||||
```bash
|
||||
# h264 and h265 encodings
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
--vcodec libx264 libx265 \
|
||||
--pix-fmt yuv444p yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
||||
--backends pyav video_reader \
|
||||
--num-samples 50 \
|
||||
--num-workers 5 \
|
||||
--save-frames 1
|
||||
|
||||
# av1 encoding (only compatible with yuv420p and pyav decoder)
|
||||
python benchmark/video/run_video_benchmark.py \
|
||||
--output-dir outputs/video_benchmark \
|
||||
--repo-ids \
|
||||
lerobot/pusht_image \
|
||||
aliberts/aloha_mobile_shrimp_image \
|
||||
aliberts/paris_street \
|
||||
aliberts/kitchen \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
||||
--backends pyav \
|
||||
--num-samples 50 \
|
||||
--num-workers 5 \
|
||||
--save-frames 1
|
||||
```
|
||||
|
||||
The full results are available [here](https://docs.google.com/spreadsheets/d/1OYJB43Qu8fC26k_OyoMFgGBBKfQRCi4BIuYitQnq3sw/edit?usp=sharing)
|
||||
|
||||
|
||||
### Parameters selected for LeRobotDataset
|
||||
Considering these results, we chose what we think is the best set of encoding parameter:
|
||||
- vcodec: `libsvtav1`
|
||||
- pix-fmt: `yuv420p`
|
||||
- g: `2`
|
||||
- crf: `30`
|
||||
|
||||
Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_reader` does not support it (and `pyav` doesn't require a custom build of `torchvision`).
|
||||
|
||||
### Summary
|
||||
|
||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
||||
|
||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
||||
|------------------------------------|------------|---------|-----------|-----------|-----------|
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
||||
| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
||||
| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
||||
| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
||||
|
||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
||||
|------------------------------------|---------|---------|----------|---------|-----------|
|
||||
| | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
||||
| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
||||
| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
||||
| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
||||
|
||||
| | | vcodec | pix_fmt | | | |
|
||||
|------------------------------------|----------|----------|--------------|----------|-----------|--------------|
|
||||
| | | libx264 | | libx265 | | libsvtav1 |
|
||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
||||
| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
||||
| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
||||
| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
||||
@@ -1,90 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Capture video feed from a camera as raw images."""
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
|
||||
|
||||
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int):
|
||||
now = dt.datetime.now()
|
||||
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
|
||||
if not capture_dir.exists():
|
||||
capture_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Opens the default webcam
|
||||
cap = cv2.VideoCapture(0)
|
||||
if not cap.isOpened():
|
||||
print("Error: Could not open video stream.")
|
||||
return
|
||||
|
||||
cap.set(cv2.CAP_PROP_FPS, fps)
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
||||
|
||||
frame_index = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
|
||||
if not ret:
|
||||
print("Error: Could not read frame.")
|
||||
break
|
||||
|
||||
cv2.imshow("Video Stream", frame)
|
||||
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
|
||||
frame_index += 1
|
||||
|
||||
# Break the loop on 'q' key press
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
|
||||
# Release the capture and destroy all windows
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("outputs/cam_capture/"),
|
||||
help="Directory where the capture images are written. A subfolder named with the current date & time will be created inside it for each capture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Frames Per Second of the capture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=1280,
|
||||
help="Width of the captured images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=720,
|
||||
help="Height of the captured images.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
display_and_save_video_stream(**vars(args))
|
||||
@@ -1,490 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Assess the performance of video decoding in various configurations.
|
||||
|
||||
This script will benchmark different video encoding and decoding parameters.
|
||||
See the provided README.md or run `python benchmark/video/run_video_benchmark.py --help` for usage info.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import random
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
import torch
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.common.utils.benchmark import TimeBenchmark
|
||||
|
||||
BASE_ENCODING = OrderedDict(
|
||||
[
|
||||
("vcodec", "libx264"),
|
||||
("pix_fmt", "yuv444p"),
|
||||
("g", 2),
|
||||
("crf", None),
|
||||
# TODO(aliberts): Add fastdecode
|
||||
# ("fastdecode", 0),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): move to `utils.py` folder when we want to refactor
|
||||
def parse_int_or_none(value) -> int | None:
|
||||
if value.lower() == "none":
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError as e:
|
||||
raise argparse.ArgumentTypeError(f"Invalid int or None: {value}") from e
|
||||
|
||||
|
||||
def check_datasets_formats(repo_ids: list) -> None:
|
||||
for repo_id in repo_ids:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
if dataset.video:
|
||||
raise ValueError(
|
||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||
)
|
||||
|
||||
|
||||
def get_directory_size(directory: Path) -> int:
|
||||
total_size = 0
|
||||
for item in directory.rglob("*"):
|
||||
if item.is_file():
|
||||
total_size += item.stat().st_size
|
||||
return total_size
|
||||
|
||||
|
||||
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
|
||||
frame = torch.from_numpy(np.array(frame))
|
||||
frame = frame.type(torch.float32) / 255
|
||||
frame = einops.rearrange(frame, "h w c -> c h w")
|
||||
frames.append(frame)
|
||||
return torch.stack(frames)
|
||||
|
||||
|
||||
def save_decoded_frames(
|
||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||
) -> None:
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
|
||||
return
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, ts in enumerate(timestamps):
|
||||
idx = int(ts * fps)
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
|
||||
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
|
||||
return
|
||||
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# We only save images from the first camera
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(
|
||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||
):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
|
||||
if i >= ep_num_images - 1:
|
||||
break
|
||||
|
||||
|
||||
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
|
||||
# Start at 5 to allow for 2_frames_4_space and 6_frames
|
||||
idx = random.randint(5, ep_num_images - 1)
|
||||
match timestamps_mode:
|
||||
case "1_frame":
|
||||
frame_indexes = [idx]
|
||||
case "2_frames":
|
||||
frame_indexes = [idx - 1, idx]
|
||||
case "2_frames_4_space":
|
||||
frame_indexes = [idx - 5, idx]
|
||||
case "6_frames":
|
||||
frame_indexes = [idx - i for i in range(6)][::-1]
|
||||
case _:
|
||||
raise ValueError(timestamps_mode)
|
||||
|
||||
return [idx / fps for idx in frame_indexes]
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
|
||||
|
||||
def benchmark_decoding(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
timestamps_mode: str,
|
||||
backend: str,
|
||||
ep_num_images: int,
|
||||
fps: int,
|
||||
num_samples: int = 50,
|
||||
num_workers: int = 4,
|
||||
save_frames: bool = False,
|
||||
) -> dict:
|
||||
def process_sample(sample: int):
|
||||
time_benchmark = TimeBenchmark()
|
||||
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
|
||||
num_frames = len(timestamps)
|
||||
result = {
|
||||
"psnr_values": [],
|
||||
"ssim_values": [],
|
||||
"mse_values": [],
|
||||
}
|
||||
|
||||
with time_benchmark:
|
||||
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
|
||||
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
|
||||
|
||||
with time_benchmark:
|
||||
original_frames = load_original_frames(imgs_dir, timestamps, fps)
|
||||
result["load_time_images_ms"] = time_benchmark.result_ms / num_frames
|
||||
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
for i in range(num_frames):
|
||||
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
|
||||
result["psnr_values"].append(
|
||||
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
|
||||
)
|
||||
result["ssim_values"].append(
|
||||
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
|
||||
)
|
||||
|
||||
if save_frames and sample == 0:
|
||||
save_dir = video_path.with_suffix("") / f"{timestamps_mode}_{backend}"
|
||||
save_decoded_frames(imgs_dir, save_dir, frames, timestamps, fps)
|
||||
|
||||
return result
|
||||
|
||||
load_times_video_ms = []
|
||||
load_times_images_ms = []
|
||||
mse_values = []
|
||||
psnr_values = []
|
||||
ssim_values = []
|
||||
|
||||
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
|
||||
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
|
||||
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
|
||||
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
|
||||
result = future.result()
|
||||
load_times_video_ms.append(result["load_time_video_ms"])
|
||||
load_times_images_ms.append(result["load_time_images_ms"])
|
||||
psnr_values.extend(result["psnr_values"])
|
||||
ssim_values.extend(result["ssim_values"])
|
||||
mse_values.extend(result["mse_values"])
|
||||
|
||||
avg_load_time_video_ms = float(np.array(load_times_video_ms).mean())
|
||||
avg_load_time_images_ms = float(np.array(load_times_images_ms).mean())
|
||||
video_images_load_time_ratio = avg_load_time_video_ms / avg_load_time_images_ms
|
||||
|
||||
return {
|
||||
"avg_load_time_video_ms": avg_load_time_video_ms,
|
||||
"avg_load_time_images_ms": avg_load_time_images_ms,
|
||||
"video_images_load_time_ratio": video_images_load_time_ratio,
|
||||
"avg_mse": float(np.mean(mse_values)),
|
||||
"avg_psnr": float(np.mean(psnr_values)),
|
||||
"avg_ssim": float(np.mean(ssim_values)),
|
||||
}
|
||||
|
||||
|
||||
def benchmark_encoding_decoding(
|
||||
dataset: LeRobotDataset,
|
||||
video_path: Path,
|
||||
imgs_dir: Path,
|
||||
encoding_cfg: dict,
|
||||
decoding_cfg: dict,
|
||||
num_samples: int,
|
||||
num_workers: int,
|
||||
save_frames: bool,
|
||||
overwrite: bool = False,
|
||||
seed: int = 1337,
|
||||
) -> list[dict]:
|
||||
fps = dataset.fps
|
||||
|
||||
if overwrite or not video_path.is_file():
|
||||
tqdm.write(f"encoding {video_path}")
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=encoding_cfg["vcodec"],
|
||||
pix_fmt=encoding_cfg["pix_fmt"],
|
||||
g=encoding_cfg.get("g"),
|
||||
crf=encoding_cfg.get("crf"),
|
||||
# fast_decode=encoding_cfg.get("fastdecode"),
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
width, height = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:])
|
||||
num_pixels = width * height
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
images_size_bytes = get_directory_size(imgs_dir)
|
||||
video_images_size_ratio = video_size_bytes / images_size_bytes
|
||||
|
||||
random.seed(seed)
|
||||
benchmark_table = []
|
||||
for timestamps_mode in tqdm(
|
||||
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
|
||||
):
|
||||
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
|
||||
benchmark_row = benchmark_decoding(
|
||||
imgs_dir,
|
||||
video_path,
|
||||
timestamps_mode,
|
||||
backend,
|
||||
ep_num_images,
|
||||
fps,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
benchmark_row.update(
|
||||
**{
|
||||
"repo_id": dataset.repo_id,
|
||||
"resolution": f"{width} x {height}",
|
||||
"num_pixels": num_pixels,
|
||||
"video_size_bytes": video_size_bytes,
|
||||
"images_size_bytes": images_size_bytes,
|
||||
"video_images_size_ratio": video_images_size_ratio,
|
||||
"timestamps_mode": timestamps_mode,
|
||||
"backend": backend,
|
||||
},
|
||||
**encoding_cfg,
|
||||
)
|
||||
benchmark_table.append(benchmark_row)
|
||||
|
||||
return benchmark_table
|
||||
|
||||
|
||||
def main(
|
||||
output_dir: Path,
|
||||
repo_ids: list[str],
|
||||
vcodec: list[str],
|
||||
pix_fmt: list[str],
|
||||
g: list[int],
|
||||
crf: list[int],
|
||||
# fastdecode: list[int],
|
||||
timestamps_modes: list[str],
|
||||
backends: list[str],
|
||||
num_samples: int,
|
||||
num_workers: int,
|
||||
save_frames: bool,
|
||||
):
|
||||
check_datasets_formats(repo_ids)
|
||||
encoding_benchmarks = {
|
||||
"g": g,
|
||||
"crf": crf,
|
||||
# "fastdecode": fastdecode,
|
||||
}
|
||||
decoding_benchmarks = {
|
||||
"timestamps_modes": timestamps_modes,
|
||||
"backends": backends,
|
||||
}
|
||||
headers = ["repo_id", "resolution", "num_pixels"]
|
||||
headers += list(BASE_ENCODING.keys())
|
||||
headers += [
|
||||
"timestamps_mode",
|
||||
"backend",
|
||||
"video_size_bytes",
|
||||
"images_size_bytes",
|
||||
"video_images_size_ratio",
|
||||
"avg_load_time_video_ms",
|
||||
"avg_load_time_images_ms",
|
||||
"video_images_load_time_ratio",
|
||||
"avg_mse",
|
||||
"avg_psnr",
|
||||
"avg_ssim",
|
||||
]
|
||||
file_paths = []
|
||||
for video_codec in tqdm(vcodec, desc="encodings (vcodec)"):
|
||||
for pixel_format in tqdm(pix_fmt, desc="encodings (pix_fmt)", leave=False):
|
||||
benchmark_table = []
|
||||
for repo_id in tqdm(repo_ids, desc="encodings (datasets)", leave=False):
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
|
||||
# We only use the first episode
|
||||
save_first_episode(imgs_dir, dataset)
|
||||
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
|
||||
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
encoding_cfg[key] = value
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
imgs_dir,
|
||||
encoding_cfg,
|
||||
decoding_benchmarks,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
|
||||
# Save intermediate results
|
||||
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
|
||||
now = dt.datetime.now()
|
||||
csv_path = (
|
||||
output_dir
|
||||
/ f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{video_codec}_{pixel_format}_{num_samples}-samples.csv"
|
||||
)
|
||||
benchmark_df.to_csv(csv_path, header=True, index=False)
|
||||
file_paths.append(csv_path)
|
||||
del benchmark_df
|
||||
|
||||
# Concatenate all results
|
||||
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
|
||||
concatenated_df = pd.concat(df_list, ignore_index=True)
|
||||
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
|
||||
concatenated_df.to_csv(concatenated_path, header=True, index=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("outputs/video_benchmark"),
|
||||
help="Directory where the video benchmark outputs are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-ids",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[
|
||||
"lerobot/pusht_image",
|
||||
"aliberts/aloha_mobile_shrimp_image",
|
||||
"aliberts/paris_street",
|
||||
"aliberts/kitchen",
|
||||
],
|
||||
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vcodec",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["libx264", "libx265", "libsvtav1"],
|
||||
help="Video codecs to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pix-fmt",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["yuv444p", "yuv420p"],
|
||||
help="Pixel formats (chroma subsampling) to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g",
|
||||
type=parse_int_or_none,
|
||||
nargs="*",
|
||||
default=[1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None],
|
||||
help="Group of pictures sizes to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crf",
|
||||
type=parse_int_or_none,
|
||||
nargs="*",
|
||||
default=[0, 5, 10, 15, 20, 25, 30, 40, 50, None],
|
||||
help="Constant rate factors to be tested.",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--fastdecode",
|
||||
# type=int,
|
||||
# nargs="*",
|
||||
# default=[0, 1],
|
||||
# help="Use the fastdecode tuning option. 0 disables it. "
|
||||
# "For libx264 and libx265, only 1 is possible. "
|
||||
# "For libsvtav1, 1, 2 or 3 are possible values with a higher number meaning a faster decoding optimization",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--timestamps-modes",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[
|
||||
"1_frame",
|
||||
"2_frames",
|
||||
"2_frames_4_space",
|
||||
"6_frames",
|
||||
],
|
||||
help="Timestamps scenarios to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backends",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["pyav", "video_reader"],
|
||||
help="Torchvision decoding backend to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of samples for each encoding x decoding config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of processes for parallelized sample processing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-frames",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Whether to save decoded frames or not. Enter a non-zero number for true.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(**vars(args))
|
||||
@@ -8,8 +8,7 @@ ARG DEBIAN_FRONTEND=noninteractive
|
||||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create virtual environment
|
||||
@@ -22,7 +21,7 @@ RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, koch]" \
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04
|
||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
|
||||
|
||||
# Configure image
|
||||
ARG PYTHON_VERSION=3.10
|
||||
@@ -8,42 +8,14 @@ ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
git git-lfs openssh-client \
|
||||
nano vim less util-linux tree \
|
||||
nano vim less util-linux \
|
||||
htop atop nvtop \
|
||||
sed gawk grep curl wget zip unzip \
|
||||
sed gawk grep curl wget \
|
||||
tcpdump sysstat screen tmux \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
speech-dispatcher \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install ffmpeg build dependencies. See:
|
||||
# https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu
|
||||
# TODO(aliberts): create image to build dependencies from source instead
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
autoconf automake yasm \
|
||||
libass-dev \
|
||||
libfreetype6-dev \
|
||||
libgnutls28-dev \
|
||||
libunistring-dev \
|
||||
libmp3lame-dev \
|
||||
libtool \
|
||||
libvorbis-dev \
|
||||
meson \
|
||||
ninja-build \
|
||||
pkg-config \
|
||||
texinfo \
|
||||
yasm \
|
||||
zlib1g-dev \
|
||||
nasm \
|
||||
libx264-dev \
|
||||
libx265-dev libnuma-dev \
|
||||
libvpx-dev \
|
||||
libfdk-aac-dev \
|
||||
libopus-dev \
|
||||
libsvtav1-dev libsvtav1enc-dev libsvtav1dec-dev \
|
||||
libdav1d-dev
|
||||
|
||||
# Install gh cli tool
|
||||
RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
|
||||
&& mkdir -p -m 755 /etc/apt/keyrings \
|
||||
|
||||
@@ -8,9 +8,8 @@ ARG DEBIAN_FRONTEND=noninteractive
|
||||
# Install apt dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher \
|
||||
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -24,7 +23,7 @@ RUN echo "source /opt/venv/bin/activate" >> /root/.bashrc
|
||||
COPY . /lerobot
|
||||
WORKDIR /lerobot
|
||||
RUN pip install --upgrade --no-cache-dir pip
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht, koch]"
|
||||
RUN pip install --no-cache-dir ".[test, aloha, xarm, pusht]"
|
||||
|
||||
# Set EGL as the rendering backend for MuJoCo
|
||||
ENV MUJOCO_GL="egl"
|
||||
|
||||
@@ -18,6 +18,8 @@ from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
output_directory = Path("outputs/eval/example_pusht_diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Download the diffusion policy for pusht environment
|
||||
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
@@ -25,17 +27,6 @@ pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||
policy.eval()
|
||||
|
||||
# Check if GPU is available
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
print("GPU is available. Device set to:", device)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.")
|
||||
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
|
||||
policy.diffusion.num_inference_steps = 10
|
||||
|
||||
policy.to(device)
|
||||
|
||||
# Initialize evaluation environment to render two observation types:
|
||||
|
||||
@@ -46,7 +46,7 @@ defaults:
|
||||
- policy: diffusion
|
||||
```
|
||||
|
||||
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overridden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
|
||||
This logic tells Hydra to incorporate configuration parameters from `env/pusht.yaml` and `policy/diffusion.yaml`. _Note: Be aware of the order as any configuration parameters with the same name will be overidden. Thus, `default.yaml` is overriden by `env/pusht.yaml` which is overidden by `policy/diffusion.yaml`_.
|
||||
|
||||
Then, `default.yaml` also contains common configuration parameters such as `device: cuda` or `use_amp: false` (for enabling fp16 training). Some other parameters are set to `???` which indicates that they are expected to be set in additional yaml files. For instance, `training.offline_steps: ???` in `default.yaml` is set to `200000` in `diffusion.yaml`.
|
||||
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
"""
|
||||
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
||||
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
||||
transforms are applied to the observation images before they are returned in the dataset's __get_item__.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from torchvision.transforms import ToPILImage, v2
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset_repo_id = "lerobot/aloha_static_tape"
|
||||
|
||||
# Create a LeRobotDataset with no transformations
|
||||
dataset = LeRobotDataset(dataset_repo_id)
|
||||
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
|
||||
|
||||
# Get the index of the first observation in the first episode
|
||||
first_idx = dataset.episode_data_index["from"][0].item()
|
||||
|
||||
# Get the frame corresponding to the first camera
|
||||
frame = dataset[first_idx][dataset.camera_keys[0]]
|
||||
|
||||
|
||||
# Define the transformations
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
v2.ColorJitter(brightness=(0.5, 1.5)),
|
||||
v2.ColorJitter(contrast=(0.5, 1.5)),
|
||||
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
|
||||
]
|
||||
)
|
||||
|
||||
# Create another LeRobotDataset with the defined transformations
|
||||
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
|
||||
|
||||
# Get a frame from the transformed dataset
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
|
||||
|
||||
# Create a directory to store output images
|
||||
output_dir = Path("outputs/image_transforms")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save the original frame
|
||||
to_pil = ToPILImage()
|
||||
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
|
||||
print(f"Original frame saved to {output_dir / 'original_frame.png'}.")
|
||||
|
||||
# Save the transformed frame
|
||||
to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100)
|
||||
print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -80,7 +80,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
89
examples/real_robot_example/README.md
Normal file
89
examples/real_robot_example/README.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# Using `lerobot` on a real world arm
|
||||
|
||||
|
||||
In this example, we'll be using `lerobot` on a real world arm to:
|
||||
- record a dataset in the `lerobot` format
|
||||
- (soon) train a policy on it
|
||||
- (soon) run the policy in the real-world
|
||||
|
||||
## Which robotic arm to use
|
||||
|
||||
In this example we're using the [open-source low-cost arm from Alexander Koch](https://github.com/AlexanderKoch-Koch/low_cost_robot) in the specific setup of:
|
||||
- having 6 servos per arm, i.e. using the elbow-to-wrist extension
|
||||
- adding two cameras around it, one on top and one in the front
|
||||
- having a teleoperation arm as well (build the leader and the follower arms in A. Koch repo, both with elbow-to-wrist extensions)
|
||||
|
||||
I'm using these cameras (but the setup should not be sensitive to the exact cameras you're using):
|
||||
- C922 Pro Stream Webcam
|
||||
- Intel(R) RealSense D455 (using only the RGB input)
|
||||
|
||||
|
||||
In general, this example should be very easily extendable to any type of arm using Dynamixel servos with at least one camera by changing a couple of configuration in the gym env.
|
||||
|
||||
## Install the example
|
||||
|
||||
Follow these steps:
|
||||
- install `lerobot`
|
||||
- install the Dynamixel-sdk: `pip install dynamixel-sdk`
|
||||
|
||||
## Usage
|
||||
|
||||
### 0 - record examples
|
||||
|
||||
Run the `record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g.
|
||||
```
|
||||
DATA_DIR='./data' python record_training_data.py \
|
||||
--repo-id=thomwolf/blue_red_sort \
|
||||
--num-episodes=50 \
|
||||
--num-frames=400
|
||||
```
|
||||
|
||||
TODO:
|
||||
- various length episodes
|
||||
- being able to drop episodes
|
||||
- checking uploading to the hub
|
||||
|
||||
### 1 - visualize the dataset
|
||||
|
||||
Use the standard dataset visualization script pointing it to the right folder:
|
||||
```
|
||||
DATA_DIR='./data' python ../../lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id thomwolf/blue_red_sort \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
### 2 - Train a policy
|
||||
|
||||
From the example directory let's run this command to train a model using ACT
|
||||
|
||||
```
|
||||
DATA_DIR='./data' python ../../lerobot/scripts/train.py \
|
||||
device=cuda \
|
||||
hydra.searchpath=[file://./train_config/] \
|
||||
hydra.run.dir=./outputs/train/blue_red_sort \
|
||||
dataset_repo_id=thomwolf/blue_red_sort \
|
||||
env=gym_real_world \
|
||||
policy=act_real_world \
|
||||
wandb.enable=false
|
||||
```
|
||||
|
||||
### 3 - Evaluate the policy in the real world
|
||||
|
||||
From the example directory let's run this command to evaluate our policy.
|
||||
The configuration for running the policy is in the checkpoint of the model.
|
||||
You can override parameters as follow:
|
||||
|
||||
```
|
||||
python run_policy.py \
|
||||
-p ./outputs/train/blue_red_sort/checkpoints/last/pretrained_model/
|
||||
env.episode_length=1000
|
||||
```
|
||||
|
||||
|
||||
## Convert a hdf5 dataset recorded with the original ACT repo
|
||||
|
||||
You can convert a dataset from the raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act with the following command:
|
||||
|
||||
```
|
||||
python ./lerobot/scripts/push_dataset_to_hub.py
|
||||
```
|
||||
@@ -0,0 +1,840 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from safetensors.torch import load_file, save_file\n",
|
||||
"from pprint import pprint"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 52,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"original_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/policy_last.ckpt\"\n",
|
||||
"converted_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/model.safetensors\"\n",
|
||||
"\n",
|
||||
"comparison_main_path = \"/home/thomwolf/Documents/Github/lerobot/examples/real_robot_example/outputs/train/blue_red_debug_no_masking/checkpoints/last/pretrained_model/\"\n",
|
||||
"comparison_safetensor_path = comparison_main_path + \"model.safetensors\"\n",
|
||||
"comparison_config_json_path = comparison_main_path + \"config.json\"\n",
|
||||
"comparison_config_yaml_path = comparison_main_path + \"config.yaml\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"a = torch.load(original_ckpt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"b = load_file(comparison_safetensor_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['model.action_head.bias',\n",
|
||||
" 'model.action_head.weight',\n",
|
||||
" 'model.backbone.bn1.bias',\n",
|
||||
" 'model.backbone.bn1.running_mean',\n",
|
||||
" 'model.backbone.bn1.running_var',\n",
|
||||
" 'model.backbone.bn1.weight',\n",
|
||||
" 'model.backbone.conv1.weight',\n",
|
||||
" 'model.backbone.layer1.0.bn1.bias',\n",
|
||||
" 'model.backbone.layer1.0.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer1.0.bn1.running_var',\n",
|
||||
" 'model.backbone.layer1.0.bn1.weight',\n",
|
||||
" 'model.backbone.layer1.0.bn2.bias',\n",
|
||||
" 'model.backbone.layer1.0.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer1.0.bn2.running_var',\n",
|
||||
" 'model.backbone.layer1.0.bn2.weight',\n",
|
||||
" 'model.backbone.layer1.0.conv1.weight',\n",
|
||||
" 'model.backbone.layer1.0.conv2.weight',\n",
|
||||
" 'model.backbone.layer1.1.bn1.bias',\n",
|
||||
" 'model.backbone.layer1.1.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer1.1.bn1.running_var',\n",
|
||||
" 'model.backbone.layer1.1.bn1.weight',\n",
|
||||
" 'model.backbone.layer1.1.bn2.bias',\n",
|
||||
" 'model.backbone.layer1.1.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer1.1.bn2.running_var',\n",
|
||||
" 'model.backbone.layer1.1.bn2.weight',\n",
|
||||
" 'model.backbone.layer1.1.conv1.weight',\n",
|
||||
" 'model.backbone.layer1.1.conv2.weight',\n",
|
||||
" 'model.backbone.layer2.0.bn1.bias',\n",
|
||||
" 'model.backbone.layer2.0.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer2.0.bn1.running_var',\n",
|
||||
" 'model.backbone.layer2.0.bn1.weight',\n",
|
||||
" 'model.backbone.layer2.0.bn2.bias',\n",
|
||||
" 'model.backbone.layer2.0.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer2.0.bn2.running_var',\n",
|
||||
" 'model.backbone.layer2.0.bn2.weight',\n",
|
||||
" 'model.backbone.layer2.0.conv1.weight',\n",
|
||||
" 'model.backbone.layer2.0.conv2.weight',\n",
|
||||
" 'model.backbone.layer2.0.downsample.0.weight',\n",
|
||||
" 'model.backbone.layer2.0.downsample.1.bias',\n",
|
||||
" 'model.backbone.layer2.0.downsample.1.running_mean',\n",
|
||||
" 'model.backbone.layer2.0.downsample.1.running_var',\n",
|
||||
" 'model.backbone.layer2.0.downsample.1.weight',\n",
|
||||
" 'model.backbone.layer2.1.bn1.bias',\n",
|
||||
" 'model.backbone.layer2.1.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer2.1.bn1.running_var',\n",
|
||||
" 'model.backbone.layer2.1.bn1.weight',\n",
|
||||
" 'model.backbone.layer2.1.bn2.bias',\n",
|
||||
" 'model.backbone.layer2.1.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer2.1.bn2.running_var',\n",
|
||||
" 'model.backbone.layer2.1.bn2.weight',\n",
|
||||
" 'model.backbone.layer2.1.conv1.weight',\n",
|
||||
" 'model.backbone.layer2.1.conv2.weight',\n",
|
||||
" 'model.backbone.layer3.0.bn1.bias',\n",
|
||||
" 'model.backbone.layer3.0.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer3.0.bn1.running_var',\n",
|
||||
" 'model.backbone.layer3.0.bn1.weight',\n",
|
||||
" 'model.backbone.layer3.0.bn2.bias',\n",
|
||||
" 'model.backbone.layer3.0.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer3.0.bn2.running_var',\n",
|
||||
" 'model.backbone.layer3.0.bn2.weight',\n",
|
||||
" 'model.backbone.layer3.0.conv1.weight',\n",
|
||||
" 'model.backbone.layer3.0.conv2.weight',\n",
|
||||
" 'model.backbone.layer3.0.downsample.0.weight',\n",
|
||||
" 'model.backbone.layer3.0.downsample.1.bias',\n",
|
||||
" 'model.backbone.layer3.0.downsample.1.running_mean',\n",
|
||||
" 'model.backbone.layer3.0.downsample.1.running_var',\n",
|
||||
" 'model.backbone.layer3.0.downsample.1.weight',\n",
|
||||
" 'model.backbone.layer3.1.bn1.bias',\n",
|
||||
" 'model.backbone.layer3.1.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer3.1.bn1.running_var',\n",
|
||||
" 'model.backbone.layer3.1.bn1.weight',\n",
|
||||
" 'model.backbone.layer3.1.bn2.bias',\n",
|
||||
" 'model.backbone.layer3.1.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer3.1.bn2.running_var',\n",
|
||||
" 'model.backbone.layer3.1.bn2.weight',\n",
|
||||
" 'model.backbone.layer3.1.conv1.weight',\n",
|
||||
" 'model.backbone.layer3.1.conv2.weight',\n",
|
||||
" 'model.backbone.layer4.0.bn1.bias',\n",
|
||||
" 'model.backbone.layer4.0.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer4.0.bn1.running_var',\n",
|
||||
" 'model.backbone.layer4.0.bn1.weight',\n",
|
||||
" 'model.backbone.layer4.0.bn2.bias',\n",
|
||||
" 'model.backbone.layer4.0.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer4.0.bn2.running_var',\n",
|
||||
" 'model.backbone.layer4.0.bn2.weight',\n",
|
||||
" 'model.backbone.layer4.0.conv1.weight',\n",
|
||||
" 'model.backbone.layer4.0.conv2.weight',\n",
|
||||
" 'model.backbone.layer4.0.downsample.0.weight',\n",
|
||||
" 'model.backbone.layer4.0.downsample.1.bias',\n",
|
||||
" 'model.backbone.layer4.0.downsample.1.running_mean',\n",
|
||||
" 'model.backbone.layer4.0.downsample.1.running_var',\n",
|
||||
" 'model.backbone.layer4.0.downsample.1.weight',\n",
|
||||
" 'model.backbone.layer4.1.bn1.bias',\n",
|
||||
" 'model.backbone.layer4.1.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer4.1.bn1.running_var',\n",
|
||||
" 'model.backbone.layer4.1.bn1.weight',\n",
|
||||
" 'model.backbone.layer4.1.bn2.bias',\n",
|
||||
" 'model.backbone.layer4.1.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer4.1.bn2.running_var',\n",
|
||||
" 'model.backbone.layer4.1.bn2.weight',\n",
|
||||
" 'model.backbone.layer4.1.conv1.weight',\n",
|
||||
" 'model.backbone.layer4.1.conv2.weight',\n",
|
||||
" 'model.decoder.layers.0.linear1.bias',\n",
|
||||
" 'model.decoder.layers.0.linear1.weight',\n",
|
||||
" 'model.decoder.layers.0.linear2.bias',\n",
|
||||
" 'model.decoder.layers.0.linear2.weight',\n",
|
||||
" 'model.decoder.layers.0.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.decoder.layers.0.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.decoder.layers.0.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.decoder.layers.0.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.decoder.layers.0.norm1.bias',\n",
|
||||
" 'model.decoder.layers.0.norm1.weight',\n",
|
||||
" 'model.decoder.layers.0.norm2.bias',\n",
|
||||
" 'model.decoder.layers.0.norm2.weight',\n",
|
||||
" 'model.decoder.layers.0.norm3.bias',\n",
|
||||
" 'model.decoder.layers.0.norm3.weight',\n",
|
||||
" 'model.decoder.layers.0.self_attn.in_proj_bias',\n",
|
||||
" 'model.decoder.layers.0.self_attn.in_proj_weight',\n",
|
||||
" 'model.decoder.layers.0.self_attn.out_proj.bias',\n",
|
||||
" 'model.decoder.layers.0.self_attn.out_proj.weight',\n",
|
||||
" 'model.decoder_pos_embed.weight',\n",
|
||||
" 'model.encoder.layers.0.linear1.bias',\n",
|
||||
" 'model.encoder.layers.0.linear1.weight',\n",
|
||||
" 'model.encoder.layers.0.linear2.bias',\n",
|
||||
" 'model.encoder.layers.0.linear2.weight',\n",
|
||||
" 'model.encoder.layers.0.norm1.bias',\n",
|
||||
" 'model.encoder.layers.0.norm1.weight',\n",
|
||||
" 'model.encoder.layers.0.norm2.bias',\n",
|
||||
" 'model.encoder.layers.0.norm2.weight',\n",
|
||||
" 'model.encoder.layers.0.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.0.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.0.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.0.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.1.linear1.bias',\n",
|
||||
" 'model.encoder.layers.1.linear1.weight',\n",
|
||||
" 'model.encoder.layers.1.linear2.bias',\n",
|
||||
" 'model.encoder.layers.1.linear2.weight',\n",
|
||||
" 'model.encoder.layers.1.norm1.bias',\n",
|
||||
" 'model.encoder.layers.1.norm1.weight',\n",
|
||||
" 'model.encoder.layers.1.norm2.bias',\n",
|
||||
" 'model.encoder.layers.1.norm2.weight',\n",
|
||||
" 'model.encoder.layers.1.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.1.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.1.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.1.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.2.linear1.bias',\n",
|
||||
" 'model.encoder.layers.2.linear1.weight',\n",
|
||||
" 'model.encoder.layers.2.linear2.bias',\n",
|
||||
" 'model.encoder.layers.2.linear2.weight',\n",
|
||||
" 'model.encoder.layers.2.norm1.bias',\n",
|
||||
" 'model.encoder.layers.2.norm1.weight',\n",
|
||||
" 'model.encoder.layers.2.norm2.bias',\n",
|
||||
" 'model.encoder.layers.2.norm2.weight',\n",
|
||||
" 'model.encoder.layers.2.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.2.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.2.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.2.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.3.linear1.bias',\n",
|
||||
" 'model.encoder.layers.3.linear1.weight',\n",
|
||||
" 'model.encoder.layers.3.linear2.bias',\n",
|
||||
" 'model.encoder.layers.3.linear2.weight',\n",
|
||||
" 'model.encoder.layers.3.norm1.bias',\n",
|
||||
" 'model.encoder.layers.3.norm1.weight',\n",
|
||||
" 'model.encoder.layers.3.norm2.bias',\n",
|
||||
" 'model.encoder.layers.3.norm2.weight',\n",
|
||||
" 'model.encoder.layers.3.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.3.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.3.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.3.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder_img_feat_input_proj.bias',\n",
|
||||
" 'model.encoder_img_feat_input_proj.weight',\n",
|
||||
" 'model.encoder_latent_input_proj.bias',\n",
|
||||
" 'model.encoder_latent_input_proj.weight',\n",
|
||||
" 'model.encoder_robot_and_latent_pos_embed.weight',\n",
|
||||
" 'model.encoder_robot_state_input_proj.bias',\n",
|
||||
" 'model.encoder_robot_state_input_proj.weight',\n",
|
||||
" 'model.vae_encoder.layers.0.linear1.bias',\n",
|
||||
" 'model.vae_encoder.layers.0.linear1.weight',\n",
|
||||
" 'model.vae_encoder.layers.0.linear2.bias',\n",
|
||||
" 'model.vae_encoder.layers.0.linear2.weight',\n",
|
||||
" 'model.vae_encoder.layers.0.norm1.bias',\n",
|
||||
" 'model.vae_encoder.layers.0.norm1.weight',\n",
|
||||
" 'model.vae_encoder.layers.0.norm2.bias',\n",
|
||||
" 'model.vae_encoder.layers.0.norm2.weight',\n",
|
||||
" 'model.vae_encoder.layers.0.self_attn.in_proj_bias',\n",
|
||||
" 'model.vae_encoder.layers.0.self_attn.in_proj_weight',\n",
|
||||
" 'model.vae_encoder.layers.0.self_attn.out_proj.bias',\n",
|
||||
" 'model.vae_encoder.layers.0.self_attn.out_proj.weight',\n",
|
||||
" 'model.vae_encoder.layers.1.linear1.bias',\n",
|
||||
" 'model.vae_encoder.layers.1.linear1.weight',\n",
|
||||
" 'model.vae_encoder.layers.1.linear2.bias',\n",
|
||||
" 'model.vae_encoder.layers.1.linear2.weight',\n",
|
||||
" 'model.vae_encoder.layers.1.norm1.bias',\n",
|
||||
" 'model.vae_encoder.layers.1.norm1.weight',\n",
|
||||
" 'model.vae_encoder.layers.1.norm2.bias',\n",
|
||||
" 'model.vae_encoder.layers.1.norm2.weight',\n",
|
||||
" 'model.vae_encoder.layers.1.self_attn.in_proj_bias',\n",
|
||||
" 'model.vae_encoder.layers.1.self_attn.in_proj_weight',\n",
|
||||
" 'model.vae_encoder.layers.1.self_attn.out_proj.bias',\n",
|
||||
" 'model.vae_encoder.layers.1.self_attn.out_proj.weight',\n",
|
||||
" 'model.vae_encoder.layers.2.linear1.bias',\n",
|
||||
" 'model.vae_encoder.layers.2.linear1.weight',\n",
|
||||
" 'model.vae_encoder.layers.2.linear2.bias',\n",
|
||||
" 'model.vae_encoder.layers.2.linear2.weight',\n",
|
||||
" 'model.vae_encoder.layers.2.norm1.bias',\n",
|
||||
" 'model.vae_encoder.layers.2.norm1.weight',\n",
|
||||
" 'model.vae_encoder.layers.2.norm2.bias',\n",
|
||||
" 'model.vae_encoder.layers.2.norm2.weight',\n",
|
||||
" 'model.vae_encoder.layers.2.self_attn.in_proj_bias',\n",
|
||||
" 'model.vae_encoder.layers.2.self_attn.in_proj_weight',\n",
|
||||
" 'model.vae_encoder.layers.2.self_attn.out_proj.bias',\n",
|
||||
" 'model.vae_encoder.layers.2.self_attn.out_proj.weight',\n",
|
||||
" 'model.vae_encoder.layers.3.linear1.bias',\n",
|
||||
" 'model.vae_encoder.layers.3.linear1.weight',\n",
|
||||
" 'model.vae_encoder.layers.3.linear2.bias',\n",
|
||||
" 'model.vae_encoder.layers.3.linear2.weight',\n",
|
||||
" 'model.vae_encoder.layers.3.norm1.bias',\n",
|
||||
" 'model.vae_encoder.layers.3.norm1.weight',\n",
|
||||
" 'model.vae_encoder.layers.3.norm2.bias',\n",
|
||||
" 'model.vae_encoder.layers.3.norm2.weight',\n",
|
||||
" 'model.vae_encoder.layers.3.self_attn.in_proj_bias',\n",
|
||||
" 'model.vae_encoder.layers.3.self_attn.in_proj_weight',\n",
|
||||
" 'model.vae_encoder.layers.3.self_attn.out_proj.bias',\n",
|
||||
" 'model.vae_encoder.layers.3.self_attn.out_proj.weight',\n",
|
||||
" 'model.vae_encoder_action_input_proj.bias',\n",
|
||||
" 'model.vae_encoder_action_input_proj.weight',\n",
|
||||
" 'model.vae_encoder_cls_embed.weight',\n",
|
||||
" 'model.vae_encoder_latent_output_proj.bias',\n",
|
||||
" 'model.vae_encoder_latent_output_proj.weight',\n",
|
||||
" 'model.vae_encoder_pos_enc',\n",
|
||||
" 'model.vae_encoder_robot_state_input_proj.bias',\n",
|
||||
" 'model.vae_encoder_robot_state_input_proj.weight',\n",
|
||||
" 'normalize_inputs.buffer_observation_images_front.mean',\n",
|
||||
" 'normalize_inputs.buffer_observation_images_front.std',\n",
|
||||
" 'normalize_inputs.buffer_observation_images_top.mean',\n",
|
||||
" 'normalize_inputs.buffer_observation_images_top.std',\n",
|
||||
" 'normalize_inputs.buffer_observation_state.mean',\n",
|
||||
" 'normalize_inputs.buffer_observation_state.std',\n",
|
||||
" 'normalize_targets.buffer_action.mean',\n",
|
||||
" 'normalize_targets.buffer_action.std',\n",
|
||||
" 'unnormalize_outputs.buffer_action.mean',\n",
|
||||
" 'unnormalize_outputs.buffer_action.std']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dest = list(b.keys())\n",
|
||||
"pprint(dest)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['model.pos_table',\n",
|
||||
" 'model.transformer.encoder.layers.0.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.encoder.layers.0.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.encoder.layers.0.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.encoder.layers.0.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.encoder.layers.0.linear1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.0.linear1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.0.linear2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.0.linear2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.0.norm1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.0.norm1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.0.norm2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.0.norm2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.1.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.encoder.layers.1.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.encoder.layers.1.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.encoder.layers.1.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.encoder.layers.1.linear1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.1.linear1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.1.linear2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.1.linear2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.1.norm1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.1.norm1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.1.norm2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.1.norm2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.2.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.encoder.layers.2.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.encoder.layers.2.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.encoder.layers.2.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.encoder.layers.2.linear1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.2.linear1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.2.linear2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.2.linear2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.2.norm1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.2.norm1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.2.norm2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.2.norm2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.3.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.encoder.layers.3.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.encoder.layers.3.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.encoder.layers.3.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.encoder.layers.3.linear1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.3.linear1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.3.linear2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.3.linear2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.3.norm1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.3.norm1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.3.norm2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.3.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.norm.weight',\n",
|
||||
" 'model.transformer.decoder.norm.bias',\n",
|
||||
" 'model.encoder.layers.0.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.0.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.0.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.0.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.0.linear1.weight',\n",
|
||||
" 'model.encoder.layers.0.linear1.bias',\n",
|
||||
" 'model.encoder.layers.0.linear2.weight',\n",
|
||||
" 'model.encoder.layers.0.linear2.bias',\n",
|
||||
" 'model.encoder.layers.0.norm1.weight',\n",
|
||||
" 'model.encoder.layers.0.norm1.bias',\n",
|
||||
" 'model.encoder.layers.0.norm2.weight',\n",
|
||||
" 'model.encoder.layers.0.norm2.bias',\n",
|
||||
" 'model.encoder.layers.1.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.1.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.1.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.1.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.1.linear1.weight',\n",
|
||||
" 'model.encoder.layers.1.linear1.bias',\n",
|
||||
" 'model.encoder.layers.1.linear2.weight',\n",
|
||||
" 'model.encoder.layers.1.linear2.bias',\n",
|
||||
" 'model.encoder.layers.1.norm1.weight',\n",
|
||||
" 'model.encoder.layers.1.norm1.bias',\n",
|
||||
" 'model.encoder.layers.1.norm2.weight',\n",
|
||||
" 'model.encoder.layers.1.norm2.bias',\n",
|
||||
" 'model.encoder.layers.2.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.2.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.2.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.2.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.2.linear1.weight',\n",
|
||||
" 'model.encoder.layers.2.linear1.bias',\n",
|
||||
" 'model.encoder.layers.2.linear2.weight',\n",
|
||||
" 'model.encoder.layers.2.linear2.bias',\n",
|
||||
" 'model.encoder.layers.2.norm1.weight',\n",
|
||||
" 'model.encoder.layers.2.norm1.bias',\n",
|
||||
" 'model.encoder.layers.2.norm2.weight',\n",
|
||||
" 'model.encoder.layers.2.norm2.bias',\n",
|
||||
" 'model.encoder.layers.3.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.3.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.3.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.3.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.3.linear1.weight',\n",
|
||||
" 'model.encoder.layers.3.linear1.bias',\n",
|
||||
" 'model.encoder.layers.3.linear2.weight',\n",
|
||||
" 'model.encoder.layers.3.linear2.bias',\n",
|
||||
" 'model.encoder.layers.3.norm1.weight',\n",
|
||||
" 'model.encoder.layers.3.norm1.bias',\n",
|
||||
" 'model.encoder.layers.3.norm2.weight',\n",
|
||||
" 'model.encoder.layers.3.norm2.bias',\n",
|
||||
" 'model.action_head.weight',\n",
|
||||
" 'model.action_head.bias',\n",
|
||||
" 'model.is_pad_head.weight',\n",
|
||||
" 'model.is_pad_head.bias',\n",
|
||||
" 'model.query_embed.weight',\n",
|
||||
" 'model.input_proj.weight',\n",
|
||||
" 'model.input_proj.bias',\n",
|
||||
" 'model.backbones.0.0.body.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.downsample.0.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.downsample.1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.downsample.1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.downsample.1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.downsample.1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.downsample.1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.downsample.0.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.downsample.1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.downsample.1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.downsample.1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.downsample.1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.downsample.1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.downsample.0.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.downsample.1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.downsample.1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.downsample.1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.downsample.1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.downsample.1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn2.num_batches_tracked',\n",
|
||||
" 'model.input_proj_robot_state.weight',\n",
|
||||
" 'model.input_proj_robot_state.bias',\n",
|
||||
" 'model.cls_embed.weight',\n",
|
||||
" 'model.encoder_action_proj.weight',\n",
|
||||
" 'model.encoder_action_proj.bias',\n",
|
||||
" 'model.encoder_joint_proj.weight',\n",
|
||||
" 'model.encoder_joint_proj.bias',\n",
|
||||
" 'model.latent_proj.weight',\n",
|
||||
" 'model.latent_proj.bias',\n",
|
||||
" 'model.latent_out_proj.weight',\n",
|
||||
" 'model.latent_out_proj.bias',\n",
|
||||
" 'model.additional_pos_embed.weight']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"orig = list(a.keys())\n",
|
||||
"pprint(orig)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"a = torch.load(original_ckpt_path)\n",
|
||||
"\n",
|
||||
"to_remove_startswith = ['model.transformer.decoder.layers.1.',\n",
|
||||
" 'model.transformer.decoder.layers.2.',\n",
|
||||
" 'model.transformer.decoder.layers.3.',\n",
|
||||
" 'model.transformer.decoder.layers.4.',\n",
|
||||
" 'model.transformer.decoder.layers.5.',\n",
|
||||
" 'model.transformer.decoder.layers.6.',\n",
|
||||
" 'model.transformer.decoder.norm.',\n",
|
||||
" 'model.is_pad_head']\n",
|
||||
"\n",
|
||||
"to_remove_in = ['num_batches_tracked',]\n",
|
||||
"\n",
|
||||
"conv = {}\n",
|
||||
"\n",
|
||||
"keys = list(a.keys())\n",
|
||||
"for k in keys:\n",
|
||||
" if any(k.startswith(tr) for tr in to_remove_startswith):\n",
|
||||
" a.pop(k)\n",
|
||||
" continue\n",
|
||||
" if any(tr in k for tr in to_remove_in):\n",
|
||||
" a.pop(k)\n",
|
||||
" continue\n",
|
||||
" if k.startswith('model.transformer.encoder.layers.'):\n",
|
||||
" conv[k.replace('transformer.', '')] = a.pop(k)\n",
|
||||
" if k.startswith('model.transformer.decoder.layers.0.'):\n",
|
||||
" conv[k.replace('transformer.', '')] = a.pop(k)\n",
|
||||
" if k.startswith('model.encoder.layers.'):\n",
|
||||
" conv[k.replace('encoder.', 'vae_encoder.')] = a.pop(k)\n",
|
||||
" if k.startswith('model.action_head.'):\n",
|
||||
" conv[k] = a.pop(k)\n",
|
||||
" if k.startswith('model.pos_table'):\n",
|
||||
" conv[k.replace('pos_table', 'vae_encoder_pos_enc')] = a.pop(k)\n",
|
||||
" if k.startswith('model.query_embed.'):\n",
|
||||
" conv[k.replace('query_embed', 'decoder_pos_embed')] = a.pop(k)\n",
|
||||
" if k.startswith('model.input_proj.'):\n",
|
||||
" conv[k.replace('input_proj.', 'encoder_img_feat_input_proj.')] = a.pop(k)\n",
|
||||
" if k.startswith('model.input_proj_robot_state.'):\n",
|
||||
" conv[k.replace('input_proj_robot_state.', 'encoder_robot_state_input_proj.')] = a.pop(k)\n",
|
||||
" if k.startswith('model.backbones.0.0.body.'):\n",
|
||||
" conv[k.replace('backbones.0.0.body', 'backbone')] = a.pop(k)\n",
|
||||
" if k.startswith('model.cls_embed.'):\n",
|
||||
" conv[k.replace('cls_embed', 'vae_encoder_cls_embed')] = a.pop(k)\n",
|
||||
" if k.startswith('model.encoder_action_proj.'):\n",
|
||||
" conv[k.replace('encoder_action_proj', 'vae_encoder_action_input_proj')] = a.pop(k)\n",
|
||||
" if k.startswith('model.encoder_joint_proj.'):\n",
|
||||
" conv[k.replace('encoder_joint_proj', 'vae_encoder_robot_state_input_proj')] = a.pop(k)\n",
|
||||
" if k.startswith('model.latent_proj.'):\n",
|
||||
" conv[k.replace('latent_proj', 'vae_encoder_latent_output_proj')] = a.pop(k)\n",
|
||||
" if k.startswith('model.latent_out_proj.'):\n",
|
||||
" conv[k.replace('latent_out_proj', 'encoder_latent_input_proj')] = a.pop(k)\n",
|
||||
" if k.startswith('model.additional_pos_embed.'):\n",
|
||||
" conv[k.replace('additional_pos_embed', 'encoder_robot_and_latent_pos_embed')] = a.pop(k)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"OrderedDict()"
|
||||
]
|
||||
},
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for k, v in conv.items():\n",
|
||||
" assert b[k].shape == v.shape\n",
|
||||
" b[k] = v"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 53,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"save_file(b, converted_ckpt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 54,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/config.yaml'"
|
||||
]
|
||||
},
|
||||
"execution_count": 54,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Now also copy the config files\n",
|
||||
"import shutil\n",
|
||||
"shutil.copy(comparison_config_json_path, converted_ckpt_path.replace('model.safetensors', 'config.json'))\n",
|
||||
"shutil.copy(comparison_config_yaml_path, converted_ckpt_path.replace('model.safetensors', 'config.yaml'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "lerobot",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
8
examples/real_robot_example/gym_real_world/__init__.py
Normal file
8
examples/real_robot_example/gym_real_world/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from gymnasium.envs.registration import register
|
||||
|
||||
register(
|
||||
id="gym_real_world/RealEnv-v0",
|
||||
entry_point="gym_real_world.gym_environment:RealEnv",
|
||||
max_episode_steps=300,
|
||||
nondeterministic=True,
|
||||
)
|
||||
363
examples/real_robot_example/gym_real_world/dynamixel.py
Normal file
363
examples/real_robot_example/gym_real_world/dynamixel.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# ruff: noqa
|
||||
"""From Alexander Koch low_cost_robot codebase at https://github.com/AlexanderKoch-Koch/low_cost_robot
|
||||
Dynamixel class to control the dynamixel servos
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
from dynamixel_sdk import * # Uses Dynamixel SDK library
|
||||
|
||||
|
||||
def pos2pwm(pos: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
:param pos: numpy array of joint positions in range [-pi, pi]
|
||||
:return: numpy array of pwm values in range [0, 4096]
|
||||
"""
|
||||
return ((pos / 3.14 + 1.0) * 2048).astype(np.int64)
|
||||
|
||||
|
||||
def pwm2pos(pwm: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
:param pwm: numpy array of pwm values in range [0, 4096]
|
||||
:return: numpy array of joint positions in range [-pi, pi]
|
||||
"""
|
||||
return (pwm / 2048 - 1) * 3.14
|
||||
|
||||
|
||||
def pwm2vel(pwm: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
:param pwm: numpy array of pwm/s joint velocities
|
||||
:return: numpy array of rad/s joint velocities
|
||||
"""
|
||||
return pwm * 3.14 / 2048
|
||||
|
||||
|
||||
def vel2pwm(vel: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
:param vel: numpy array of rad/s joint velocities
|
||||
:return: numpy array of pwm/s joint velocities
|
||||
"""
|
||||
return (vel * 2048 / 3.14).astype(np.int64)
|
||||
|
||||
|
||||
class ReadAttribute(enum.Enum):
|
||||
TEMPERATURE = 146
|
||||
VOLTAGE = 145
|
||||
VELOCITY = 128
|
||||
POSITION = 132
|
||||
CURRENT = 126
|
||||
PWM = 124
|
||||
HARDWARE_ERROR_STATUS = 70
|
||||
HOMING_OFFSET = 20
|
||||
BAUDRATE = 8
|
||||
|
||||
|
||||
class OperatingMode(enum.Enum):
|
||||
VELOCITY = 1
|
||||
POSITION = 3
|
||||
CURRENT_CONTROLLED_POSITION = 5
|
||||
PWM = 16
|
||||
UNKNOWN = -1
|
||||
|
||||
|
||||
class Dynamixel:
|
||||
ADDR_TORQUE_ENABLE = 64
|
||||
ADDR_GOAL_POSITION = 116
|
||||
ADDR_VELOCITY_LIMIT = 44
|
||||
ADDR_GOAL_PWM = 100
|
||||
OPERATING_MODE_ADDR = 11
|
||||
POSITION_I = 82
|
||||
POSITION_P = 84
|
||||
ADDR_ID = 7
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
def instantiate(self):
|
||||
return Dynamixel(self)
|
||||
|
||||
baudrate: int = 57600
|
||||
protocol_version: float = 2.0
|
||||
device_name: str = "" # /dev/tty.usbserial-1120'
|
||||
dynamixel_id: int = 1
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.connect()
|
||||
|
||||
def connect(self):
|
||||
if self.config.device_name == "":
|
||||
for port_name in os.listdir("/dev"):
|
||||
if "ttyUSB" in port_name or "ttyACM" in port_name:
|
||||
self.config.device_name = "/dev/" + port_name
|
||||
print(f"using device {self.config.device_name}")
|
||||
self.portHandler = PortHandler(self.config.device_name)
|
||||
# self.portHandler.LA
|
||||
self.packetHandler = PacketHandler(self.config.protocol_version)
|
||||
if not self.portHandler.openPort():
|
||||
raise Exception(f"Failed to open port {self.config.device_name}")
|
||||
|
||||
if not self.portHandler.setBaudRate(self.config.baudrate):
|
||||
raise Exception(f"failed to set baudrate to {self.config.baudrate}")
|
||||
|
||||
# self.operating_mode = OperatingMode.UNKNOWN
|
||||
# self.torque_enabled = False
|
||||
# self._disable_torque()
|
||||
|
||||
self.operating_modes = [None for _ in range(32)]
|
||||
self.torque_enabled = [None for _ in range(32)]
|
||||
return True
|
||||
|
||||
def disconnect(self):
|
||||
self.portHandler.closePort()
|
||||
|
||||
def set_goal_position(self, motor_id, goal_position):
|
||||
# if self.operating_modes[motor_id] is not OperatingMode.POSITION:
|
||||
# self._disable_torque(motor_id)
|
||||
# self.set_operating_mode(motor_id, OperatingMode.POSITION)
|
||||
|
||||
# if not self.torque_enabled[motor_id]:
|
||||
# self._enable_torque(motor_id)
|
||||
|
||||
# self._enable_torque(motor_id)
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
|
||||
self.portHandler, motor_id, self.ADDR_GOAL_POSITION, goal_position
|
||||
)
|
||||
# self._process_response(dxl_comm_result, dxl_error)
|
||||
# print(f'set position of motor {motor_id} to {goal_position}')
|
||||
|
||||
def set_pwm_value(self, motor_id: int, pwm_value, tries=3):
|
||||
if self.operating_modes[motor_id] is not OperatingMode.PWM:
|
||||
self._disable_torque(motor_id)
|
||||
self.set_operating_mode(motor_id, OperatingMode.PWM)
|
||||
|
||||
if not self.torque_enabled[motor_id]:
|
||||
self._enable_torque(motor_id)
|
||||
# print(f'enabling torque')
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
|
||||
self.portHandler, motor_id, self.ADDR_GOAL_PWM, pwm_value
|
||||
)
|
||||
# self._process_response(dxl_comm_result, dxl_error)
|
||||
# print(f'set pwm of motor {motor_id} to {pwm_value}')
|
||||
if dxl_comm_result != COMM_SUCCESS:
|
||||
if tries <= 1:
|
||||
raise ConnectionError(f"dxl_comm_result: {self.packetHandler.getTxRxResult(dxl_comm_result)}")
|
||||
else:
|
||||
print(f"dynamixel pwm setting failure trying again with {tries - 1} tries")
|
||||
self.set_pwm_value(motor_id, pwm_value, tries=tries - 1)
|
||||
elif dxl_error != 0:
|
||||
print(f"dxl error {dxl_error}")
|
||||
raise ConnectionError(f"dynamixel error: {self.packetHandler.getTxRxResult(dxl_error)}")
|
||||
|
||||
def read_temperature(self, motor_id: int):
|
||||
return self._read_value(motor_id, ReadAttribute.TEMPERATURE, 1)
|
||||
|
||||
def read_velocity(self, motor_id: int):
|
||||
pos = self._read_value(motor_id, ReadAttribute.VELOCITY, 4)
|
||||
if pos > 2**31:
|
||||
pos -= 2**32
|
||||
# print(f'read position {pos} for motor {motor_id}')
|
||||
return pos
|
||||
|
||||
def read_position(self, motor_id: int):
|
||||
pos = self._read_value(motor_id, ReadAttribute.POSITION, 4)
|
||||
if pos > 2**31:
|
||||
pos -= 2**32
|
||||
# print(f'read position {pos} for motor {motor_id}')
|
||||
return pos
|
||||
|
||||
def read_position_degrees(self, motor_id: int) -> float:
|
||||
return (self.read_position(motor_id) / 4096) * 360
|
||||
|
||||
def read_position_radians(self, motor_id: int) -> float:
|
||||
return (self.read_position(motor_id) / 4096) * 2 * math.pi
|
||||
|
||||
def read_current(self, motor_id: int):
|
||||
current = self._read_value(motor_id, ReadAttribute.CURRENT, 2)
|
||||
if current > 2**15:
|
||||
current -= 2**16
|
||||
return current
|
||||
|
||||
def read_present_pwm(self, motor_id: int):
|
||||
return self._read_value(motor_id, ReadAttribute.PWM, 2)
|
||||
|
||||
def read_hardware_error_status(self, motor_id: int):
|
||||
return self._read_value(motor_id, ReadAttribute.HARDWARE_ERROR_STATUS, 1)
|
||||
|
||||
def disconnect(self):
|
||||
self.portHandler.closePort()
|
||||
|
||||
def set_id(self, old_id, new_id, use_broadcast_id: bool = False):
|
||||
"""
|
||||
sets the id of the dynamixel servo
|
||||
@param old_id: current id of the servo
|
||||
@param new_id: new id
|
||||
@param use_broadcast_id: set ids of all connected dynamixels if True.
|
||||
If False, change only servo with self.config.id
|
||||
@return:
|
||||
"""
|
||||
if use_broadcast_id:
|
||||
current_id = 254
|
||||
else:
|
||||
current_id = old_id
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
|
||||
self.portHandler, current_id, self.ADDR_ID, new_id
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, old_id)
|
||||
self.config.id = id
|
||||
|
||||
def _enable_torque(self, motor_id):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
|
||||
self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 1
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
self.torque_enabled[motor_id] = True
|
||||
|
||||
def _disable_torque(self, motor_id):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
|
||||
self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 0
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
self.torque_enabled[motor_id] = False
|
||||
|
||||
def _process_response(self, dxl_comm_result: int, dxl_error: int, motor_id: int):
|
||||
if dxl_comm_result != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"dxl_comm_result for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_comm_result)}"
|
||||
)
|
||||
elif dxl_error != 0:
|
||||
print(f"dxl error {dxl_error}")
|
||||
raise ConnectionError(
|
||||
f"dynamixel error for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_error)}"
|
||||
)
|
||||
|
||||
def set_operating_mode(self, motor_id: int, operating_mode: OperatingMode):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
|
||||
self.portHandler, motor_id, self.OPERATING_MODE_ADDR, operating_mode.value
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
self.operating_modes[motor_id] = operating_mode
|
||||
|
||||
def set_pwm_limit(self, motor_id: int, limit: int):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(self.portHandler, motor_id, 36, limit)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
|
||||
def set_velocity_limit(self, motor_id: int, velocity_limit):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
|
||||
self.portHandler, motor_id, self.ADDR_VELOCITY_LIMIT, velocity_limit
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
|
||||
def set_P(self, motor_id: int, P: int):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
|
||||
self.portHandler, motor_id, self.POSITION_P, P
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
|
||||
def set_I(self, motor_id: int, I: int):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
|
||||
self.portHandler, motor_id, self.POSITION_I, I
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
|
||||
def read_home_offset(self, motor_id: int):
|
||||
self._disable_torque(motor_id)
|
||||
# dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(self.portHandler, motor_id,
|
||||
# ReadAttribute.HOMING_OFFSET.value, home_position)
|
||||
home_offset = self._read_value(motor_id, ReadAttribute.HOMING_OFFSET, 4)
|
||||
# self._process_response(dxl_comm_result, dxl_error)
|
||||
self._enable_torque(motor_id)
|
||||
return home_offset
|
||||
|
||||
def set_home_offset(self, motor_id: int, home_position: int):
|
||||
self._disable_torque(motor_id)
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
|
||||
self.portHandler, motor_id, ReadAttribute.HOMING_OFFSET.value, home_position
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
self._enable_torque(motor_id)
|
||||
|
||||
def set_baudrate(self, motor_id: int, baudrate):
|
||||
# translate baudrate into dynamixel baudrate setting id
|
||||
if baudrate == 57600:
|
||||
baudrate_id = 1
|
||||
elif baudrate == 1_000_000:
|
||||
baudrate_id = 3
|
||||
elif baudrate == 2_000_000:
|
||||
baudrate_id = 4
|
||||
elif baudrate == 3_000_000:
|
||||
baudrate_id = 5
|
||||
elif baudrate == 4_000_000:
|
||||
baudrate_id = 6
|
||||
else:
|
||||
raise Exception("baudrate not implemented")
|
||||
|
||||
self._disable_torque(motor_id)
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
|
||||
self.portHandler, motor_id, ReadAttribute.BAUDRATE.value, baudrate_id
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
|
||||
def _read_value(self, motor_id, attribute: ReadAttribute, num_bytes: int, tries=10):
|
||||
try:
|
||||
if num_bytes == 1:
|
||||
value, dxl_comm_result, dxl_error = self.packetHandler.read1ByteTxRx(
|
||||
self.portHandler, motor_id, attribute.value
|
||||
)
|
||||
elif num_bytes == 2:
|
||||
value, dxl_comm_result, dxl_error = self.packetHandler.read2ByteTxRx(
|
||||
self.portHandler, motor_id, attribute.value
|
||||
)
|
||||
elif num_bytes == 4:
|
||||
value, dxl_comm_result, dxl_error = self.packetHandler.read4ByteTxRx(
|
||||
self.portHandler, motor_id, attribute.value
|
||||
)
|
||||
except Exception:
|
||||
if tries == 0:
|
||||
raise Exception
|
||||
else:
|
||||
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
|
||||
if dxl_comm_result != COMM_SUCCESS:
|
||||
if tries <= 1:
|
||||
# print("%s" % self.packetHandler.getTxRxResult(dxl_comm_result))
|
||||
raise ConnectionError(f"dxl_comm_result {dxl_comm_result} for servo {motor_id} value {value}")
|
||||
else:
|
||||
print(f"dynamixel read failure for servo {motor_id} trying again with {tries - 1} tries")
|
||||
time.sleep(0.02)
|
||||
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
|
||||
elif dxl_error != 0: # # print("%s" % self.packetHandler.getRxPacketError(dxl_error))
|
||||
# raise ConnectionError(f'dxl_error {dxl_error} binary ' + "{0:b}".format(37))
|
||||
if tries == 0 and dxl_error != 128:
|
||||
raise Exception(f"Failed to read value from motor {motor_id} error is {dxl_error}")
|
||||
else:
|
||||
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
|
||||
return value
|
||||
|
||||
def set_home_position(self, motor_id: int):
|
||||
print(f"setting home position for motor {motor_id}")
|
||||
self.set_home_offset(motor_id, 0)
|
||||
current_position = self.read_position(motor_id)
|
||||
print(f"position before {current_position}")
|
||||
self.set_home_offset(motor_id, -current_position)
|
||||
# dynamixel.set_home_offset(motor_id, -4096)
|
||||
# dynamixel.set_home_offset(motor_id, -4294964109)
|
||||
current_position = self.read_position(motor_id)
|
||||
# print(f'signed position {current_position - 2** 32}')
|
||||
print(f"position after {current_position}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dynamixel = Dynamixel.Config(baudrate=1_000_000, device_name="/dev/tty.usbmodem57380045631").instantiate()
|
||||
motor_id = 1
|
||||
pos = dynamixel.read_position(motor_id)
|
||||
for i in range(10):
|
||||
s = time.monotonic()
|
||||
pos = dynamixel.read_position(motor_id)
|
||||
delta = time.monotonic() - s
|
||||
print(f"read position took {delta}")
|
||||
print(f"position {pos}")
|
||||
192
examples/real_robot_example/gym_real_world/gym_environment.py
Normal file
192
examples/real_robot_example/gym_real_world/gym_environment.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import cv2
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from .dynamixel import pos2pwm, pwm2pos
|
||||
from .robot import Robot
|
||||
|
||||
FPS = 30
|
||||
|
||||
CAMERAS_SHAPES = {
|
||||
"images.high": (480, 640, 3),
|
||||
"images.low": (480, 640, 3),
|
||||
}
|
||||
|
||||
CAMERAS_PORTS = {
|
||||
"images.high": "/dev/video6",
|
||||
"images.low": "/dev/video0",
|
||||
}
|
||||
|
||||
LEADER_PORT = "/dev/ttyACM1"
|
||||
FOLLOWER_PORT = "/dev/ttyACM0"
|
||||
|
||||
MockRobot = MagicMock()
|
||||
MockRobot.read_position = MagicMock()
|
||||
MockRobot.read_position.return_value = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
|
||||
MockCamera = MagicMock()
|
||||
MockCamera.isOpened = MagicMock(return_value=True)
|
||||
MockCamera.read = MagicMock(return_value=(True, np.zeros((480, 640, 3), dtype=np.uint8)))
|
||||
|
||||
|
||||
def capture_image(cam, cam_width, cam_height):
|
||||
# Capture a single frame
|
||||
_, frame = cam.read()
|
||||
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
# # Define your crop coordinates (top left corner and bottom right corner)
|
||||
# x1, y1 = 400, 0 # Example starting coordinates (top left of the crop rectangle)
|
||||
# x2, y2 = 1600, 900 # Example ending coordinates (bottom right of the crop rectangle)
|
||||
# # Crop the image
|
||||
# image = image[y1:y2, x1:x2]
|
||||
# Resize the image
|
||||
image = cv2.resize(image, (cam_width, cam_height), interpolation=cv2.INTER_AREA)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class RealEnv(gym.Env):
|
||||
metadata = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
record: bool = False,
|
||||
num_joints: int = 6,
|
||||
cameras_shapes: dict = CAMERAS_SHAPES,
|
||||
cameras_ports: dict = CAMERAS_PORTS,
|
||||
follower_port: str = FOLLOWER_PORT,
|
||||
leader_port: str = LEADER_PORT,
|
||||
warmup_steps: int = 100,
|
||||
trigger_torque=70,
|
||||
fps: int = FPS,
|
||||
fps_tolerance: float = 0.1,
|
||||
mock: bool = False,
|
||||
):
|
||||
self.num_joints = num_joints
|
||||
self.cameras_shapes = cameras_shapes
|
||||
self.cameras_ports = cameras_ports
|
||||
self.warmup_steps = warmup_steps
|
||||
assert len(self.cameras_shapes) == len(self.cameras_ports), "Number of cameras and shapes must match."
|
||||
|
||||
self.follower_port = follower_port
|
||||
self.leader_port = leader_port
|
||||
self.record = record
|
||||
self.fps = fps
|
||||
self.fps_tolerance = fps_tolerance
|
||||
|
||||
# Initialize the robot
|
||||
self.follower = Robot(device_name=self.follower_port) if not mock else MockRobot
|
||||
if self.record:
|
||||
self.leader = Robot(device_name=self.leader_port) if not mock else MockRobot
|
||||
self.leader.set_trigger_torque(trigger_torque)
|
||||
|
||||
# Initialize the cameras - sorted by camera names
|
||||
self.cameras = {}
|
||||
for cn, p in sorted(self.cameras_ports.items()):
|
||||
self.cameras[cn] = cv2.VideoCapture(p) if not mock else MockCamera
|
||||
if not self.cameras[cn].isOpened():
|
||||
raise OSError(
|
||||
f"Cannot open camera port {p} for {cn}."
|
||||
f" Make sure the camera is connected and the port is correct."
|
||||
f"Also check you are not spinning several instances of the same environment (eval.batch_size)"
|
||||
)
|
||||
|
||||
# Specify gym action and observation spaces
|
||||
observation_space = {}
|
||||
|
||||
if self.num_joints > 0:
|
||||
observation_space["agent_pos"] = spaces.Box(
|
||||
low=-1000.0,
|
||||
high=1000.0,
|
||||
shape=(num_joints,),
|
||||
dtype=np.float64,
|
||||
)
|
||||
if self.record:
|
||||
observation_space["leader_pos"] = spaces.Box(
|
||||
low=-1000.0,
|
||||
high=1000.0,
|
||||
shape=(num_joints,),
|
||||
dtype=np.float64,
|
||||
)
|
||||
|
||||
if self.cameras_shapes:
|
||||
for cn, hwc_shape in self.cameras_shapes.items():
|
||||
# Assumes images are unsigned int8 in [0,255]
|
||||
observation_space[cn] = spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
# height x width x channels (e.g. 480 x 640 x 3)
|
||||
shape=hwc_shape,
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
self.observation_space = spaces.Dict(observation_space)
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(num_joints,), dtype=np.float32)
|
||||
|
||||
self._observation = {}
|
||||
self._terminated = False
|
||||
self.timestamps = []
|
||||
|
||||
def _get_obs(self):
|
||||
qpos = self.follower.read_position()
|
||||
self._observation["agent_pos"] = pwm2pos(qpos)
|
||||
for cn, c in self.cameras.items():
|
||||
self._observation[cn] = capture_image(c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0])
|
||||
|
||||
if self.record:
|
||||
action = self.leader.read_position()
|
||||
self._observation["leader_pos"] = pwm2pos(action)
|
||||
|
||||
def reset(self, seed: int | None = None):
|
||||
# Reset the robot and sync the leader and follower if we are recording
|
||||
for _ in range(self.warmup_steps):
|
||||
self._get_obs()
|
||||
if self.record:
|
||||
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
|
||||
self._terminated = False
|
||||
info = {}
|
||||
self.timestamps = []
|
||||
return self._observation, info
|
||||
|
||||
def step(self, action: np.ndarray = None):
|
||||
if self.timestamps:
|
||||
# wait the right amount of time to stay at the desired fps
|
||||
time.sleep(max(0, 1 / self.fps - (time.time() - self.timestamps[-1])))
|
||||
|
||||
self.timestamps.append(time.time())
|
||||
|
||||
# Get the observation
|
||||
self._get_obs()
|
||||
if self.record:
|
||||
# Teleoperate the leader
|
||||
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
|
||||
else:
|
||||
# Apply the action to the follower
|
||||
self.follower.set_goal_pos(pos2pwm(action))
|
||||
|
||||
reward = 0
|
||||
terminated = truncated = self._terminated
|
||||
info = {"timestamp": self.timestamps[-1] - self.timestamps[0], "fps_error": False}
|
||||
|
||||
# Check if we are able to keep up with the desired fps
|
||||
if len(self.timestamps) > 1 and (self.timestamps[-1] - self.timestamps[-2]) > 1 / (
|
||||
self.fps - self.fps_tolerance
|
||||
):
|
||||
print(
|
||||
f"Error: recording fps {1 / (self.timestamps[-1] - self.timestamps[-2]):.5f} is lower"
|
||||
f" than min admited fps {(self.fps - self.fps_tolerance):.5f}"
|
||||
f" at frame {len(self.timestamps)}"
|
||||
)
|
||||
info["fps_error"] = True
|
||||
|
||||
return self._observation, reward, terminated, truncated, info
|
||||
|
||||
def render(self): ...
|
||||
|
||||
def close(self):
|
||||
self.follower._disable_torque()
|
||||
if self.record:
|
||||
self.leader._disable_torque()
|
||||
173
examples/real_robot_example/gym_real_world/robot.py
Normal file
173
examples/real_robot_example/gym_real_world/robot.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# ruff: noqa
|
||||
"""From Alexander Koch low_cost_robot codebase at https://github.com/AlexanderKoch-Koch/low_cost_robot
|
||||
Class to control the robot using dynamixel servos.
|
||||
"""
|
||||
|
||||
from enum import Enum, auto
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from dynamixel_sdk import DXL_HIBYTE, DXL_HIWORD, DXL_LOBYTE, DXL_LOWORD, GroupSyncRead, GroupSyncWrite
|
||||
|
||||
from .dynamixel import Dynamixel, OperatingMode, ReadAttribute
|
||||
|
||||
|
||||
class MotorControlType(Enum):
|
||||
PWM = auto()
|
||||
POSITION_CONTROL = auto()
|
||||
DISABLED = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
|
||||
class Robot:
|
||||
def __init__(self, device_name: str, baudrate=1_000_000, servo_ids=[1, 2, 3, 4, 5, 6]) -> None:
|
||||
self.servo_ids = servo_ids
|
||||
self.dynamixel = Dynamixel.Config(baudrate=baudrate, device_name=device_name).instantiate()
|
||||
self._init_motors()
|
||||
|
||||
def _init_motors(self):
|
||||
self.position_reader = GroupSyncRead(
|
||||
self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.POSITION.value, 4
|
||||
)
|
||||
for id in self.servo_ids:
|
||||
self.position_reader.addParam(id)
|
||||
|
||||
self.velocity_reader = GroupSyncRead(
|
||||
self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.VELOCITY.value, 4
|
||||
)
|
||||
for id in self.servo_ids:
|
||||
self.velocity_reader.addParam(id)
|
||||
|
||||
self.pos_writer = GroupSyncWrite(
|
||||
self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_POSITION, 4
|
||||
)
|
||||
for id in self.servo_ids:
|
||||
self.pos_writer.addParam(id, [2048])
|
||||
|
||||
self.pwm_writer = GroupSyncWrite(
|
||||
self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_PWM, 2
|
||||
)
|
||||
for id in self.servo_ids:
|
||||
self.pwm_writer.addParam(id, [2048])
|
||||
# self._disable_torque()
|
||||
self.motor_control_state = MotorControlType.DISABLED
|
||||
|
||||
def read_position(self, tries=2):
|
||||
"""
|
||||
Reads the joint positions of the robot. 2048 is the center position. 0 and 4096 are 180 degrees in each direction.
|
||||
:param tries: maximum number of tries to read the position
|
||||
:return: list of joint positions in range [0, 4096]
|
||||
"""
|
||||
result = self.position_reader.txRxPacket()
|
||||
if result != 0:
|
||||
if tries > 0:
|
||||
return self.read_position(tries=tries - 1)
|
||||
else:
|
||||
print("failed to read position!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
positions = []
|
||||
for id in self.servo_ids:
|
||||
position = self.position_reader.getData(id, ReadAttribute.POSITION.value, 4)
|
||||
if position > 2**31:
|
||||
position -= 2**32
|
||||
positions.append(position)
|
||||
return np.array(positions)
|
||||
|
||||
def read_velocity(self):
|
||||
"""
|
||||
Reads the joint velocities of the robot.
|
||||
:return: list of joint velocities,
|
||||
"""
|
||||
self.velocity_reader.txRxPacket()
|
||||
velocties = []
|
||||
for id in self.servo_ids:
|
||||
velocity = self.velocity_reader.getData(id, ReadAttribute.VELOCITY.value, 4)
|
||||
if velocity > 2**31:
|
||||
velocity -= 2**32
|
||||
velocties.append(velocity)
|
||||
return np.array(velocties)
|
||||
|
||||
def set_goal_pos(self, action):
|
||||
"""
|
||||
:param action: list or numpy array of target joint positions in range [0, 4096[
|
||||
"""
|
||||
if self.motor_control_state is not MotorControlType.POSITION_CONTROL:
|
||||
self._set_position_control()
|
||||
for i, motor_id in enumerate(self.servo_ids):
|
||||
data_write = [
|
||||
DXL_LOBYTE(DXL_LOWORD(action[i])),
|
||||
DXL_HIBYTE(DXL_LOWORD(action[i])),
|
||||
DXL_LOBYTE(DXL_HIWORD(action[i])),
|
||||
DXL_HIBYTE(DXL_HIWORD(action[i])),
|
||||
]
|
||||
self.pos_writer.changeParam(motor_id, data_write)
|
||||
|
||||
self.pos_writer.txPacket()
|
||||
|
||||
def set_pwm(self, action):
|
||||
"""
|
||||
Sets the pwm values for the servos.
|
||||
:param action: list or numpy array of pwm values in range [0, 885]
|
||||
"""
|
||||
if self.motor_control_state is not MotorControlType.PWM:
|
||||
self._set_pwm_control()
|
||||
for i, motor_id in enumerate(self.servo_ids):
|
||||
data_write = [
|
||||
DXL_LOBYTE(DXL_LOWORD(action[i])),
|
||||
DXL_HIBYTE(DXL_LOWORD(action[i])),
|
||||
]
|
||||
self.pwm_writer.changeParam(motor_id, data_write)
|
||||
|
||||
self.pwm_writer.txPacket()
|
||||
|
||||
def set_trigger_torque(self, torque: int):
|
||||
"""
|
||||
Sets a constant torque torque for the last servo in the chain. This is useful for the trigger of the leader arm
|
||||
"""
|
||||
self.dynamixel._enable_torque(self.servo_ids[-1])
|
||||
self.dynamixel.set_pwm_value(self.servo_ids[-1], torque)
|
||||
|
||||
def limit_pwm(self, limit: Union[int, list, np.ndarray]):
|
||||
"""
|
||||
Limits the pwm values for the servos in for position control
|
||||
@param limit: 0 ~ 885
|
||||
@return:
|
||||
"""
|
||||
if isinstance(limit, int):
|
||||
limits = [
|
||||
limit,
|
||||
] * 5
|
||||
else:
|
||||
limits = limit
|
||||
self._disable_torque()
|
||||
for motor_id, limit in zip(self.servo_ids, limits, strict=False):
|
||||
self.dynamixel.set_pwm_limit(motor_id, limit)
|
||||
self._enable_torque()
|
||||
|
||||
def _disable_torque(self):
|
||||
print(f"disabling torque for servos {self.servo_ids}")
|
||||
for motor_id in self.servo_ids:
|
||||
self.dynamixel._disable_torque(motor_id)
|
||||
|
||||
def _enable_torque(self):
|
||||
print(f"enabling torque for servos {self.servo_ids}")
|
||||
for motor_id in self.servo_ids:
|
||||
self.dynamixel._enable_torque(motor_id)
|
||||
|
||||
def _set_pwm_control(self):
|
||||
self._disable_torque()
|
||||
for motor_id in self.servo_ids:
|
||||
self.dynamixel.set_operating_mode(motor_id, OperatingMode.PWM)
|
||||
self._enable_torque()
|
||||
self.motor_control_state = MotorControlType.PWM
|
||||
|
||||
def _set_position_control(self):
|
||||
self._disable_torque()
|
||||
for motor_id in self.servo_ids:
|
||||
# TODO(rcadene): redesign
|
||||
if motor_id == 9:
|
||||
self.dynamixel.set_operating_mode(9, OperatingMode.CURRENT_CONTROLLED_POSITION)
|
||||
else:
|
||||
self.dynamixel.set_operating_mode(motor_id, OperatingMode.POSITION)
|
||||
|
||||
self._enable_torque()
|
||||
self.motor_control_state = MotorControlType.POSITION_CONTROL
|
||||
237
examples/real_robot_example/record_training_data.py
Normal file
237
examples/real_robot_example/record_training_data.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""This script demonstrates how to record a LeRobot dataset of training data
|
||||
using a very simple gym environment (see in examples/real_robot_example/gym_real_world/gym_environment.py).
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import gym_real_world # noqa: F401
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset, Features, Sequence, Value
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, DATA_DIR, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_videos_to_hub, save_meta_data
|
||||
|
||||
# parse the repo_id name via command line
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--repo-id", type=str, default="thomwolf/blue_red_sort")
|
||||
parser.add_argument("--num-episodes", type=int, default=2)
|
||||
parser.add_argument("--num-frames", type=int, default=400)
|
||||
parser.add_argument("--num-workers", type=int, default=16)
|
||||
parser.add_argument("--keep-last", action="store_true")
|
||||
parser.add_argument("--data_dir", type=str, default=None)
|
||||
parser.add_argument("--push-to-hub", action="store_true")
|
||||
parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.")
|
||||
parser.add_argument(
|
||||
"--fps_tolerance",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Tolerance in fps for the recording before dropping episodes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
|
||||
)
|
||||
parser.add_argument("--gym-config", type=str, default=None, help="Path to the gym config file.")
|
||||
parser.add_argument("--mock_robot", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_id = args.repo_id
|
||||
num_episodes = args.num_episodes
|
||||
num_frames = args.num_frames
|
||||
revision = args.revision
|
||||
fps = args.fps
|
||||
fps_tolerance = args.fps_tolerance
|
||||
|
||||
out_data = DATA_DIR / repo_id if args.data_dir is None else Path(args.data_dir)
|
||||
|
||||
# During data collection, frames are stored as png images in `images_dir`
|
||||
images_dir = out_data / "images"
|
||||
# After data collection, png images of each episode are encoded into a mp4 file stored in `videos_dir`
|
||||
videos_dir = out_data / "videos"
|
||||
meta_data_dir = out_data / "meta_data"
|
||||
|
||||
gym_config = None
|
||||
if args.config is not None:
|
||||
gym_config = OmegaConf.load(args.config)
|
||||
|
||||
# Create image and video directories
|
||||
if not os.path.exists(images_dir):
|
||||
os.makedirs(images_dir, exist_ok=True)
|
||||
if not os.path.exists(videos_dir):
|
||||
os.makedirs(videos_dir, exist_ok=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
|
||||
gym_handle = "gym_real_world/RealEnv-v0"
|
||||
gym_kwargs = {}
|
||||
if gym_config is not None:
|
||||
gym_kwargs = OmegaConf.to_container(gym_config.gym_kwargs)
|
||||
env = gym.make(
|
||||
gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance, mock=True
|
||||
)
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
ep_fps = []
|
||||
id_from = 0
|
||||
id_to = 0
|
||||
os.system('spd-say "gym environment created"')
|
||||
|
||||
ep_idx = 0
|
||||
while ep_idx < num_episodes:
|
||||
# bring the follower to the leader and start camera
|
||||
env.reset()
|
||||
|
||||
os.system(f'spd-say "go {ep_idx}"')
|
||||
# init buffers
|
||||
obs_replay = {k: [] for k in env.observation_space}
|
||||
|
||||
drop_episode = False
|
||||
timestamps = []
|
||||
for _ in tqdm(range(num_frames)):
|
||||
# Apply the next action
|
||||
observation, _, _, _, info = env.step(action=None)
|
||||
# images_stacked = np.hstack(list(observation['pixels'].values()))
|
||||
# images_stacked = cv2.cvtColor(images_stacked, cv2.COLOR_RGB2BGR)
|
||||
# cv2.imshow('frame', images_stacked)
|
||||
|
||||
if info["fps_error"]:
|
||||
os.system(f'spd-say "Error fps too low, dropping episode {ep_idx}"')
|
||||
drop_episode = True
|
||||
break
|
||||
|
||||
# store data
|
||||
for key in observation:
|
||||
obs_replay[key].append(copy.deepcopy(observation[key]))
|
||||
timestamps.append(info["timestamp"])
|
||||
|
||||
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
# break
|
||||
|
||||
os.system('spd-say "stop"')
|
||||
|
||||
if not drop_episode:
|
||||
os.system(f'spd-say "saving episode {ep_idx}"')
|
||||
ep_dict = {}
|
||||
# store images in png and create the video
|
||||
for img_key in env.cameras:
|
||||
save_images_concurrently(
|
||||
obs_replay[img_key],
|
||||
images_dir / f"{img_key}_episode_{ep_idx:06d}",
|
||||
args.num_workers,
|
||||
)
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
# store the reference to the video frame
|
||||
ep_dict[f"observation.{img_key}"] = [
|
||||
{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps
|
||||
]
|
||||
|
||||
state = torch.tensor(np.array(obs_replay["agent_pos"]))
|
||||
action = torch.tensor(np.array(obs_replay["leader_pos"]))
|
||||
next_done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
next_done[-1] = True
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["action"] = action
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.tensor(timestamps)
|
||||
ep_dict["next.done"] = next_done
|
||||
ep_fps.append(num_frames / timestamps[-1])
|
||||
ep_dicts.append(ep_dict)
|
||||
print(f"Episode {ep_idx} done, fps: {ep_fps[-1]:.2f}")
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(
|
||||
id_from + num_frames if args.keep_last else id_from + num_frames - 1
|
||||
)
|
||||
|
||||
id_to = id_from + num_frames if args.keep_last else id_from + num_frames - 1
|
||||
id_from = id_to
|
||||
|
||||
ep_idx += 1
|
||||
|
||||
env.close()
|
||||
|
||||
os.system('spd-say "encode video frames"')
|
||||
for ep_idx in range(num_episodes):
|
||||
for img_key in env.cameras:
|
||||
# If necessary, we may want to encode the video
|
||||
# with variable frame rate: https://superuser.com/questions/1661901/encoding-video-from-vfr-still-images
|
||||
encode_video_frames(
|
||||
images_dir / f"{img_key}_episode_{ep_idx:06d}",
|
||||
videos_dir / f"{img_key}_episode_{ep_idx:06d}.mp4",
|
||||
ep_fps[ep_idx],
|
||||
)
|
||||
|
||||
os.system('spd-say "concatenate episodes"')
|
||||
data_dict = concatenate_episodes(
|
||||
ep_dicts, drop_episodes_last_frame=not args.keep_last
|
||||
) # Since our fps varies we are sometimes off tolerance for the last frame
|
||||
|
||||
features = {}
|
||||
|
||||
keys = [key for key in data_dict if "observation.images." in key]
|
||||
for key in keys:
|
||||
features[key] = VideoFrame()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
info = {
|
||||
"fps": sum(ep_fps) / len(ep_fps), # to have a good tolerance in data processing for the slowest video
|
||||
"video": 1,
|
||||
}
|
||||
|
||||
os.system('spd-say "from preloaded"')
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
version=revision,
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
videos_dir=videos_dir,
|
||||
)
|
||||
os.system('spd-say "compute stats"')
|
||||
stats = compute_stats(lerobot_dataset)
|
||||
|
||||
os.system('spd-say "save to disk"')
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(out_data / "train"))
|
||||
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
||||
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
|
||||
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
push_videos_to_hub(repo_id, videos_dir, revision=revision)
|
||||
60
examples/real_robot_example/run_policy.py
Normal file
60
examples/real_robot_example/run_policy.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import gym_real_world # noqa: F401
|
||||
import gymnasium as gym # noqa: F401
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
from lerobot.scripts.eval import eval
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"overrides",
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
pretrained_policy_path = Path(
|
||||
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
|
||||
)
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
)
|
||||
else:
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
)
|
||||
|
||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
|
||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||
raise ValueError(
|
||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
|
||||
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
|
||||
@@ -1,36 +1,37 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_koch_real.yaml` to train on real-world datasets collected on Alexander Koch's robots.
|
||||
# Compared to `act.yaml`, it contains 2 cameras (i.e. laptop, phone) instead of 1 camera (i.e. top).
|
||||
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
|
||||
# low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_koch_real \
|
||||
# env=koch_real
|
||||
# policy=act_real \
|
||||
# env=aloha_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/koch_pick_place_lego
|
||||
dataset_repo_id: ???
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.laptop:
|
||||
observation.images.high:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.phone:
|
||||
observation.images.low:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
offline_steps: 1000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
save_freq: 1000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
@@ -42,11 +43,11 @@ training:
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
n_episodes: 1
|
||||
batch_size: 1
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
@@ -54,21 +55,21 @@ policy:
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.laptop: [3, 480, 640]
|
||||
observation.images.phone: [3, 480, 640]
|
||||
observation.images.high: [3, 480, 640]
|
||||
observation.images.low: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
observation.images.high: mean_std
|
||||
observation.images.low: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
@@ -0,0 +1,103 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
|
||||
# front) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_real \
|
||||
# env=aloha_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: ???
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.top:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.front:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 1000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 1000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
batch_size: 1
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.top: [3, 480, 640]
|
||||
observation.images.front: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.top: mean_std
|
||||
observation.images.front: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -70,8 +70,6 @@ available_datasets_per_env = {
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted_image",
|
||||
],
|
||||
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
|
||||
# coupled with tests.
|
||||
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
|
||||
"xarm": [
|
||||
"lerobot/xarm_lift_medium",
|
||||
@@ -125,10 +123,6 @@ available_real_world_datasets = [
|
||||
"lerobot/aloha_static_vinh_cup_left",
|
||||
"lerobot/aloha_static_ziploc_slide",
|
||||
"lerobot/umi_cup_in_the_wild",
|
||||
"lerobot/unitreeh1_fold_clothes",
|
||||
"lerobot/unitreeh1_rearrange_objects",
|
||||
"lerobot/unitreeh1_two_robot_greeting",
|
||||
"lerobot/unitreeh1_warehouse",
|
||||
]
|
||||
|
||||
available_datasets = list(
|
||||
@@ -140,13 +134,12 @@ available_policies = [
|
||||
"act",
|
||||
"diffusion",
|
||||
"tdmpc",
|
||||
"vqbet",
|
||||
]
|
||||
|
||||
# keys and values refer to yaml files
|
||||
available_policies_per_env = {
|
||||
"aloha": ["act"],
|
||||
"pusht": ["diffusion", "vqbet"],
|
||||
"pusht": ["diffusion"],
|
||||
"xarm": ["tdmpc"],
|
||||
"dora_aloha_real": ["act_real"],
|
||||
}
|
||||
|
||||
334
lerobot/common/datasets/_video_benchmark/README.md
Normal file
334
lerobot/common/datasets/_video_benchmark/README.md
Normal file
@@ -0,0 +1,334 @@
|
||||
# Video benchmark
|
||||
|
||||
|
||||
## Questions
|
||||
|
||||
What is the optimal trade-off between:
|
||||
- maximizing loading time with random access,
|
||||
- minimizing memory space on disk,
|
||||
- maximizing success rate of policies?
|
||||
|
||||
How to encode videos?
|
||||
- How much compression (`-crf`)? Low compression with `0`, normal compression with `20` or extreme with `56`?
|
||||
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
|
||||
- How many key frames (`-g`)? A key frame every `10` frames?
|
||||
|
||||
How to decode videos?
|
||||
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
|
||||
|
||||
## Metrics
|
||||
|
||||
**Percentage of data compression (higher is better)**
|
||||
`compression_factor` is the ratio of the memory space on disk taken by the original images to encode, to the memory space taken by the encoded video. For instance, `compression_factor=4` means that the video takes 4 times less memory space on disk compared to the original images.
|
||||
|
||||
**Percentage of loading time (higher is better)**
|
||||
`load_time_factor` is the ratio of the time it takes to load original images at given timestamps, to the time it takes to decode the exact same frames from the video. Higher is better. For instance, `load_time_factor=0.5` means that decoding from video is 2 times slower than loading the original images.
|
||||
|
||||
**Average L2 error per pixel (lower is better)**
|
||||
`avg_per_pixel_l2_error` is the average L2 error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
|
||||
|
||||
**Loss of a pretrained policy (higher is better)** (not available)
|
||||
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
|
||||
|
||||
**Success rate after retraining (higher is better)** (not available)
|
||||
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best.
|
||||
|
||||
|
||||
## Variables
|
||||
|
||||
**Image content**
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, etc. Hence, we run this benchmark on two datasets: `pusht` (simulation) and `umi` (real-world outdoor).
|
||||
|
||||
**Requested timestamps**
|
||||
In this benchmark, we focus on the loading time of random access, so we are not interested in sequentially loading all frames of a video like in a movie. However, the number of consecutive timestamps requested and their spacing can greatly affect the `load_time_factor`. In fact, it is expected to get faster loading time by decoding a large number of consecutive frames from a video, than to load the same data from individual images. To reflect our robotics use case, we consider a few settings:
|
||||
- `single_frame`: 1 frame,
|
||||
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
|
||||
- `2_frames_4_space`: 2 consecutive frames with 4 frames of spacing (e.g `[t, t + 4 / fps]`),
|
||||
|
||||
**Data augmentations**
|
||||
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
|
||||
|
||||
|
||||
## Results
|
||||
|
||||
**`decoder`**
|
||||
| repo_id | decoder | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- |
|
||||
| lerobot/pusht | <span style="color: #32CD32;">torchvision</span> | 0.166 | 0.0000119 |
|
||||
| lerobot/pusht | ffmpegio | 0.009 | 0.0001182 |
|
||||
| lerobot/pusht | torchaudio | 0.138 | 0.0000359 |
|
||||
| lerobot/umi_cup_in_the_wild | <span style="color: #32CD32;">torchvision</span> | 0.174 | 0.0000174 |
|
||||
| lerobot/umi_cup_in_the_wild | ffmpegio | 0.010 | 0.0000735 |
|
||||
| lerobot/umi_cup_in_the_wild | torchaudio | 0.154 | 0.0000340 |
|
||||
|
||||
### `1_frame`
|
||||
|
||||
**`pix_fmt`**
|
||||
| repo_id | pix_fmt | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | yuv420p | 3.788 | 0.224 | 0.0000760 |
|
||||
| lerobot/pusht | yuv444p | 3.646 | 0.185 | 0.0000443 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv420p | 14.391 | 0.388 | 0.0000469 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv444p | 14.932 | 0.329 | 0.0000397 |
|
||||
|
||||
**`g`**
|
||||
| repo_id | g | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 1 | 2.543 | 0.204 | 0.0000556 |
|
||||
| lerobot/pusht | 2 | 3.646 | 0.182 | 0.0000443 |
|
||||
| lerobot/pusht | 3 | 4.431 | 0.174 | 0.0000450 |
|
||||
| lerobot/pusht | 4 | 5.103 | 0.163 | 0.0000448 |
|
||||
| lerobot/pusht | 5 | 5.625 | 0.163 | 0.0000436 |
|
||||
| lerobot/pusht | 6 | 5.974 | 0.155 | 0.0000427 |
|
||||
| lerobot/pusht | 10 | 6.814 | 0.130 | 0.0000410 |
|
||||
| lerobot/pusht | 15 | 7.431 | 0.105 | 0.0000406 |
|
||||
| lerobot/pusht | 20 | 7.662 | 0.097 | 0.0000400 |
|
||||
| lerobot/pusht | 40 | 8.163 | 0.061 | 0.0000405 |
|
||||
| lerobot/pusht | 100 | 8.761 | 0.039 | 0.0000422 |
|
||||
| lerobot/pusht | None | 8.909 | 0.024 | 0.0000431 |
|
||||
| lerobot/umi_cup_in_the_wild | 1 | 14.411 | 0.444 | 0.0000601 |
|
||||
| lerobot/umi_cup_in_the_wild | 2 | 14.932 | 0.345 | 0.0000397 |
|
||||
| lerobot/umi_cup_in_the_wild | 3 | 20.174 | 0.282 | 0.0000416 |
|
||||
| lerobot/umi_cup_in_the_wild | 4 | 24.889 | 0.271 | 0.0000415 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 28.825 | 0.260 | 0.0000415 |
|
||||
| lerobot/umi_cup_in_the_wild | 6 | 31.635 | 0.249 | 0.0000415 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 39.418 | 0.195 | 0.0000399 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 44.577 | 0.169 | 0.0000394 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 47.907 | 0.140 | 0.0000390 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 52.554 | 0.096 | 0.0000384 |
|
||||
| lerobot/umi_cup_in_the_wild | 100 | 58.241 | 0.046 | 0.0000390 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 60.530 | 0.022 | 0.0000400 |
|
||||
|
||||
**`crf`**
|
||||
| repo_id | crf | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 0 | 1.699 | 0.175 | 0.0000035 |
|
||||
| lerobot/pusht | 5 | 1.409 | 0.181 | 0.0000080 |
|
||||
| lerobot/pusht | 10 | 1.842 | 0.172 | 0.0000123 |
|
||||
| lerobot/pusht | 15 | 2.322 | 0.187 | 0.0000211 |
|
||||
| lerobot/pusht | 20 | 3.050 | 0.181 | 0.0000346 |
|
||||
| lerobot/pusht | None | 3.646 | 0.189 | 0.0000443 |
|
||||
| lerobot/pusht | 25 | 3.969 | 0.186 | 0.0000521 |
|
||||
| lerobot/pusht | 30 | 5.687 | 0.184 | 0.0000850 |
|
||||
| lerobot/pusht | 40 | 10.818 | 0.193 | 0.0001726 |
|
||||
| lerobot/pusht | 50 | 18.185 | 0.183 | 0.0002606 |
|
||||
| lerobot/umi_cup_in_the_wild | 0 | 1.918 | 0.165 | 0.0000056 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 3.207 | 0.171 | 0.0000111 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 4.818 | 0.212 | 0.0000153 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 7.329 | 0.261 | 0.0000218 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 11.361 | 0.312 | 0.0000317 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 14.932 | 0.339 | 0.0000397 |
|
||||
| lerobot/umi_cup_in_the_wild | 25 | 17.741 | 0.297 | 0.0000452 |
|
||||
| lerobot/umi_cup_in_the_wild | 30 | 27.983 | 0.406 | 0.0000629 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 82.449 | 0.468 | 0.0001184 |
|
||||
| lerobot/umi_cup_in_the_wild | 50 | 186.145 | 0.515 | 0.0001879 |
|
||||
|
||||
**best**
|
||||
| repo_id | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- |
|
||||
| lerobot/pusht | 3.646 | 0.188 | 0.0000443 |
|
||||
| lerobot/umi_cup_in_the_wild | 14.932 | 0.339 | 0.0000397 |
|
||||
|
||||
### `2_frames`
|
||||
|
||||
**`pix_fmt`**
|
||||
| repo_id | pix_fmt | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | yuv420p | 3.788 | 0.314 | 0.0000799 |
|
||||
| lerobot/pusht | yuv444p | 3.646 | 0.303 | 0.0000496 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv420p | 14.391 | 0.642 | 0.0000503 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv444p | 14.932 | 0.529 | 0.0000436 |
|
||||
|
||||
**`g`**
|
||||
| repo_id | g | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 1 | 2.543 | 0.308 | 0.0000599 |
|
||||
| lerobot/pusht | 2 | 3.646 | 0.279 | 0.0000496 |
|
||||
| lerobot/pusht | 3 | 4.431 | 0.259 | 0.0000498 |
|
||||
| lerobot/pusht | 4 | 5.103 | 0.243 | 0.0000501 |
|
||||
| lerobot/pusht | 5 | 5.625 | 0.235 | 0.0000492 |
|
||||
| lerobot/pusht | 6 | 5.974 | 0.230 | 0.0000481 |
|
||||
| lerobot/pusht | 10 | 6.814 | 0.194 | 0.0000468 |
|
||||
| lerobot/pusht | 15 | 7.431 | 0.152 | 0.0000460 |
|
||||
| lerobot/pusht | 20 | 7.662 | 0.151 | 0.0000455 |
|
||||
| lerobot/pusht | 40 | 8.163 | 0.095 | 0.0000454 |
|
||||
| lerobot/pusht | 100 | 8.761 | 0.062 | 0.0000472 |
|
||||
| lerobot/pusht | None | 8.909 | 0.037 | 0.0000479 |
|
||||
| lerobot/umi_cup_in_the_wild | 1 | 14.411 | 0.638 | 0.0000625 |
|
||||
| lerobot/umi_cup_in_the_wild | 2 | 14.932 | 0.537 | 0.0000436 |
|
||||
| lerobot/umi_cup_in_the_wild | 3 | 20.174 | 0.493 | 0.0000437 |
|
||||
| lerobot/umi_cup_in_the_wild | 4 | 24.889 | 0.458 | 0.0000446 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 28.825 | 0.438 | 0.0000445 |
|
||||
| lerobot/umi_cup_in_the_wild | 6 | 31.635 | 0.424 | 0.0000444 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 39.418 | 0.345 | 0.0000435 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 44.577 | 0.313 | 0.0000417 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 47.907 | 0.264 | 0.0000421 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 52.554 | 0.185 | 0.0000414 |
|
||||
| lerobot/umi_cup_in_the_wild | 100 | 58.241 | 0.090 | 0.0000420 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 60.530 | 0.042 | 0.0000424 |
|
||||
|
||||
**`crf`**
|
||||
| repo_id | crf | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 0 | 1.699 | 0.302 | 0.0000097 |
|
||||
| lerobot/pusht | 5 | 1.409 | 0.287 | 0.0000142 |
|
||||
| lerobot/pusht | 10 | 1.842 | 0.283 | 0.0000184 |
|
||||
| lerobot/pusht | 15 | 2.322 | 0.305 | 0.0000268 |
|
||||
| lerobot/pusht | 20 | 3.050 | 0.285 | 0.0000402 |
|
||||
| lerobot/pusht | None | 3.646 | 0.285 | 0.0000496 |
|
||||
| lerobot/pusht | 25 | 3.969 | 0.293 | 0.0000572 |
|
||||
| lerobot/pusht | 30 | 5.687 | 0.293 | 0.0000893 |
|
||||
| lerobot/pusht | 40 | 10.818 | 0.319 | 0.0001762 |
|
||||
| lerobot/pusht | 50 | 18.185 | 0.304 | 0.0002626 |
|
||||
| lerobot/umi_cup_in_the_wild | 0 | 1.918 | 0.235 | 0.0000112 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 3.207 | 0.261 | 0.0000166 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 4.818 | 0.333 | 0.0000207 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 7.329 | 0.406 | 0.0000267 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 11.361 | 0.489 | 0.0000361 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 14.932 | 0.537 | 0.0000436 |
|
||||
| lerobot/umi_cup_in_the_wild | 25 | 17.741 | 0.578 | 0.0000487 |
|
||||
| lerobot/umi_cup_in_the_wild | 30 | 27.983 | 0.453 | 0.0000655 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 82.449 | 0.767 | 0.0001192 |
|
||||
| lerobot/umi_cup_in_the_wild | 50 | 186.145 | 0.816 | 0.0001881 |
|
||||
|
||||
**best**
|
||||
| repo_id | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- |
|
||||
| lerobot/pusht | 3.646 | 0.283 | 0.0000496 |
|
||||
| lerobot/umi_cup_in_the_wild | 14.932 | 0.543 | 0.0000436 |
|
||||
|
||||
### `2_frames_4_space`
|
||||
|
||||
**`pix_fmt`**
|
||||
| repo_id | pix_fmt | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | yuv420p | 3.788 | 0.257 | 0.0000855 |
|
||||
| lerobot/pusht | yuv444p | 3.646 | 0.261 | 0.0000556 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv420p | 14.391 | 0.493 | 0.0000476 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv444p | 14.932 | 0.371 | 0.0000404 |
|
||||
|
||||
**`g`**
|
||||
| repo_id | g | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 1 | 2.543 | 0.226 | 0.0000670 |
|
||||
| lerobot/pusht | 2 | 3.646 | 0.222 | 0.0000556 |
|
||||
| lerobot/pusht | 3 | 4.431 | 0.217 | 0.0000567 |
|
||||
| lerobot/pusht | 4 | 5.103 | 0.204 | 0.0000555 |
|
||||
| lerobot/pusht | 5 | 5.625 | 0.179 | 0.0000556 |
|
||||
| lerobot/pusht | 6 | 5.974 | 0.188 | 0.0000544 |
|
||||
| lerobot/pusht | 10 | 6.814 | 0.160 | 0.0000531 |
|
||||
| lerobot/pusht | 15 | 7.431 | 0.150 | 0.0000521 |
|
||||
| lerobot/pusht | 20 | 7.662 | 0.123 | 0.0000519 |
|
||||
| lerobot/pusht | 40 | 8.163 | 0.092 | 0.0000519 |
|
||||
| lerobot/pusht | 100 | 8.761 | 0.053 | 0.0000533 |
|
||||
| lerobot/pusht | None | 8.909 | 0.034 | 0.0000541 |
|
||||
| lerobot/umi_cup_in_the_wild | 1 | 14.411 | 0.409 | 0.0000607 |
|
||||
| lerobot/umi_cup_in_the_wild | 2 | 14.932 | 0.381 | 0.0000404 |
|
||||
| lerobot/umi_cup_in_the_wild | 3 | 20.174 | 0.355 | 0.0000418 |
|
||||
| lerobot/umi_cup_in_the_wild | 4 | 24.889 | 0.346 | 0.0000425 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 28.825 | 0.354 | 0.0000419 |
|
||||
| lerobot/umi_cup_in_the_wild | 6 | 31.635 | 0.336 | 0.0000419 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 39.418 | 0.314 | 0.0000402 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 44.577 | 0.269 | 0.0000397 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 47.907 | 0.246 | 0.0000395 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 52.554 | 0.171 | 0.0000390 |
|
||||
| lerobot/umi_cup_in_the_wild | 100 | 58.241 | 0.091 | 0.0000399 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 60.530 | 0.043 | 0.0000409 |
|
||||
|
||||
**`crf`**
|
||||
| repo_id | crf | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 0 | 1.699 | 0.212 | 0.0000193 |
|
||||
| lerobot/pusht | 5 | 1.409 | 0.211 | 0.0000232 |
|
||||
| lerobot/pusht | 10 | 1.842 | 0.199 | 0.0000270 |
|
||||
| lerobot/pusht | 15 | 2.322 | 0.198 | 0.0000347 |
|
||||
| lerobot/pusht | 20 | 3.050 | 0.211 | 0.0000469 |
|
||||
| lerobot/pusht | None | 3.646 | 0.206 | 0.0000556 |
|
||||
| lerobot/pusht | 25 | 3.969 | 0.210 | 0.0000626 |
|
||||
| lerobot/pusht | 30 | 5.687 | 0.223 | 0.0000927 |
|
||||
| lerobot/pusht | 40 | 10.818 | 0.227 | 0.0001763 |
|
||||
| lerobot/pusht | 50 | 18.185 | 0.223 | 0.0002625 |
|
||||
| lerobot/umi_cup_in_the_wild | 0 | 1.918 | 0.147 | 0.0000071 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 3.207 | 0.182 | 0.0000125 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 4.818 | 0.222 | 0.0000166 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 7.329 | 0.270 | 0.0000229 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 11.361 | 0.325 | 0.0000326 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 14.932 | 0.362 | 0.0000404 |
|
||||
| lerobot/umi_cup_in_the_wild | 25 | 17.741 | 0.390 | 0.0000459 |
|
||||
| lerobot/umi_cup_in_the_wild | 30 | 27.983 | 0.437 | 0.0000633 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 82.449 | 0.499 | 0.0001186 |
|
||||
| lerobot/umi_cup_in_the_wild | 50 | 186.145 | 0.564 | 0.0001879 |
|
||||
|
||||
**best**
|
||||
| repo_id | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- |
|
||||
| lerobot/pusht | 3.646 | 0.224 | 0.0000556 |
|
||||
| lerobot/umi_cup_in_the_wild | 14.932 | 0.368 | 0.0000404 |
|
||||
|
||||
### `6_frames`
|
||||
|
||||
**`pix_fmt`**
|
||||
| repo_id | pix_fmt | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | yuv420p | 3.788 | 0.660 | 0.0000839 |
|
||||
| lerobot/pusht | yuv444p | 3.646 | 0.546 | 0.0000542 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv420p | 14.391 | 1.225 | 0.0000497 |
|
||||
| lerobot/umi_cup_in_the_wild | yuv444p | 14.932 | 0.908 | 0.0000428 |
|
||||
|
||||
**`g`**
|
||||
| repo_id | g | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 1 | 2.543 | 0.552 | 0.0000646 |
|
||||
| lerobot/pusht | 2 | 3.646 | 0.534 | 0.0000542 |
|
||||
| lerobot/pusht | 3 | 4.431 | 0.563 | 0.0000546 |
|
||||
| lerobot/pusht | 4 | 5.103 | 0.537 | 0.0000545 |
|
||||
| lerobot/pusht | 5 | 5.625 | 0.477 | 0.0000532 |
|
||||
| lerobot/pusht | 6 | 5.974 | 0.515 | 0.0000530 |
|
||||
| lerobot/pusht | 10 | 6.814 | 0.410 | 0.0000512 |
|
||||
| lerobot/pusht | 15 | 7.431 | 0.405 | 0.0000503 |
|
||||
| lerobot/pusht | 20 | 7.662 | 0.345 | 0.0000500 |
|
||||
| lerobot/pusht | 40 | 8.163 | 0.247 | 0.0000496 |
|
||||
| lerobot/pusht | 100 | 8.761 | 0.147 | 0.0000510 |
|
||||
| lerobot/pusht | None | 8.909 | 0.100 | 0.0000519 |
|
||||
| lerobot/umi_cup_in_the_wild | 1 | 14.411 | 0.997 | 0.0000620 |
|
||||
| lerobot/umi_cup_in_the_wild | 2 | 14.932 | 0.911 | 0.0000428 |
|
||||
| lerobot/umi_cup_in_the_wild | 3 | 20.174 | 0.869 | 0.0000433 |
|
||||
| lerobot/umi_cup_in_the_wild | 4 | 24.889 | 0.874 | 0.0000438 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 28.825 | 0.864 | 0.0000439 |
|
||||
| lerobot/umi_cup_in_the_wild | 6 | 31.635 | 0.834 | 0.0000440 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 39.418 | 0.781 | 0.0000421 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 44.577 | 0.679 | 0.0000411 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 47.907 | 0.652 | 0.0000410 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 52.554 | 0.465 | 0.0000404 |
|
||||
| lerobot/umi_cup_in_the_wild | 100 | 58.241 | 0.245 | 0.0000413 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 60.530 | 0.116 | 0.0000417 |
|
||||
|
||||
**`crf`**
|
||||
| repo_id | crf | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| lerobot/pusht | 0 | 1.699 | 0.534 | 0.0000163 |
|
||||
| lerobot/pusht | 5 | 1.409 | 0.524 | 0.0000205 |
|
||||
| lerobot/pusht | 10 | 1.842 | 0.510 | 0.0000245 |
|
||||
| lerobot/pusht | 15 | 2.322 | 0.512 | 0.0000324 |
|
||||
| lerobot/pusht | 20 | 3.050 | 0.508 | 0.0000452 |
|
||||
| lerobot/pusht | None | 3.646 | 0.518 | 0.0000542 |
|
||||
| lerobot/pusht | 25 | 3.969 | 0.534 | 0.0000616 |
|
||||
| lerobot/pusht | 30 | 5.687 | 0.530 | 0.0000927 |
|
||||
| lerobot/pusht | 40 | 10.818 | 0.552 | 0.0001777 |
|
||||
| lerobot/pusht | 50 | 18.185 | 0.564 | 0.0002644 |
|
||||
| lerobot/umi_cup_in_the_wild | 0 | 1.918 | 0.401 | 0.0000101 |
|
||||
| lerobot/umi_cup_in_the_wild | 5 | 3.207 | 0.499 | 0.0000156 |
|
||||
| lerobot/umi_cup_in_the_wild | 10 | 4.818 | 0.599 | 0.0000197 |
|
||||
| lerobot/umi_cup_in_the_wild | 15 | 7.329 | 0.704 | 0.0000258 |
|
||||
| lerobot/umi_cup_in_the_wild | 20 | 11.361 | 0.834 | 0.0000352 |
|
||||
| lerobot/umi_cup_in_the_wild | None | 14.932 | 0.925 | 0.0000428 |
|
||||
| lerobot/umi_cup_in_the_wild | 25 | 17.741 | 0.978 | 0.0000480 |
|
||||
| lerobot/umi_cup_in_the_wild | 30 | 27.983 | 1.088 | 0.0000648 |
|
||||
| lerobot/umi_cup_in_the_wild | 40 | 82.449 | 1.324 | 0.0001190 |
|
||||
| lerobot/umi_cup_in_the_wild | 50 | 186.145 | 1.436 | 0.0001880 |
|
||||
|
||||
**best**
|
||||
| repo_id | compression_factor | load_time_factor | avg_per_pixel_l2_error |
|
||||
| --- | --- | --- | --- |
|
||||
| lerobot/pusht | 3.646 | 0.546 | 0.0000542 |
|
||||
| lerobot/umi_cup_in_the_wild | 14.932 | 0.934 | 0.0000428 |
|
||||
372
lerobot/common/datasets/_video_benchmark/run_video_benchmark.py
Normal file
372
lerobot/common/datasets/_video_benchmark/run_video_benchmark.py
Normal file
@@ -0,0 +1,372 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import numpy
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
decode_video_frames_torchvision,
|
||||
)
|
||||
|
||||
|
||||
def get_directory_size(directory):
|
||||
total_size = 0
|
||||
# Iterate over all files and subdirectories recursively
|
||||
for item in directory.rglob("*"):
|
||||
if item.is_file():
|
||||
# Add the file size to the total
|
||||
total_size += item.stat().st_size
|
||||
return total_size
|
||||
|
||||
|
||||
def run_video_benchmark(
|
||||
output_dir,
|
||||
cfg,
|
||||
timestamps_mode,
|
||||
seed=1337,
|
||||
):
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
repo_id = cfg["repo_id"]
|
||||
|
||||
# TODO(rcadene): rewrite with hardcoding of original images and episodes
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# Get fps
|
||||
fps = dataset.fps
|
||||
|
||||
# we only load first episode
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
|
||||
# Save/Load image directory for the first episode
|
||||
imgs_dir = Path(f"tmp/data/images/{repo_id}/observation.image_episode_000000")
|
||||
if not imgs_dir.exists():
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
imgs_dataset = hf_dataset.select_columns("observation.image")
|
||||
|
||||
for i, item in enumerate(imgs_dataset):
|
||||
img = item["observation.image"]
|
||||
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
|
||||
if i >= ep_num_images - 1:
|
||||
break
|
||||
|
||||
sum_original_frames_size_bytes = get_directory_size(imgs_dir)
|
||||
|
||||
# Encode images into video
|
||||
video_path = output_dir / "episode_0.mp4"
|
||||
|
||||
g = cfg.get("g")
|
||||
crf = cfg.get("crf")
|
||||
pix_fmt = cfg["pix_fmt"]
|
||||
|
||||
cmd = f"ffmpeg -r {fps} "
|
||||
cmd += "-f image2 "
|
||||
cmd += "-loglevel error "
|
||||
cmd += f"-i {str(imgs_dir / 'frame_%06d.png')} "
|
||||
cmd += "-vcodec libx264 "
|
||||
if g is not None:
|
||||
cmd += f"-g {g} " # ensures at least 1 keyframe every 10 frames
|
||||
# cmd += "-keyint_min 10 " set a minimum of 10 frames between 2 key frames
|
||||
# cmd += "-sc_threshold 0 " disable scene change detection to lower the number of key frames
|
||||
if crf is not None:
|
||||
cmd += f"-crf {crf} "
|
||||
cmd += f"-pix_fmt {pix_fmt} "
|
||||
cmd += f"{str(video_path)}"
|
||||
subprocess.run(cmd.split(" "), check=True)
|
||||
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
|
||||
# Set decoder
|
||||
|
||||
decoder = cfg["decoder"]
|
||||
decoder_kwgs = cfg["decoder_kwgs"]
|
||||
device = cfg["device"]
|
||||
|
||||
if decoder == "torchvision":
|
||||
decode_frames_fn = decode_video_frames_torchvision
|
||||
else:
|
||||
raise ValueError(decoder)
|
||||
|
||||
# Estimate average loading time
|
||||
|
||||
def load_original_frames(imgs_dir, timestamps):
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
|
||||
frame = torch.from_numpy(numpy.array(frame))
|
||||
frame = frame.type(torch.float32) / 255
|
||||
frame = einops.rearrange(frame, "h w c -> c h w")
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
list_avg_load_time = []
|
||||
list_avg_load_time_from_images = []
|
||||
per_pixel_l2_errors = []
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
for t in range(50):
|
||||
# test loading 2 frames that are 4 frames appart, which might be a common setting
|
||||
ts = random.randint(fps, ep_num_images - fps) / fps
|
||||
|
||||
if timestamps_mode == "1_frame":
|
||||
timestamps = [ts]
|
||||
elif timestamps_mode == "2_frames":
|
||||
timestamps = [ts - 1 / fps, ts]
|
||||
elif timestamps_mode == "2_frames_4_space":
|
||||
timestamps = [ts - 4 / fps, ts]
|
||||
elif timestamps_mode == "6_frames":
|
||||
timestamps = [ts - i / fps for i in range(6)][::-1]
|
||||
else:
|
||||
raise ValueError(timestamps_mode)
|
||||
|
||||
num_frames = len(timestamps)
|
||||
|
||||
start_time_s = time.monotonic()
|
||||
frames = decode_frames_fn(
|
||||
video_path, timestamps=timestamps, tolerance_s=1e-4, device=device, **decoder_kwgs
|
||||
)
|
||||
avg_load_time = (time.monotonic() - start_time_s) / num_frames
|
||||
list_avg_load_time.append(avg_load_time)
|
||||
|
||||
start_time_s = time.monotonic()
|
||||
original_frames = load_original_frames(imgs_dir, timestamps)
|
||||
avg_load_time_from_images = (time.monotonic() - start_time_s) / num_frames
|
||||
list_avg_load_time_from_images.append(avg_load_time_from_images)
|
||||
|
||||
# Estimate average L2 error between original frames and decoded frames
|
||||
for i, ts in enumerate(timestamps):
|
||||
# are_close = torch.allclose(frames[i], original_frames[i], atol=0.02)
|
||||
num_pixels = original_frames[i].numel()
|
||||
per_pixel_l2_error = torch.norm(frames[i] - original_frames[i], p=2).item() / num_pixels
|
||||
|
||||
# save decoded frames
|
||||
if t == 0:
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(output_dir / f"frame_{i:06d}.png")
|
||||
|
||||
# save original_frames
|
||||
idx = int(ts * fps)
|
||||
if t == 0:
|
||||
original_frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
|
||||
original_frame.save(output_dir / f"original_frame_{i:06d}.png")
|
||||
|
||||
per_pixel_l2_errors.append(per_pixel_l2_error)
|
||||
|
||||
avg_load_time = float(numpy.array(list_avg_load_time).mean())
|
||||
avg_load_time_from_images = float(numpy.array(list_avg_load_time_from_images).mean())
|
||||
avg_per_pixel_l2_error = float(numpy.array(per_pixel_l2_errors).mean())
|
||||
|
||||
# Save benchmark info
|
||||
|
||||
info = {
|
||||
"sum_original_frames_size_bytes": sum_original_frames_size_bytes,
|
||||
"video_size_bytes": video_size_bytes,
|
||||
"avg_load_time_from_images": avg_load_time_from_images,
|
||||
"avg_load_time": avg_load_time,
|
||||
"compression_factor": sum_original_frames_size_bytes / video_size_bytes,
|
||||
"load_time_factor": avg_load_time_from_images / avg_load_time,
|
||||
"avg_per_pixel_l2_error": avg_per_pixel_l2_error,
|
||||
}
|
||||
|
||||
with open(output_dir / "info.json", "w") as f:
|
||||
json.dump(info, f)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def display_markdown_table(headers, rows):
|
||||
for i, row in enumerate(rows):
|
||||
new_row = []
|
||||
for col in row:
|
||||
if col is None:
|
||||
new_col = "None"
|
||||
elif isinstance(col, float):
|
||||
new_col = f"{col:.3f}"
|
||||
if new_col == "0.000":
|
||||
new_col = f"{col:.7f}"
|
||||
elif isinstance(col, int):
|
||||
new_col = f"{col}"
|
||||
else:
|
||||
new_col = col
|
||||
new_row.append(new_col)
|
||||
rows[i] = new_row
|
||||
|
||||
header_line = "| " + " | ".join(headers) + " |"
|
||||
separator_line = "| " + " | ".join(["---" for _ in headers]) + " |"
|
||||
body_lines = ["| " + " | ".join(row) + " |" for row in rows]
|
||||
markdown_table = "\n".join([header_line, separator_line] + body_lines)
|
||||
print(markdown_table)
|
||||
print()
|
||||
|
||||
|
||||
def load_info(out_dir):
|
||||
with open(out_dir / "info.json") as f:
|
||||
info = json.load(f)
|
||||
return info
|
||||
|
||||
|
||||
def main():
|
||||
out_dir = Path("tmp/run_video_benchmark")
|
||||
dry_run = False
|
||||
repo_ids = ["lerobot/pusht", "lerobot/umi_cup_in_the_wild"]
|
||||
timestamps_modes = [
|
||||
"1_frame",
|
||||
"2_frames",
|
||||
"2_frames_4_space",
|
||||
"6_frames",
|
||||
]
|
||||
for timestamps_mode in timestamps_modes:
|
||||
bench_dir = out_dir / timestamps_mode
|
||||
|
||||
print(f"### `{timestamps_mode}`")
|
||||
print()
|
||||
|
||||
print("**`pix_fmt`**")
|
||||
headers = ["repo_id", "pix_fmt", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"]
|
||||
rows = []
|
||||
for repo_id in repo_ids:
|
||||
for pix_fmt in ["yuv420p", "yuv444p"]:
|
||||
cfg = {
|
||||
"repo_id": repo_id,
|
||||
# video encoding
|
||||
"g": 2,
|
||||
"crf": None,
|
||||
"pix_fmt": pix_fmt,
|
||||
# video decoding
|
||||
"device": "cpu",
|
||||
"decoder": "torchvision",
|
||||
"decoder_kwgs": {},
|
||||
}
|
||||
if not dry_run:
|
||||
run_video_benchmark(bench_dir / repo_id / f"torchvision_{pix_fmt}", cfg, timestamps_mode)
|
||||
info = load_info(bench_dir / repo_id / f"torchvision_{pix_fmt}")
|
||||
rows.append(
|
||||
[
|
||||
repo_id,
|
||||
pix_fmt,
|
||||
info["compression_factor"],
|
||||
info["load_time_factor"],
|
||||
info["avg_per_pixel_l2_error"],
|
||||
]
|
||||
)
|
||||
display_markdown_table(headers, rows)
|
||||
|
||||
print("**`g`**")
|
||||
headers = ["repo_id", "g", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"]
|
||||
rows = []
|
||||
for repo_id in repo_ids:
|
||||
for g in [1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None]:
|
||||
cfg = {
|
||||
"repo_id": repo_id,
|
||||
# video encoding
|
||||
"g": g,
|
||||
"pix_fmt": "yuv444p",
|
||||
# video decoding
|
||||
"device": "cpu",
|
||||
"decoder": "torchvision",
|
||||
"decoder_kwgs": {},
|
||||
}
|
||||
if not dry_run:
|
||||
run_video_benchmark(bench_dir / repo_id / f"torchvision_g_{g}", cfg, timestamps_mode)
|
||||
info = load_info(bench_dir / repo_id / f"torchvision_g_{g}")
|
||||
rows.append(
|
||||
[
|
||||
repo_id,
|
||||
g,
|
||||
info["compression_factor"],
|
||||
info["load_time_factor"],
|
||||
info["avg_per_pixel_l2_error"],
|
||||
]
|
||||
)
|
||||
display_markdown_table(headers, rows)
|
||||
|
||||
print("**`crf`**")
|
||||
headers = ["repo_id", "crf", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"]
|
||||
rows = []
|
||||
for repo_id in repo_ids:
|
||||
for crf in [0, 5, 10, 15, 20, None, 25, 30, 40, 50]:
|
||||
cfg = {
|
||||
"repo_id": repo_id,
|
||||
# video encoding
|
||||
"g": 2,
|
||||
"crf": crf,
|
||||
"pix_fmt": "yuv444p",
|
||||
# video decoding
|
||||
"device": "cpu",
|
||||
"decoder": "torchvision",
|
||||
"decoder_kwgs": {},
|
||||
}
|
||||
if not dry_run:
|
||||
run_video_benchmark(bench_dir / repo_id / f"torchvision_crf_{crf}", cfg, timestamps_mode)
|
||||
info = load_info(bench_dir / repo_id / f"torchvision_crf_{crf}")
|
||||
rows.append(
|
||||
[
|
||||
repo_id,
|
||||
crf,
|
||||
info["compression_factor"],
|
||||
info["load_time_factor"],
|
||||
info["avg_per_pixel_l2_error"],
|
||||
]
|
||||
)
|
||||
display_markdown_table(headers, rows)
|
||||
|
||||
print("**best**")
|
||||
headers = ["repo_id", "compression_factor", "load_time_factor", "avg_per_pixel_l2_error"]
|
||||
rows = []
|
||||
for repo_id in repo_ids:
|
||||
cfg = {
|
||||
"repo_id": repo_id,
|
||||
# video encoding
|
||||
"g": 2,
|
||||
"crf": None,
|
||||
"pix_fmt": "yuv444p",
|
||||
# video decoding
|
||||
"device": "cpu",
|
||||
"decoder": "torchvision",
|
||||
"decoder_kwgs": {},
|
||||
}
|
||||
if not dry_run:
|
||||
run_video_benchmark(bench_dir / repo_id / "torchvision_best", cfg, timestamps_mode)
|
||||
info = load_info(bench_dir / repo_id / "torchvision_best")
|
||||
rows.append(
|
||||
[
|
||||
repo_id,
|
||||
info["compression_factor"],
|
||||
info["load_time_factor"],
|
||||
info["avg_per_pixel_l2_error"],
|
||||
]
|
||||
)
|
||||
display_markdown_table(headers, rows)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -19,7 +19,6 @@ import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
|
||||
|
||||
def resolve_delta_timestamps(cfg):
|
||||
@@ -57,7 +56,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
)
|
||||
|
||||
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
|
||||
if cfg.env.name != "dora":
|
||||
if not cfg.env.real_world:
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
|
||||
else:
|
||||
@@ -72,39 +71,17 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
image_transforms = None
|
||||
if cfg.training.image_transforms.enable:
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
# TODO(rcadene): add data augmentations
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
)
|
||||
else:
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
|
||||
@@ -35,41 +35,40 @@ from lerobot.common.datasets.utils import (
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||
|
||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||
CODEBASE_VERSION = "v1.6"
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
CODEBASE_VERSION = "v1.4"
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
transform: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.image_transforms = image_transforms
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
|
||||
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
|
||||
if split == "train":
|
||||
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
|
||||
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
||||
else:
|
||||
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
|
||||
self.info = load_info(repo_id, CODEBASE_VERSION, root)
|
||||
self.stats = load_stats(repo_id, version, root)
|
||||
self.info = load_info(repo_id, version, root)
|
||||
if self.video:
|
||||
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.videos_dir = load_videos(repo_id, version, root)
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
@@ -150,12 +149,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.video_frame_keys,
|
||||
self.videos_dir,
|
||||
self.tolerance_s,
|
||||
self.video_backend,
|
||||
)
|
||||
|
||||
if self.image_transforms is not None:
|
||||
for cam in self.camera_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
@@ -163,6 +160,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Version: '{self.version}',\n"
|
||||
f" Split: '{self.split}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
@@ -170,8 +168,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
|
||||
f" Transformations: {self.transform},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
@@ -179,6 +176,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def from_preloaded(
|
||||
cls,
|
||||
repo_id: str = "from_preloaded",
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
@@ -189,7 +187,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
stats=None,
|
||||
info=None,
|
||||
videos_dir=None,
|
||||
video_backend=None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
|
||||
|
||||
@@ -202,16 +199,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# create an empty object of type LeRobotDataset
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.version = version
|
||||
obj.root = root
|
||||
obj.split = split
|
||||
obj.image_transforms = transform
|
||||
obj.transform = transform
|
||||
obj.delta_timestamps = delta_timestamps
|
||||
obj.hf_dataset = hf_dataset
|
||||
obj.episode_data_index = episode_data_index
|
||||
obj.stats = stats
|
||||
obj.info = info if info is not None else {}
|
||||
obj.videos_dir = videos_dir
|
||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
return obj
|
||||
|
||||
|
||||
@@ -225,11 +222,11 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
transform: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
@@ -238,11 +235,11 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
version=version,
|
||||
root=root,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=video_backend,
|
||||
transform=transform,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
@@ -274,9 +271,10 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
self.disabled_data_keys.update(extra_keys)
|
||||
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.image_transforms = image_transforms
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.stats = aggregate_stats(self._datasets)
|
||||
|
||||
@@ -382,13 +380,13 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
for data_key in self.disabled_data_keys:
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository IDs: '{self.repo_ids}',\n"
|
||||
f" Version: '{self.version}',\n"
|
||||
f" Split: '{self.split}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
@@ -396,6 +394,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f" Transformations: {self.transform},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
@@ -1,384 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""An online buffer for the online training loop in train.py
|
||||
|
||||
Note to maintainers: This duplicates some logic from LeRobotDataset and EpisodeAwareSampler. We should
|
||||
consider converging to one approach. Here we have opted to use numpy.memmap to back the data buffer. It's much
|
||||
faster than using HuggingFace Datasets as there's no conversion to an intermediate non-python object. Also it
|
||||
supports in-place slicing and mutation which is very handy for a dynamic buffer.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def _make_memmap_safe(**kwargs) -> np.memmap:
|
||||
"""Make a numpy memmap with checks on available disk space first.
|
||||
|
||||
Expected kwargs are: "filename", "dtype" (must by np.dtype), "mode" and "shape"
|
||||
|
||||
For information on dtypes:
|
||||
https://numpy.org/doc/stable/reference/arrays.dtypes.html#arrays-dtypes-constructing
|
||||
"""
|
||||
if kwargs["mode"].startswith("w"):
|
||||
required_space = kwargs["dtype"].itemsize * np.prod(kwargs["shape"]) # bytes
|
||||
stats = os.statvfs(Path(kwargs["filename"]).parent)
|
||||
available_space = stats.f_bavail * stats.f_frsize # bytes
|
||||
if required_space >= available_space * 0.8:
|
||||
raise RuntimeError(
|
||||
f"You're about to take up {required_space} of {available_space} bytes available."
|
||||
)
|
||||
return np.memmap(**kwargs)
|
||||
|
||||
|
||||
class OnlineBuffer(torch.utils.data.Dataset):
|
||||
"""FIFO data buffer for the online training loop in train.py.
|
||||
|
||||
Follows the protocol of LeRobotDataset as much as is required to have it be used by the online training
|
||||
loop in the same way that a LeRobotDataset would be used.
|
||||
|
||||
The underlying data structure will have data inserted in a circular fashion. Always insert after the
|
||||
last index, and when you reach the end, wrap around to the start.
|
||||
|
||||
The data is stored in a numpy memmap.
|
||||
"""
|
||||
|
||||
NEXT_INDEX_KEY = "_next_index"
|
||||
OCCUPANCY_MASK_KEY = "_occupancy_mask"
|
||||
INDEX_KEY = "index"
|
||||
FRAME_INDEX_KEY = "frame_index"
|
||||
EPISODE_INDEX_KEY = "episode_index"
|
||||
TIMESTAMP_KEY = "timestamp"
|
||||
IS_PAD_POSTFIX = "_is_pad"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
write_dir: str | Path,
|
||||
data_spec: dict[str, Any] | None,
|
||||
buffer_capacity: int | None,
|
||||
fps: float | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | dict[str, np.ndarray] | None = None,
|
||||
):
|
||||
"""
|
||||
The online buffer can be provided from scratch or you can load an existing online buffer by passing
|
||||
a `write_dir` associated with an existing buffer.
|
||||
|
||||
Args:
|
||||
write_dir: Where to keep the numpy memmap files. One memmap file will be stored for each data key.
|
||||
Note that if the files already exist, they are opened in read-write mode (used for training
|
||||
resumption.)
|
||||
data_spec: A mapping from data key to data specification, like {data_key: {"shape": tuple[int],
|
||||
"dtype": np.dtype}}. This should include all the data that you wish to record into the buffer,
|
||||
but note that "index", "frame_index" and "episode_index" are already accounted for by this
|
||||
class, so you don't need to include them.
|
||||
buffer_capacity: How many frames should be stored in the buffer as a maximum. Be aware of your
|
||||
system's available disk space when choosing this.
|
||||
fps: Same as the fps concept in LeRobot dataset. Here it needs to be provided for the
|
||||
delta_timestamps logic. You can pass None if you are not using delta_timestamps.
|
||||
delta_timestamps: Same as the delta_timestamps concept in LeRobotDataset. This is internally
|
||||
converted to dict[str, np.ndarray] for optimization purposes.
|
||||
|
||||
"""
|
||||
self.set_delta_timestamps(delta_timestamps)
|
||||
self._fps = fps
|
||||
# Tolerance in seconds used to discard loaded frames when their timestamps are not close enough from
|
||||
# the requested frames. It is only used when `delta_timestamps` is provided.
|
||||
# minus 1e-4 to account for possible numerical error
|
||||
self.tolerance_s = 1 / self.fps - 1e-4 if fps is not None else None
|
||||
self._buffer_capacity = buffer_capacity
|
||||
data_spec = self._make_data_spec(data_spec, buffer_capacity)
|
||||
Path(write_dir).mkdir(parents=True, exist_ok=True)
|
||||
self._data = {}
|
||||
for k, v in data_spec.items():
|
||||
self._data[k] = _make_memmap_safe(
|
||||
filename=Path(write_dir) / k,
|
||||
dtype=v["dtype"] if v is not None else None,
|
||||
mode="r+" if (Path(write_dir) / k).exists() else "w+",
|
||||
shape=tuple(v["shape"]) if v is not None else None,
|
||||
)
|
||||
|
||||
@property
|
||||
def delta_timestamps(self) -> dict[str, np.ndarray] | None:
|
||||
return self._delta_timestamps
|
||||
|
||||
def set_delta_timestamps(self, value: dict[str, list[float]] | None):
|
||||
"""Set delta_timestamps converting the values to numpy arrays.
|
||||
|
||||
The conversion is for an optimization in the __getitem__. The loop is much slower if the arrays
|
||||
need to be converted into numpy arrays.
|
||||
"""
|
||||
if value is not None:
|
||||
self._delta_timestamps = {k: np.array(v) for k, v in value.items()}
|
||||
else:
|
||||
self._delta_timestamps = None
|
||||
|
||||
def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]:
|
||||
"""Makes the data spec for np.memmap."""
|
||||
if any(k.startswith("_") for k in data_spec):
|
||||
raise ValueError(
|
||||
"data_spec keys should not start with '_'. This prefix is reserved for internal logic."
|
||||
)
|
||||
preset_keys = {
|
||||
OnlineBuffer.INDEX_KEY,
|
||||
OnlineBuffer.FRAME_INDEX_KEY,
|
||||
OnlineBuffer.EPISODE_INDEX_KEY,
|
||||
OnlineBuffer.TIMESTAMP_KEY,
|
||||
}
|
||||
if len(intersection := set(data_spec).intersection(preset_keys)) > 0:
|
||||
raise ValueError(
|
||||
f"data_spec should not contain any of {preset_keys} as these are handled internally. "
|
||||
f"The provided data_spec has {intersection}."
|
||||
)
|
||||
complete_data_spec = {
|
||||
# _next_index will be a pointer to the next index that we should start filling from when we add
|
||||
# more data.
|
||||
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
||||
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
||||
# with real data rather than the dummy initialization.
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
|
||||
}
|
||||
for k, v in data_spec.items():
|
||||
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
|
||||
return complete_data_spec
|
||||
|
||||
def add_data(self, data: dict[str, np.ndarray]):
|
||||
"""Add new data to the buffer, which could potentially mean shifting old data out.
|
||||
|
||||
The new data should contain all the frames (in order) of any number of episodes. The indices should
|
||||
start from 0 (note to the developer: this can easily be generalized). See the `rollout` and
|
||||
`eval_policy` functions in `eval.py` for more information on how the data is constructed.
|
||||
|
||||
Shift the incoming data index and episode_index to continue on from the last frame. Note that this
|
||||
will be done in place!
|
||||
"""
|
||||
if len(missing_keys := (set(self.data_keys).difference(set(data)))) > 0:
|
||||
raise ValueError(f"Missing data keys: {missing_keys}")
|
||||
new_data_length = len(data[self.data_keys[0]])
|
||||
if not all(len(data[k]) == new_data_length for k in self.data_keys):
|
||||
raise ValueError("All data items should have the same length")
|
||||
|
||||
next_index = self._data[OnlineBuffer.NEXT_INDEX_KEY]
|
||||
|
||||
# Sanity check to make sure that the new data indices start from 0.
|
||||
assert data[OnlineBuffer.EPISODE_INDEX_KEY][0].item() == 0
|
||||
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
|
||||
|
||||
# Shift the incoming indices if necessary.
|
||||
if self.num_samples > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||
data[OnlineBuffer.INDEX_KEY] += last_data_index + 1
|
||||
|
||||
# Insert the new data starting from next_index. It may be necessary to wrap around to the start.
|
||||
n_surplus = max(0, new_data_length - (self._buffer_capacity - next_index))
|
||||
for k in self.data_keys:
|
||||
if n_surplus == 0:
|
||||
slc = slice(next_index, next_index + new_data_length)
|
||||
self._data[k][slc] = data[k]
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][slc] = True
|
||||
else:
|
||||
self._data[k][next_index:] = data[k][:-n_surplus]
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY][next_index:] = True
|
||||
self._data[k][:n_surplus] = data[k][-n_surplus:]
|
||||
if n_surplus == 0:
|
||||
self._data[OnlineBuffer.NEXT_INDEX_KEY] = next_index + new_data_length
|
||||
else:
|
||||
self._data[OnlineBuffer.NEXT_INDEX_KEY] = n_surplus
|
||||
|
||||
@property
|
||||
def data_keys(self) -> list[str]:
|
||||
keys = set(self._data)
|
||||
keys.remove(OnlineBuffer.OCCUPANCY_MASK_KEY)
|
||||
keys.remove(OnlineBuffer.NEXT_INDEX_KEY)
|
||||
return sorted(keys)
|
||||
|
||||
@property
|
||||
def fps(self) -> float | None:
|
||||
return self._fps
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(
|
||||
np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def _item_to_tensors(self, item: dict) -> dict:
|
||||
item_ = {}
|
||||
for k, v in item.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
item_[k] = v
|
||||
elif isinstance(v, np.ndarray):
|
||||
item_[k] = torch.from_numpy(v)
|
||||
else:
|
||||
item_[k] = torch.tensor(v)
|
||||
return item_
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self) or idx < -len(self):
|
||||
raise IndexError
|
||||
|
||||
item = {k: v[idx] for k, v in self._data.items() if not k.startswith("_")}
|
||||
|
||||
if self.delta_timestamps is None:
|
||||
return self._item_to_tensors(item)
|
||||
|
||||
episode_index = item[OnlineBuffer.EPISODE_INDEX_KEY]
|
||||
current_ts = item[OnlineBuffer.TIMESTAMP_KEY]
|
||||
episode_data_indices = np.where(
|
||||
np.bitwise_and(
|
||||
self._data[OnlineBuffer.EPISODE_INDEX_KEY] == episode_index,
|
||||
self._data[OnlineBuffer.OCCUPANCY_MASK_KEY],
|
||||
)
|
||||
)[0]
|
||||
episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices]
|
||||
|
||||
for data_key in self.delta_timestamps:
|
||||
# Note: The logic in this loop is copied from `load_previous_and_future_frames`.
|
||||
# Get timestamps used as query to retrieve data of previous/future frames.
|
||||
query_ts = current_ts + self.delta_timestamps[data_key]
|
||||
|
||||
# Compute distances between each query timestamp and all timestamps of all the frames belonging to
|
||||
# the episode.
|
||||
dist = np.abs(query_ts[:, None] - episode_timestamps[None, :])
|
||||
argmin_ = np.argmin(dist, axis=1)
|
||||
min_ = dist[np.arange(dist.shape[0]), argmin_]
|
||||
|
||||
is_pad = min_ > self.tolerance_s
|
||||
|
||||
# Check violated query timestamps are all outside the episode range.
|
||||
assert (
|
||||
(query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad])
|
||||
).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}"
|
||||
") inside the episode range."
|
||||
)
|
||||
|
||||
# Load frames for this data key.
|
||||
item[data_key] = self._data[data_key][episode_data_indices[argmin_]]
|
||||
|
||||
item[f"{data_key}{OnlineBuffer.IS_PAD_POSTFIX}"] = is_pad
|
||||
|
||||
return self._item_to_tensors(item)
|
||||
|
||||
def get_data_by_key(self, key: str) -> torch.Tensor:
|
||||
"""Returns all data for a given data key as a Tensor."""
|
||||
return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]])
|
||||
|
||||
|
||||
def compute_sampler_weights(
|
||||
offline_dataset: LeRobotDataset,
|
||||
offline_drop_n_last_frames: int = 0,
|
||||
online_dataset: OnlineBuffer | None = None,
|
||||
online_sampling_ratio: float | None = None,
|
||||
online_drop_n_last_frames: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Compute the sampling weights for the online training dataloader in train.py.
|
||||
|
||||
Args:
|
||||
offline_dataset: The LeRobotDataset used for offline pre-training.
|
||||
online_drop_n_last_frames: Number of frames to drop from the end of each offline dataset episode.
|
||||
online_dataset: The OnlineBuffer used in online training.
|
||||
online_sampling_ratio: The proportion of data that should be sampled from the online dataset. If an
|
||||
online dataset is provided, this value must also be provided.
|
||||
online_drop_n_first_frames: See `offline_drop_n_last_frames`. This is the same, but for the online
|
||||
dataset.
|
||||
Returns:
|
||||
Tensor of weights for [offline_dataset; online_dataset], normalized to 1.
|
||||
|
||||
Notes to maintainers:
|
||||
- This duplicates some logic from EpisodeAwareSampler. We should consider converging to one approach.
|
||||
- When used with `torch.utils.data.WeightedRandomSampler`, it could completely replace
|
||||
`EpisodeAwareSampler` as the online dataset related arguments are optional. The only missing feature
|
||||
is the ability to turn shuffling off.
|
||||
- Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not
|
||||
included here to avoid adding complexity.
|
||||
"""
|
||||
if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0):
|
||||
raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.")
|
||||
if (online_dataset is None) ^ (online_sampling_ratio is None):
|
||||
raise ValueError(
|
||||
"`online_dataset` and `online_sampling_ratio` must be provided together or not at all."
|
||||
)
|
||||
offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio
|
||||
|
||||
weights = []
|
||||
|
||||
if len(offline_dataset) > 0:
|
||||
offline_data_mask_indices = []
|
||||
for start_index, end_index in zip(
|
||||
offline_dataset.episode_data_index["from"],
|
||||
offline_dataset.episode_data_index["to"],
|
||||
strict=True,
|
||||
):
|
||||
offline_data_mask_indices.extend(
|
||||
range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
|
||||
)
|
||||
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
|
||||
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
|
||||
weights.append(
|
||||
torch.full(
|
||||
size=(len(offline_dataset),),
|
||||
fill_value=offline_sampling_ratio / offline_data_mask.sum(),
|
||||
)
|
||||
* offline_data_mask
|
||||
)
|
||||
|
||||
if online_dataset is not None and len(online_dataset) > 0:
|
||||
online_data_mask_indices = []
|
||||
episode_indices = online_dataset.get_data_by_key("episode_index")
|
||||
for episode_idx in torch.unique(episode_indices):
|
||||
where_episode = torch.where(episode_indices == episode_idx)
|
||||
start_index = where_episode[0][0]
|
||||
end_index = where_episode[0][-1] + 1
|
||||
online_data_mask_indices.extend(
|
||||
range(start_index.item(), end_index.item() - online_drop_n_last_frames)
|
||||
)
|
||||
online_data_mask = torch.zeros(len(online_dataset), dtype=torch.bool)
|
||||
online_data_mask[torch.tensor(online_data_mask_indices)] = True
|
||||
weights.append(
|
||||
torch.full(
|
||||
size=(len(online_dataset),),
|
||||
fill_value=online_sampling_ratio / online_data_mask.sum(),
|
||||
)
|
||||
* online_data_mask
|
||||
)
|
||||
|
||||
weights = torch.cat(weights)
|
||||
|
||||
if weights.sum() == 0:
|
||||
weights += 1 / len(weights)
|
||||
else:
|
||||
weights /= weights.sum()
|
||||
|
||||
return weights
|
||||
@@ -1,56 +0,0 @@
|
||||
## Using / Updating `CODEBASE_VERSION` (for maintainers)
|
||||
|
||||
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
|
||||
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
|
||||
lerobot/common/datasets/lerobot_dataset.py) variable.
|
||||
|
||||
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
|
||||
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
|
||||
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
|
||||
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
|
||||
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
|
||||
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
|
||||
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5)
|
||||
- [v1.6](https://huggingface.co/datasets/lerobot/pusht/tree/v1.6) <-- last version
|
||||
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
|
||||
|
||||
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
|
||||
`info.json` metadata.
|
||||
|
||||
### Uploading a new dataset
|
||||
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
|
||||
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
|
||||
dataset with the current `CODEBASE_VERSION`.
|
||||
|
||||
### Updating an existing dataset
|
||||
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
|
||||
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
|
||||
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
|
||||
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
|
||||
codebase won't be affected by your change and backward compatibility is maintained.
|
||||
|
||||
However, you will need to update the version of ALL the other datasets so that they have the new
|
||||
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
|
||||
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
|
||||
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
|
||||
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
|
||||
api = HfApi()
|
||||
|
||||
for repo_id in available_datasets:
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if CODEBASE_VERSION in branches:
|
||||
print(f"{repo_id} already @{CODEBASE_VERSION}, skipping.")
|
||||
continue
|
||||
else:
|
||||
# Now create a branch named after the new version by branching out from "main"
|
||||
# which is expected to be the preceding version
|
||||
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
|
||||
print(f"{repo_id} successfully updated @{CODEBASE_VERSION}")
|
||||
```
|
||||
@@ -14,121 +14,156 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file contains download scripts for raw datasets.
|
||||
|
||||
Example of usage:
|
||||
```
|
||||
python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py \
|
||||
--raw-dir data/lerobot-raw/pusht_raw \
|
||||
--repo-id lerobot-raw/pusht_raw
|
||||
```
|
||||
This file contains all obsolete download scripts. They are centralized here to not have to load
|
||||
useless dependencies when using datasets.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import logging
|
||||
import warnings
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
|
||||
# {raw_repo_id: raw_format}
|
||||
AVAILABLE_RAW_REPO_IDS = {
|
||||
"lerobot-raw/aloha_mobile_cabinet_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_chair_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_elevator_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_shrimp_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_wash_pan_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_mobile_wipe_wine_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_sim_insertion_human_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_sim_insertion_scripted_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_sim_transfer_cube_human_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_sim_transfer_cube_scripted_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_battery_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_candy_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_coffee_new_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_coffee_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_cups_open_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_fork_pick_up_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_pingpong_test_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_pro_pencil_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_screw_driver_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_tape_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_thread_velcro_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_towel_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_vinh_cup_left_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_vinh_cup_raw": "aloha_hdf5",
|
||||
"lerobot-raw/aloha_static_ziploc_slide_raw": "aloha_hdf5",
|
||||
"lerobot-raw/pusht_raw": "pusht_zarr",
|
||||
"lerobot-raw/umi_cup_in_the_wild_raw": "umi_zarr",
|
||||
"lerobot-raw/unitreeh1_fold_clothes_raw": "aloha_hdf5",
|
||||
"lerobot-raw/unitreeh1_rearrange_objects_raw": "aloha_hdf5",
|
||||
"lerobot-raw/unitreeh1_two_robot_greeting_raw": "aloha_hdf5",
|
||||
"lerobot-raw/unitreeh1_warehouse_raw": "aloha_hdf5",
|
||||
"lerobot-raw/xarm_lift_medium_raw": "xarm_pkl",
|
||||
"lerobot-raw/xarm_lift_medium_replay_raw": "xarm_pkl",
|
||||
"lerobot-raw/xarm_push_medium_raw": "xarm_pkl",
|
||||
"lerobot-raw/xarm_push_medium_replay_raw": "xarm_pkl",
|
||||
}
|
||||
def download_raw(raw_dir, dataset_id):
|
||||
if "aloha" in dataset_id or "image" in dataset_id:
|
||||
download_hub(raw_dir, dataset_id)
|
||||
elif "pusht" in dataset_id:
|
||||
download_pusht(raw_dir)
|
||||
elif "xarm" in dataset_id:
|
||||
download_xarm(raw_dir)
|
||||
elif "umi" in dataset_id:
|
||||
download_umi(raw_dir)
|
||||
else:
|
||||
raise ValueError(dataset_id)
|
||||
|
||||
|
||||
def download_raw(raw_dir: Path, repo_id: str):
|
||||
check_repo_id(repo_id)
|
||||
user_id, dataset_id = repo_id.split("/")
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
import zipfile
|
||||
|
||||
if not dataset_id.endswith("_raw"):
|
||||
warnings.warn(
|
||||
f"""`dataset_id` ({dataset_id}) doesn't end with '_raw' (e.g. 'lerobot/pusht_raw'). Following this
|
||||
naming convention by renaming your repository is advised, but not mandatory.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
import requests
|
||||
|
||||
# Send warning if raw_dir isn't well formated
|
||||
if raw_dir.parts[-2] != user_id or raw_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"""`raw_dir` ({raw_dir}) doesn't contain a community or user id `/` the name of the dataset that
|
||||
match the `repo_id` (e.g. 'data/lerobot/pusht_raw'). Following this naming convention is advised,
|
||||
but not mandatory.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
print(f"downloading from {url}")
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
||||
|
||||
zip_file = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
zip_file.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
zip_file.seek(0)
|
||||
|
||||
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_folder)
|
||||
|
||||
|
||||
def download_pusht(raw_dir: str):
|
||||
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(pusht_url, raw_dir)
|
||||
# file is created inside a useful "pusht" directory, so we move it out and delete the dir
|
||||
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
|
||||
shutil.move(raw_dir / "pusht" / "pusht_cchi_v7_replay.zarr", zarr_path)
|
||||
shutil.rmtree(raw_dir / "pusht")
|
||||
|
||||
|
||||
def download_xarm(raw_dir: Path):
|
||||
"""Download all xarm datasets at once"""
|
||||
import zipfile
|
||||
|
||||
import gdown
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
zip_path = raw_dir / "data.zip"
|
||||
gdown.download(url, str(zip_path), quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
||||
for pkl_path in zip_f.namelist():
|
||||
if pkl_path.startswith("data/xarm") and pkl_path.endswith(".pkl"):
|
||||
zip_f.extract(member=pkl_path)
|
||||
# move to corresponding raw directory
|
||||
extract_dir = pkl_path.replace("/buffer.pkl", "")
|
||||
raw_pkl_path = raw_dir / "buffer.pkl"
|
||||
shutil.move(pkl_path, raw_pkl_path)
|
||||
shutil.rmtree(extract_dir)
|
||||
zip_path.unlink()
|
||||
|
||||
|
||||
def download_hub(raw_dir: Path, dataset_id: str):
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
logging.info(f"Start downloading from huggingface.co/cadene for {dataset_id}")
|
||||
snapshot_download(f"cadene/{dataset_id}_raw", repo_type="dataset", local_dir=raw_dir)
|
||||
logging.info(f"Finish downloading from huggingface.co/cadene for {dataset_id}")
|
||||
|
||||
|
||||
def download_all_raw_datasets(data_dir: Path | None = None):
|
||||
if data_dir is None:
|
||||
data_dir = Path("data")
|
||||
for repo_id in AVAILABLE_RAW_REPO_IDS:
|
||||
raw_dir = data_dir / repo_id
|
||||
download_raw(raw_dir, repo_id)
|
||||
def download_umi(raw_dir: Path):
|
||||
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=f"""A script to download raw datasets from Hugging Face hub to a local directory. Here is a
|
||||
non exhaustive list of available repositories to use in `--repo-id`: {AVAILABLE_RAW_REPO_IDS}""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""Repositery identifier on Hugging Face: a community or a user name `/` the name of
|
||||
the dataset (e.g. `lerobot/pusht_raw`, `cadene/aloha_sim_insertion_human_raw`).""",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
download_raw(**vars(args))
|
||||
raw_dir = Path(raw_dir)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
data_dir = Path("data")
|
||||
dataset_ids = [
|
||||
"pusht_image",
|
||||
"xarm_lift_medium_image",
|
||||
"xarm_lift_medium_replay_image",
|
||||
"xarm_push_medium_image",
|
||||
"xarm_push_medium_replay_image",
|
||||
"aloha_sim_insertion_human_image",
|
||||
"aloha_sim_insertion_scripted_image",
|
||||
"aloha_sim_transfer_cube_human_image",
|
||||
"aloha_sim_transfer_cube_scripted_image",
|
||||
"pusht",
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
"aloha_mobile_cabinet",
|
||||
"aloha_mobile_chair",
|
||||
"aloha_mobile_elevator",
|
||||
"aloha_mobile_shrimp",
|
||||
"aloha_mobile_wash_pan",
|
||||
"aloha_mobile_wipe_wine",
|
||||
"aloha_static_battery",
|
||||
"aloha_static_candy",
|
||||
"aloha_static_coffee",
|
||||
"aloha_static_coffee_new",
|
||||
"aloha_static_cups_open",
|
||||
"aloha_static_fork_pick_up",
|
||||
"aloha_static_pingpong_test",
|
||||
"aloha_static_pro_pencil",
|
||||
"aloha_static_screw_driver",
|
||||
"aloha_static_tape",
|
||||
"aloha_static_thread_velcro",
|
||||
"aloha_static_towel",
|
||||
"aloha_static_vinh_cup",
|
||||
"aloha_static_vinh_cup_left",
|
||||
"aloha_static_ziploc_slide",
|
||||
"umi_cup_in_the_wild",
|
||||
]
|
||||
for dataset_id in dataset_ids:
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
download_raw(raw_dir, dataset_id)
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Use this script to batch encode lerobot dataset from their raw format to LeRobotDataset and push their updated
|
||||
version to the hub. Under the hood, this script reuses 'push_dataset_to_hub.py'. It assumes that you already
|
||||
downloaded raw datasets, which you can do with the related '_download_raw.py' script.
|
||||
|
||||
For instance, for codebase_version = 'v1.6', the following command was run, assuming raw datasets from
|
||||
lerobot-raw were downloaded in 'raw/datasets/directory':
|
||||
```bash
|
||||
python lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py \
|
||||
--raw-dir raw/datasets/directory \
|
||||
--raw-repo-ids lerobot-raw \
|
||||
--local-dir push/datasets/directory \
|
||||
--tests-data-dir tests/data \
|
||||
--push-repo lerobot \
|
||||
--vcodec libsvtav1 \
|
||||
--pix-fmt yuv420p \
|
||||
--g 2 \
|
||||
--crf 30
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import AVAILABLE_RAW_REPO_IDS
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
|
||||
|
||||
|
||||
def get_push_repo_id_from_raw(raw_repo_id: str, push_repo: str) -> str:
|
||||
dataset_id_raw = raw_repo_id.split("/")[1]
|
||||
dataset_id = dataset_id_raw.removesuffix("_raw")
|
||||
return f"{push_repo}/{dataset_id}"
|
||||
|
||||
|
||||
def encode_datasets(
|
||||
raw_dir: Path,
|
||||
raw_repo_ids: list[str],
|
||||
push_repo: str,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
local_dir: Path | None = None,
|
||||
tests_data_dir: Path | None = None,
|
||||
raw_format: str | None = None,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
if len(raw_repo_ids) == 1 and raw_repo_ids[0].lower() == "lerobot-raw":
|
||||
raw_repo_ids_format = AVAILABLE_RAW_REPO_IDS
|
||||
else:
|
||||
if raw_format is None:
|
||||
raise ValueError(raw_format)
|
||||
raw_repo_ids_format = {id_: raw_format for id_ in raw_repo_ids}
|
||||
|
||||
for raw_repo_id, repo_raw_format in raw_repo_ids_format.items():
|
||||
check_repo_id(raw_repo_id)
|
||||
dataset_repo_id_push = get_push_repo_id_from_raw(raw_repo_id, push_repo)
|
||||
dataset_raw_dir = raw_dir / raw_repo_id
|
||||
dataset_dir = local_dir / dataset_repo_id_push if local_dir is not None else None
|
||||
encoding = {
|
||||
"vcodec": vcodec,
|
||||
"pix_fmt": pix_fmt,
|
||||
"g": g,
|
||||
"crf": crf,
|
||||
}
|
||||
|
||||
if not (dataset_raw_dir).is_dir():
|
||||
raise NotADirectoryError(dataset_raw_dir)
|
||||
|
||||
if not dry_run:
|
||||
push_dataset_to_hub(
|
||||
dataset_raw_dir,
|
||||
raw_format=repo_raw_format,
|
||||
repo_id=dataset_repo_id_push,
|
||||
local_dir=dataset_dir,
|
||||
resume=True,
|
||||
encoding=encoding,
|
||||
tests_data_dir=tests_data_dir,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"DRY RUN: {dataset_raw_dir} --> {dataset_dir} --> {dataset_repo_id_push}@{CODEBASE_VERSION}"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
default=Path("data"),
|
||||
help="Directory where raw datasets are located.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--raw-repo-ids",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["lerobot-raw"],
|
||||
help="""Raw dataset repo ids. if 'lerobot-raw', the keys from `AVAILABLE_RAW_REPO_IDS` will be
|
||||
used and raw datasets will be fetched from the 'lerobot-raw/' repo and pushed with their
|
||||
associated format. It is assumed that each dataset is located at `raw_dir / raw_repo_id` """,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--raw-format",
|
||||
type=str,
|
||||
default=None,
|
||||
help="""Raw format to use for the raw repo-ids. Must be specified if --raw-repo-ids is not
|
||||
'lerobot-raw'""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="""When provided, writes the dataset converted to LeRobotDataset format in this directory
|
||||
(e.g. `data/lerobot/aloha_mobile_chair`).""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-repo",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="Repo to upload datasets to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vcodec",
|
||||
type=str,
|
||||
default="libsvtav1",
|
||||
help="Codec to use for encoding videos",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pix-fmt",
|
||||
type=str,
|
||||
default="yuv420p",
|
||||
help="Pixel formats (chroma subsampling) to be used for encoding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Group of pictures sizes to be used for encoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crf",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Constant rate factors to be used for encoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help=(
|
||||
"When provided, save tests artifacts into the given directory "
|
||||
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
type=int,
|
||||
default=0,
|
||||
help="If not set to 0, this script won't download or upload anything.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
encode_datasets(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -17,20 +17,19 @@
|
||||
Contains utilities to process raw data format from dora-record
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
@@ -42,7 +41,7 @@ def check_format(raw_dir) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
||||
# Load data stream that will be used as reference for the timestamps synchronization
|
||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||
if len(reference_files) == 0:
|
||||
@@ -123,7 +122,8 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
||||
|
||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||
videos_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
videos_dir = out_dir / "videos"
|
||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||
|
||||
# sanity check the video paths are well formated
|
||||
@@ -156,7 +156,16 @@ def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episod
|
||||
else:
|
||||
raise ValueError(key)
|
||||
|
||||
return data_dict
|
||||
# Get the episode index containing for each unique episode index
|
||||
first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
|
||||
from_ = first_ep_index_df["start_index"].tolist()
|
||||
to_ = from_[1:] + [len(df)]
|
||||
episode_data_index = {
|
||||
"from": from_,
|
||||
"to": to_,
|
||||
}
|
||||
|
||||
return data_dict, episode_data_index
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
@@ -194,14 +203,12 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
init_logging()
|
||||
|
||||
if debug:
|
||||
logging.warning("debug=True not implemented. Falling back to debug=False.")
|
||||
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
@@ -213,21 +220,11 @@ def from_raw_to_lerobot_format(
|
||||
if not video:
|
||||
raise NotImplementedError()
|
||||
|
||||
if encoding is not None:
|
||||
warnings.warn(
|
||||
"Video encoding is currently done outside of LeRobot for the dora_parquet format.",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
data_df = load_from_raw(raw_dir, videos_dir, fps, episodes)
|
||||
data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps)
|
||||
hf_dataset = to_hf_dataset(data_df, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = "unknown"
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
@@ -28,14 +28,8 @@ import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
@@ -49,9 +43,6 @@ def get_cameras(hdf5_data):
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
# only frames from simulation are uncompressed
|
||||
compressed_images = "sim" not in raw_dir.name
|
||||
|
||||
hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
|
||||
assert len(hdf5_paths) != 0
|
||||
for hdf5_path in hdf5_paths:
|
||||
@@ -68,32 +59,22 @@ def check_format(raw_dir) -> bool:
|
||||
for camera in get_cameras(data):
|
||||
assert num_frames == data[f"/observations/images/{camera}"].shape[0]
|
||||
|
||||
if compressed_images:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 2
|
||||
else:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||
# ndim 2 when image are compressed and 4 when uncompressed
|
||||
assert data[f"/observations/images/{camera}"].ndim in [2, 4]
|
||||
if data[f"/observations/images/{camera}"].ndim == 4:
|
||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
# only frames from simulation are uncompressed
|
||||
compressed_images = "sim" not in raw_dir.name
|
||||
|
||||
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
||||
num_episodes = len(hdf5_files)
|
||||
|
||||
hdf5_files = list(raw_dir.glob("*.hdf5"))
|
||||
ep_dicts = []
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx in tqdm.tqdm(ep_ids):
|
||||
ep_path = hdf5_files[ep_idx]
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
for ep_idx, ep_path in tqdm.tqdm(enumerate(hdf5_files), total=len(hdf5_files)):
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
@@ -113,7 +94,7 @@ def load_from_raw(
|
||||
for camera in get_cameras(ep):
|
||||
img_key = f"observation.images.{camera}"
|
||||
|
||||
if compressed_images:
|
||||
if ep[f"/observations/images/{camera}"].ndim == 2:
|
||||
import cv2
|
||||
|
||||
# load one compressed image after the other in RAM and uncompress
|
||||
@@ -128,13 +109,13 @@ def load_from_raw(
|
||||
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
video_path = out_dir / "videos" / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
@@ -161,13 +142,19 @@ def load_from_raw(
|
||||
assert isinstance(ep_idx, int)
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
gc.collect()
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
@@ -205,29 +192,18 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 50
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
hf_dataset = to_hf_dataset(data_dir, video)
|
||||
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contains utilities to process raw data format of png images files recorded with capture_camera_feed.py
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, Features, Image, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
|
||||
def check_format(raw_dir: Path) -> bool:
|
||||
image_paths = list(raw_dir.glob("frame_*.png"))
|
||||
if len(image_paths) == 0:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, fps: int, episodes: list[int] | None = None):
|
||||
if episodes is not None:
|
||||
# TODO(aliberts): add support for multi-episodes.
|
||||
raise NotImplementedError()
|
||||
|
||||
ep_dict = {}
|
||||
ep_idx = 0
|
||||
|
||||
image_paths = sorted(raw_dir.glob("frame_*.png"))
|
||||
num_frames = len(image_paths)
|
||||
|
||||
ep_dict["observation.image"] = [PILImage.open(x) for x in image_paths]
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
|
||||
ep_dicts = [ep_dict]
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
if video:
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
if video or episodes or encoding is not None:
|
||||
# TODO(aliberts): support this
|
||||
raise NotImplementedError
|
||||
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 30
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
return hf_dataset, episode_data_index, info
|
||||
@@ -25,14 +25,8 @@ import zarr
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
@@ -59,15 +53,7 @@ def check_format(raw_dir):
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
keypoints_instead_of_image: bool = False,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
@@ -85,6 +71,7 @@ def load_from_raw(
|
||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||
|
||||
episode_ids = torch.from_numpy(zarr_data.get_episode_idxs())
|
||||
num_episodes = zarr_data.meta["episode_ends"].shape[0]
|
||||
assert len(
|
||||
{zarr_data[key].shape[0] for key in zarr_data.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
@@ -97,44 +84,32 @@ def load_from_raw(
|
||||
states = torch.from_numpy(zarr_data["state"])
|
||||
actions = torch.from_numpy(zarr_data["action"])
|
||||
|
||||
# load data indices from which each episode starts and ends
|
||||
from_ids, to_ids = [], []
|
||||
from_idx = 0
|
||||
for to_idx in zarr_data.meta["episode_ends"]:
|
||||
from_ids.append(from_idx)
|
||||
to_ids.append(to_idx)
|
||||
from_idx = to_idx
|
||||
|
||||
num_episodes = len(from_ids)
|
||||
|
||||
ep_dicts = []
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
id_from = 0
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
||||
id_to = zarr_data.meta["episode_ends"][ep_idx]
|
||||
num_frames = id_to - id_from
|
||||
|
||||
# sanity check
|
||||
assert (episode_ids[from_idx:to_idx] == ep_idx).all()
|
||||
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
||||
|
||||
# get image
|
||||
if not keypoints_instead_of_image:
|
||||
image = imgs[from_idx:to_idx]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
image = imgs[id_from:id_to]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
# get state
|
||||
state = states[from_idx:to_idx]
|
||||
state = states[id_from:id_to]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
|
||||
# get reward, success, done, and (maybe) keypoints
|
||||
# get reward, success, done
|
||||
reward = torch.zeros(num_frames)
|
||||
success = torch.zeros(num_frames, dtype=torch.bool)
|
||||
if keypoints_instead_of_image:
|
||||
keypoints = torch.zeros(num_frames, 16) # 8 keypoints each with 2 coords
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
@@ -150,7 +125,7 @@ def load_from_raw(
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
block_body = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
@@ -158,41 +133,34 @@ def load_from_raw(
|
||||
coverage = intersection_area / goal_area
|
||||
reward[i] = np.clip(coverage / success_threshold, 0, 1)
|
||||
success[i] = coverage > success_threshold
|
||||
if keypoints_instead_of_image:
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
if not keypoints_instead_of_image:
|
||||
imgs_array = [x.numpy() for x in image]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
imgs_array = [x.numpy() for x in image]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = out_dir / "videos" / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = agent_pos
|
||||
if keypoints_instead_of_image:
|
||||
ep_dict["observation.environment_state"] = keypoints
|
||||
ep_dict["action"] = actions[from_idx:to_idx]
|
||||
ep_dict["action"] = actions[id_from:id_to]
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
@@ -203,30 +171,31 @@ def load_from_raw(
|
||||
ep_dict["next.done"] = torch.cat([done[1:], done[[-1]]])
|
||||
ep_dict["next.success"] = torch.cat([success[1:], success[[-1]]])
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
return data_dict, episode_data_index
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
|
||||
def to_hf_dataset(data_dict, video):
|
||||
features = {}
|
||||
|
||||
if not keypoints_instead_of_image:
|
||||
if video:
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
if video:
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if keypoints_instead_of_image:
|
||||
features["observation.environment_state"] = Sequence(
|
||||
length=data_dict["observation.environment_state"].shape[1],
|
||||
feature=Value(dtype="float32", id=None),
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
@@ -243,33 +212,18 @@ def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False):
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
# Manually change this to True to use keypoints of the T instead of an image observation (but don't merge
|
||||
# with True). Also make sure to use video = 0 in the `push_dataset_to_hub.py` script.
|
||||
keypoints_instead_of_image = False
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 10
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding)
|
||||
hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video if not keypoints_instead_of_image else 0,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
return hf_dataset, episode_data_index, info
|
||||
|
||||
@@ -19,21 +19,16 @@ import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import zarr
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
@@ -64,14 +59,23 @@ def check_format(raw_dir) -> bool:
|
||||
assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
def get_episode_idxs(episode_ends: np.ndarray) -> np.ndarray:
|
||||
# Optimized and simplified version of this function: https://github.com/real-stanford/universal_manipulation_interface/blob/298776ce251f33b6b3185a98d6e7d1f9ad49168b/diffusion_policy/common/replay_buffer.py#L374
|
||||
from numba import jit
|
||||
|
||||
@jit(nopython=True)
|
||||
def _get_episode_idxs(episode_ends):
|
||||
result = np.zeros((episode_ends[-1],), dtype=np.int64)
|
||||
start_idx = 0
|
||||
for episode_number, end_idx in enumerate(episode_ends):
|
||||
result[start_idx:end_idx] = episode_number
|
||||
start_idx = end_idx
|
||||
return result
|
||||
|
||||
return _get_episode_idxs(episode_ends)
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
zarr_path = raw_dir / "cup_in_the_wild.zarr"
|
||||
zarr_data = zarr.open(zarr_path, mode="r")
|
||||
|
||||
@@ -88,79 +92,74 @@ def load_from_raw(
|
||||
episode_ends = zarr_data["meta/episode_ends"][:]
|
||||
num_episodes = episode_ends.shape[0]
|
||||
|
||||
episode_ids = torch.from_numpy(get_episode_idxs(episode_ends))
|
||||
|
||||
# We convert it in torch tensor later because the jit function does not support torch tensors
|
||||
episode_ends = torch.from_numpy(episode_ends)
|
||||
|
||||
# load data indices from which each episode starts and ends
|
||||
from_ids, to_ids = [], []
|
||||
from_idx = 0
|
||||
for to_idx in episode_ends:
|
||||
from_ids.append(from_idx)
|
||||
to_ids.append(to_idx)
|
||||
from_idx = to_idx
|
||||
|
||||
ep_dicts_dir = videos_dir / "ep_dicts"
|
||||
ep_dicts_dir.mkdir(exist_ok=True, parents=True)
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||
ep_dict_path = ep_dicts_dir / f"{ep_idx}"
|
||||
if not ep_dict_path.is_file():
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
id_from = 0
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes)):
|
||||
id_to = episode_ends[ep_idx]
|
||||
num_frames = id_to - id_from
|
||||
|
||||
# TODO(rcadene): save temporary images of the episode?
|
||||
# sanity heck
|
||||
assert (episode_ids[id_from:id_to] == ep_idx).all()
|
||||
|
||||
state = states[from_idx:to_idx]
|
||||
# TODO(rcadene): save temporary images of the episode?
|
||||
|
||||
ep_dict = {}
|
||||
state = states[id_from:id_to]
|
||||
|
||||
# load 57MB of images in RAM (400x224x224x3 uint8)
|
||||
imgs_array = zarr_data["data/camera0_rgb"][from_idx:to_idx]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
if not video_path.is_file():
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
ep_dict = {}
|
||||
|
||||
# encode images to a mp4 video
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
# load 57MB of images in RAM (400x224x224x3 uint8)
|
||||
imgs_array = zarr_data["data/camera0_rgb"][id_from:id_to]
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = out_dir / "videos" / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames)
|
||||
ep_dict["end_pose"] = end_pose[from_idx:to_idx]
|
||||
ep_dict["start_pos"] = start_pos[from_idx:to_idx]
|
||||
ep_dict["gripper_width"] = gripper_width[from_idx:to_idx]
|
||||
torch.save(ep_dict, ep_dict_path)
|
||||
# store the reference to the video frame
|
||||
ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)]
|
||||
else:
|
||||
ep_dict = torch.load(ep_dict_path)
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_data_index_from"] = torch.tensor([id_from] * num_frames)
|
||||
ep_dict["episode_data_index_to"] = torch.tensor([id_from + num_frames] * num_frames)
|
||||
ep_dict["end_pose"] = end_pose[id_from:id_to]
|
||||
ep_dict["start_pos"] = start_pos[id_from:id_to]
|
||||
ep_dict["gripper_width"] = gripper_width[id_from:id_to]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
id_from += num_frames
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
total_frames = id_from
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
|
||||
return data_dict, episode_data_index
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
@@ -200,14 +199,7 @@ def to_hf_dataset(data_dict, video):
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
@@ -220,15 +212,11 @@ def from_raw_to_lerobot_format(
|
||||
"Generating UMI dataset without `video=True` creates ~150GB on disk and requires ~80GB in RAM."
|
||||
)
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
@@ -21,22 +20,25 @@ import numpy
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
|
||||
|
||||
def concatenate_episodes(ep_dicts):
|
||||
def concatenate_episodes(ep_dicts, drop_episodes_last_frame=False):
|
||||
data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
if drop_episodes_last_frame:
|
||||
data_dict[key] = torch.cat([ep_dict[key][:-1] for ep_dict in ep_dicts])
|
||||
else:
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
data_dict[key].append(x)
|
||||
if drop_episodes_last_frame:
|
||||
data_dict[key].pop()
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
@@ -54,21 +56,3 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
|
||||
|
||||
def get_default_encoding() -> dict:
|
||||
"""Returns the default ffmpeg encoding parameters used by `encode_video_frames`."""
|
||||
signature = inspect.signature(encode_video_frames)
|
||||
return {
|
||||
k: v.default
|
||||
for k, v in signature.parameters.items()
|
||||
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
|
||||
}
|
||||
|
||||
|
||||
def check_repo_id(repo_id: str) -> None:
|
||||
if len(repo_id.split("/")) != 2:
|
||||
raise ValueError(
|
||||
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
|
||||
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
|
||||
)
|
||||
|
||||
@@ -25,14 +25,8 @@ import tqdm
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
@@ -60,49 +54,37 @@ def check_format(raw_dir):
|
||||
assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict)
|
||||
|
||||
|
||||
def load_from_raw(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int,
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
pkl_path = raw_dir / "buffer.pkl"
|
||||
|
||||
with open(pkl_path, "rb") as f:
|
||||
pkl_data = pickle.load(f)
|
||||
|
||||
# load data indices from which each episode starts and ends
|
||||
from_ids, to_ids = [], []
|
||||
from_idx, to_idx = 0, 0
|
||||
for done in pkl_data["dones"]:
|
||||
to_idx += 1
|
||||
if not done:
|
||||
continue
|
||||
from_ids.append(from_idx)
|
||||
to_ids.append(to_idx)
|
||||
from_idx = to_idx
|
||||
|
||||
num_episodes = len(from_ids)
|
||||
|
||||
ep_dicts = []
|
||||
ep_ids = episodes if episodes else range(num_episodes)
|
||||
for ep_idx, selected_ep_idx in tqdm.tqdm(enumerate(ep_ids)):
|
||||
from_idx = from_ids[selected_ep_idx]
|
||||
to_idx = to_ids[selected_ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
image = torch.tensor(pkl_data["observations"]["rgb"][from_idx:to_idx])
|
||||
id_from = 0
|
||||
id_to = 0
|
||||
ep_idx = 0
|
||||
total_frames = pkl_data["actions"].shape[0]
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
id_to += 1
|
||||
|
||||
if not pkl_data["dones"][i]:
|
||||
continue
|
||||
|
||||
num_frames = id_to - id_from
|
||||
|
||||
image = torch.tensor(pkl_data["observations"]["rgb"][id_from:id_to])
|
||||
image = einops.rearrange(image, "b c h w -> b h w c")
|
||||
state = torch.tensor(pkl_data["observations"]["state"][from_idx:to_idx])
|
||||
action = torch.tensor(pkl_data["actions"][from_idx:to_idx])
|
||||
state = torch.tensor(pkl_data["observations"]["state"][id_from:id_to])
|
||||
action = torch.tensor(pkl_data["actions"][id_from:id_to])
|
||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][from_idx:to_idx])
|
||||
# next_state = torch.tensor(pkl_data["next_observations"]["state"][from_idx:to_idx])
|
||||
next_reward = torch.tensor(pkl_data["rewards"][from_idx:to_idx])
|
||||
next_done = torch.tensor(pkl_data["dones"][from_idx:to_idx])
|
||||
# next_image = torch.tensor(pkl_data["next_observations"]["rgb"][id_from:id_to])
|
||||
# next_state = torch.tensor(pkl_data["next_observations"]["state"][id_from:id_to])
|
||||
next_reward = torch.tensor(pkl_data["rewards"][id_from:id_to])
|
||||
next_done = torch.tensor(pkl_data["dones"][id_from:id_to])
|
||||
|
||||
ep_dict = {}
|
||||
|
||||
@@ -110,13 +92,13 @@ def load_from_raw(
|
||||
img_key = "observation.image"
|
||||
if video:
|
||||
# save png images in temporary directory
|
||||
tmp_imgs_dir = videos_dir / "tmp_images"
|
||||
tmp_imgs_dir = out_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
video_path = videos_dir / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {}))
|
||||
video_path = out_dir / "videos" / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
@@ -137,11 +119,18 @@ def load_from_raw(
|
||||
ep_dict["next.done"] = next_done
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
return data_dict
|
||||
id_from = id_to
|
||||
ep_idx += 1
|
||||
|
||||
# process first episode only
|
||||
if debug:
|
||||
break
|
||||
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
return data_dict, episode_data_index
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video):
|
||||
@@ -172,29 +161,18 @@ def to_hf_dataset(data_dict, video):
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(
|
||||
raw_dir: Path,
|
||||
videos_dir: Path,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
):
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 15
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||
data_dict, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
return hf_dataset, episode_data_index, info
|
||||
|
||||
@@ -1,197 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
from typing import Any, Callable, Dict, Sequence
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import Transform
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
|
||||
class RandomSubsetApply(Transform):
|
||||
"""Apply a random subset of N transformations from a list of transformations.
|
||||
|
||||
Args:
|
||||
transforms: list of transformations.
|
||||
p: represents the multinomial probabilities (with no replacement) used for sampling the transform.
|
||||
If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms
|
||||
have the same probability.
|
||||
n_subset: number of transformations to apply. If ``None``, all transforms are applied.
|
||||
Must be in [1, len(transforms)].
|
||||
random_order: apply transformations in a random order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transforms: Sequence[Callable],
|
||||
p: list[float] | None = None,
|
||||
n_subset: int | None = None,
|
||||
random_order: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not isinstance(transforms, Sequence):
|
||||
raise TypeError("Argument transforms should be a sequence of callables")
|
||||
if p is None:
|
||||
p = [1] * len(transforms)
|
||||
elif len(p) != len(transforms):
|
||||
raise ValueError(
|
||||
f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
|
||||
)
|
||||
|
||||
if n_subset is None:
|
||||
n_subset = len(transforms)
|
||||
elif not isinstance(n_subset, int):
|
||||
raise TypeError("n_subset should be an int or None")
|
||||
elif not (1 <= n_subset <= len(transforms)):
|
||||
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
|
||||
|
||||
self.transforms = transforms
|
||||
total = sum(p)
|
||||
self.p = [prob / total for prob in p]
|
||||
self.n_subset = n_subset
|
||||
self.random_order = random_order
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
needs_unpacking = len(inputs) > 1
|
||||
|
||||
selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
|
||||
if not self.random_order:
|
||||
selected_indices = selected_indices.sort().values
|
||||
|
||||
selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
|
||||
for transform in selected_transforms:
|
||||
outputs = transform(*inputs)
|
||||
inputs = outputs if needs_unpacking else (outputs,)
|
||||
|
||||
return outputs
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
f"transforms={self.transforms}, "
|
||||
f"p={self.p}, "
|
||||
f"n_subset={self.n_subset}, "
|
||||
f"random_order={self.random_order}"
|
||||
)
|
||||
|
||||
|
||||
class SharpnessJitter(Transform):
|
||||
"""Randomly change the sharpness of an image or video.
|
||||
|
||||
Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly.
|
||||
While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image,
|
||||
SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of
|
||||
augmentations as a result.
|
||||
|
||||
A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness
|
||||
by a factor of 2.
|
||||
|
||||
If the input is a :class:`torch.Tensor`,
|
||||
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
|
||||
Args:
|
||||
sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from
|
||||
[max(0, 1 - sharpness), 1 + sharpness] or the given
|
||||
[min, max]. Should be non negative numbers.
|
||||
"""
|
||||
|
||||
def __init__(self, sharpness: float | Sequence[float]) -> None:
|
||||
super().__init__()
|
||||
self.sharpness = self._check_input(sharpness)
|
||||
|
||||
def _check_input(self, sharpness):
|
||||
if isinstance(sharpness, (int, float)):
|
||||
if sharpness < 0:
|
||||
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||
sharpness[0] = max(sharpness[0], 0.0)
|
||||
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
|
||||
sharpness = [float(v) for v in sharpness]
|
||||
else:
|
||||
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
|
||||
|
||||
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
||||
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
|
||||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
def _generate_value(self, left: float, right: float) -> float:
|
||||
return torch.empty(1).uniform_(left, right).item()
|
||||
|
||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
def get_image_transforms(
|
||||
brightness_weight: float = 1.0,
|
||||
brightness_min_max: tuple[float, float] | None = None,
|
||||
contrast_weight: float = 1.0,
|
||||
contrast_min_max: tuple[float, float] | None = None,
|
||||
saturation_weight: float = 1.0,
|
||||
saturation_min_max: tuple[float, float] | None = None,
|
||||
hue_weight: float = 1.0,
|
||||
hue_min_max: tuple[float, float] | None = None,
|
||||
sharpness_weight: float = 1.0,
|
||||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
):
|
||||
def check_value(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
if len(min_max) != 2:
|
||||
raise ValueError(
|
||||
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
|
||||
)
|
||||
if weight < 0.0:
|
||||
raise ValueError(
|
||||
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
|
||||
)
|
||||
|
||||
check_value("brightness", brightness_weight, brightness_min_max)
|
||||
check_value("contrast", contrast_weight, contrast_min_max)
|
||||
check_value("saturation", saturation_weight, saturation_min_max)
|
||||
check_value("hue", hue_weight, hue_min_max)
|
||||
check_value("sharpness", sharpness_weight, sharpness_min_max)
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
if contrast_min_max is not None and contrast_weight > 0.0:
|
||||
weights.append(contrast_weight)
|
||||
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
|
||||
if saturation_min_max is not None and saturation_weight > 0.0:
|
||||
weights.append(saturation_weight)
|
||||
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
|
||||
if hue_min_max is not None and hue_weight > 0.0:
|
||||
weights.append(hue_weight)
|
||||
transforms.append(v2.ColorJitter(hue=hue_min_max))
|
||||
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||
|
||||
n_subset = len(transforms)
|
||||
if max_num_transforms is not None:
|
||||
n_subset = min(n_subset, max_num_transforms)
|
||||
|
||||
if n_subset == 0:
|
||||
return v2.Identity()
|
||||
else:
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
|
||||
@@ -15,27 +15,17 @@
|
||||
# limitations under the License.
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.torch import load_file
|
||||
from torchvision import transforms
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
# Metadata will go there
|
||||
---
|
||||
This dataset was created using [🤗 LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="/"):
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
@@ -90,28 +80,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
return items_dict
|
||||
|
||||
|
||||
@cache
|
||||
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if version not in branches:
|
||||
warnings.warn(
|
||||
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
||||
codebase. The following versions are available: {branches}.
|
||||
The requested version ('{version}') is not found. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
if "main" not in branches:
|
||||
raise ValueError(f"Version 'main' not found on {repo_id}")
|
||||
return "main"
|
||||
else:
|
||||
return version
|
||||
|
||||
|
||||
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
|
||||
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if root is not None:
|
||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
||||
@@ -132,9 +101,7 @@ def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datas
|
||||
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
|
||||
)
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
|
||||
|
||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
@@ -152,9 +119,8 @@ def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
|
||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
|
||||
)
|
||||
|
||||
return load_file(path)
|
||||
@@ -171,10 +137,7 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
|
||||
)
|
||||
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
|
||||
|
||||
stats = load_file(path)
|
||||
return unflatten_dict(stats)
|
||||
@@ -191,8 +154,7 @@ def load_info(repo_id, version, root) -> dict:
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "info.json"
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
|
||||
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
|
||||
|
||||
with open(path) as f:
|
||||
info = json.load(f)
|
||||
@@ -204,8 +166,7 @@ def load_videos(repo_id, version, root) -> Path:
|
||||
path = Path(root) / repo_id / "videos"
|
||||
else:
|
||||
# TODO(rcadene): we download the whole repo here. see if we can avoid this
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
|
||||
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=version)
|
||||
path = Path(repo_dir) / "videos"
|
||||
|
||||
return path
|
||||
@@ -393,29 +354,3 @@ def cycle(iterable):
|
||||
yield next(iterator)
|
||||
except StopIteration:
|
||||
iterator = iter(iterable)
|
||||
|
||||
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
|
||||
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
||||
exists before creating it.
|
||||
"""
|
||||
api = HfApi()
|
||||
|
||||
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
|
||||
refs = [branch.ref for branch in branches]
|
||||
ref = f"refs/heads/{branch}"
|
||||
if ref in refs:
|
||||
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
|
||||
def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard:
|
||||
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
||||
card.data.task_categories = ["robotics"]
|
||||
card.data.tags = ["LeRobot"]
|
||||
if tags is not None:
|
||||
card.data.tags += tags
|
||||
if text is not None:
|
||||
card.text += text
|
||||
return card
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
import logging
|
||||
import subprocess
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar
|
||||
@@ -28,11 +27,7 @@ from datasets.features.features import register_feature
|
||||
|
||||
|
||||
def load_from_videos(
|
||||
item: dict[str, torch.Tensor],
|
||||
video_frame_keys: list[str],
|
||||
videos_dir: Path,
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
item: dict[str, torch.Tensor], video_frame_keys: list[str], videos_dir: Path, tolerance_s: float
|
||||
):
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
|
||||
@@ -51,14 +46,14 @@ def load_from_videos(
|
||||
raise NotImplementedError("All video paths are expected to be the same for now.")
|
||||
video_path = data_dir / paths[0]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
|
||||
item[key] = frames
|
||||
else:
|
||||
# load one frame
|
||||
timestamps = [item[key]["timestamp"]]
|
||||
video_path = data_dir / item[key]["path"]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
|
||||
item[key] = frames[0]
|
||||
|
||||
return item
|
||||
@@ -68,22 +63,11 @@ def decode_video_frames_torchvision(
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
device: str = "cpu",
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
"""Loads frames associated to the requested timestamps of a video
|
||||
|
||||
The backend can be either "pyav" (default) or "video_reader".
|
||||
"video_reader" requires installing torchvision from source, see:
|
||||
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
||||
(note that you need to compile against ffmpeg<4.3)
|
||||
|
||||
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
|
||||
For more info on video decoding, see `benchmark/video/README.md`
|
||||
|
||||
See torchvision doc for more info on these two backends:
|
||||
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
|
||||
|
||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||
@@ -94,9 +78,21 @@ def decode_video_frames_torchvision(
|
||||
|
||||
# set backend
|
||||
keyframes_only = False
|
||||
torchvision.set_video_backend(backend)
|
||||
if backend == "pyav":
|
||||
if device == "cpu":
|
||||
# explicitely use pyav
|
||||
torchvision.set_video_backend("pyav")
|
||||
keyframes_only = True # pyav doesnt support accuracte seek
|
||||
elif device == "cuda":
|
||||
# TODO(rcadene, aliberts): implement video decoding with GPU
|
||||
# torchvision.set_video_backend("cuda")
|
||||
# torchvision.set_video_backend("video_reader")
|
||||
# requires installing torchvision from source, see: https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
||||
# check possible bug: https://github.com/pytorch/vision/issues/7745
|
||||
raise NotImplementedError(
|
||||
"Video decoding on gpu with cuda is currently not supported. Use `device='cpu'`."
|
||||
)
|
||||
else:
|
||||
raise ValueError(device)
|
||||
|
||||
# set a video stream reader
|
||||
# TODO(rcadene): also load audio stream at the same time
|
||||
@@ -124,9 +120,7 @@ def decode_video_frames_torchvision(
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
|
||||
if backend == "pyav":
|
||||
reader.container.close()
|
||||
|
||||
reader.container.close()
|
||||
reader = None
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
@@ -142,10 +136,6 @@ def decode_video_frames_torchvision(
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
@@ -162,59 +152,22 @@ def decode_video_frames_torchvision(
|
||||
return closest_frames
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: str | None = "error",
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
def encode_video_frames(imgs_dir: Path, video_path: Path, fps: int):
|
||||
"""More info on ffmpeg arguments tuning on `lerobot/common/datasets/_video_benchmark/README.md`"""
|
||||
video_path = Path(video_path)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ffmpeg_args = OrderedDict(
|
||||
[
|
||||
("-f", "image2"),
|
||||
("-r", str(fps)),
|
||||
("-i", str(imgs_dir / "frame_%06d.png")),
|
||||
("-vcodec", vcodec),
|
||||
("-pix_fmt", pix_fmt),
|
||||
]
|
||||
ffmpeg_cmd = (
|
||||
f"ffmpeg -r {fps} "
|
||||
"-f image2 "
|
||||
"-loglevel error "
|
||||
f"-i {str(imgs_dir / 'frame_%06d.png')} "
|
||||
"-vcodec libx264 "
|
||||
"-g 2 "
|
||||
"-pix_fmt yuv444p "
|
||||
f"{str(video_path)}"
|
||||
)
|
||||
|
||||
if g is not None:
|
||||
ffmpeg_args["-g"] = str(g)
|
||||
|
||||
if crf is not None:
|
||||
ffmpeg_args["-crf"] = str(crf)
|
||||
|
||||
if fast_decode:
|
||||
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
ffmpeg_args[key] = value
|
||||
|
||||
if log_level is not None:
|
||||
ffmpeg_args["-loglevel"] = str(log_level)
|
||||
|
||||
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
||||
if overwrite:
|
||||
ffmpeg_args.append("-y")
|
||||
|
||||
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
||||
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||
|
||||
if not video_path.exists():
|
||||
raise OSError(
|
||||
f"Video encoding did not work. File not found: {video_path}. "
|
||||
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
||||
)
|
||||
subprocess.run(ffmpeg_cmd.split(" "), check=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -19,7 +19,7 @@ import gymnasium as gym
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv:
|
||||
"""Makes a gym vector environment according to the evaluation config.
|
||||
|
||||
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
|
||||
@@ -27,9 +27,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
if n_envs is not None and n_envs < 1:
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
if cfg.env.name == "real_world":
|
||||
return
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
|
||||
try:
|
||||
|
||||
@@ -28,35 +28,33 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
if "pixels" in observations and isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
elif "pixels" in observations and isinstance(observations["pixels"], np.ndarray):
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
else:
|
||||
imgs = {f"observation.{key}": img for key, img in observations.items() if "images" in key}
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
return_observations[imgkey] = img
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
if "environment_state" in observations:
|
||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||
observations["environment_state"]
|
||||
).float()
|
||||
return_observations[imgkey] = img
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
|
||||
return return_observations
|
||||
|
||||
@@ -241,6 +241,5 @@ class Logger:
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
assert self._wandb is not None
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
|
||||
@@ -26,10 +26,7 @@ class ACTConfig:
|
||||
Those are: `input_shapes` and 'output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- Either:
|
||||
- At least one key starting with "observation.image is required as an input.
|
||||
AND/OR
|
||||
- The key "observation.environment_state" is required as input.
|
||||
- At least one key starting with "observation.image is required as an input.
|
||||
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
|
||||
views. Right now we only support all images having the same shape.
|
||||
- May optionally work without an "observation.state" key for the proprioceptive robot state.
|
||||
@@ -76,10 +73,12 @@ class ACTConfig:
|
||||
documentation in the policy class).
|
||||
latent_dim: The VAE's latent dimension.
|
||||
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
||||
temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
|
||||
ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
|
||||
1 when using this feature, as inference needs to happen at every step to form an ensemble. For
|
||||
more information on how ensembling works, please see `ACTTemporalEnsembler`.
|
||||
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
|
||||
actions for a given time step over multiple policy invocations. Updates are calculated as:
|
||||
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
|
||||
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
|
||||
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
|
||||
require `n_action_steps == 1` (since we need to query the policy every step anyway).
|
||||
dropout: Dropout to use in the transformer layers (see code for details).
|
||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||
@@ -130,15 +129,16 @@ class ACTConfig:
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
# As a consequence we also remove the final, unused layer normalization, by default
|
||||
n_decoder_layers: int = 1
|
||||
decoder_norm: bool = False
|
||||
# VAE.
|
||||
use_vae: bool = True
|
||||
latent_dim: int = 32
|
||||
n_vae_encoder_layers: int = 4
|
||||
|
||||
# Inference.
|
||||
# Note: the value used in ACT when temporal ensembling is enabled is 0.01.
|
||||
temporal_ensemble_coeff: float | None = None
|
||||
temporal_ensemble_momentum: float | None = None
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: float = 0.1
|
||||
@@ -150,7 +150,7 @@ class ACTConfig:
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:
|
||||
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
|
||||
raise NotImplementedError(
|
||||
"`n_action_steps` must be 1 when using temporal ensembling. This is "
|
||||
"because the policy needs to be queried every step to compute the ensembled action."
|
||||
@@ -164,8 +164,3 @@ class ACTConfig:
|
||||
raise ValueError(
|
||||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
)
|
||||
if (
|
||||
not any(k.startswith("observation.image") for k in self.input_shapes)
|
||||
and "observation.environment_state" not in self.input_shapes
|
||||
):
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
@@ -38,13 +38,7 @@ from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
|
||||
class ACTPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "act"],
|
||||
):
|
||||
class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
@@ -83,15 +77,12 @@ class ACTPolicy(
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
self.temporal_ensembler.reset()
|
||||
if self.config.temporal_ensemble_momentum is not None:
|
||||
self._ensembled_actions = None
|
||||
else:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@@ -106,16 +97,26 @@ class ACTPolicy(
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
||||
# the first action.
|
||||
if self.config.temporal_ensemble_momentum is not None:
|
||||
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
action = self.temporal_ensembler.update(actions)
|
||||
if self._ensembled_actions is None:
|
||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||
# time step of the episode.
|
||||
self._ensembled_actions = actions.clone()
|
||||
else:
|
||||
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||
# the EMA update for those entries.
|
||||
alpha = self.config.temporal_ensemble_momentum
|
||||
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
|
||||
# The last action, which has no prior moving average, needs to get concatenated onto the end.
|
||||
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
|
||||
# "Consume" the first action.
|
||||
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
|
||||
return action
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
@@ -134,9 +135,7 @@ class ACTPolicy(
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
@@ -161,97 +160,6 @@ class ACTPolicy(
|
||||
return loss_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
|
||||
"""Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.
|
||||
|
||||
The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
|
||||
They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
|
||||
coefficient works:
|
||||
- Setting it to 0 uniformly weighs all actions.
|
||||
- Setting it positive gives more weight to older actions.
|
||||
- Setting it negative gives more weight to newer actions.
|
||||
NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
|
||||
results in older actions being weighed more highly than newer actions (the experiments documented in
|
||||
https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be
|
||||
detrimental: doing so aggressively may diminish the benefits of action chunking).
|
||||
|
||||
Here we use an online method for computing the average rather than caching a history of actions in
|
||||
order to compute the average offline. For a simple 1D sequence it looks something like:
|
||||
|
||||
```
|
||||
import torch
|
||||
|
||||
seq = torch.linspace(8, 8.5, 100)
|
||||
print(seq)
|
||||
|
||||
m = 0.01
|
||||
exp_weights = torch.exp(-m * torch.arange(len(seq)))
|
||||
print(exp_weights)
|
||||
|
||||
# Calculate offline
|
||||
avg = (exp_weights * seq).sum() / exp_weights.sum()
|
||||
print("offline", avg)
|
||||
|
||||
# Calculate online
|
||||
for i, item in enumerate(seq):
|
||||
if i == 0:
|
||||
avg = item
|
||||
continue
|
||||
avg *= exp_weights[:i].sum()
|
||||
avg += item * exp_weights[i]
|
||||
avg /= exp_weights[:i+1].sum()
|
||||
print("online", avg)
|
||||
```
|
||||
"""
|
||||
self.chunk_size = chunk_size
|
||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Resets the online computation variables."""
|
||||
self.ensembled_actions = None
|
||||
# (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
|
||||
self.ensembled_actions_count = None
|
||||
|
||||
def update(self, actions: Tensor) -> Tensor:
|
||||
"""
|
||||
Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
|
||||
time steps, and pop/return the next batch of actions in the sequence.
|
||||
"""
|
||||
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
|
||||
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
|
||||
if self.ensembled_actions is None:
|
||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||
# time step of the episode.
|
||||
self.ensembled_actions = actions.clone()
|
||||
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
|
||||
# operations later.
|
||||
self.ensembled_actions_count = torch.ones(
|
||||
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
|
||||
)
|
||||
else:
|
||||
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||
# the online update for those entries.
|
||||
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
|
||||
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
|
||||
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
|
||||
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
|
||||
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
||||
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
|
||||
self.ensembled_actions_count = torch.cat(
|
||||
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
|
||||
)
|
||||
# "Consume" the first action.
|
||||
action, self.ensembled_actions, self.ensembled_actions_count = (
|
||||
self.ensembled_actions[:, 0],
|
||||
self.ensembled_actions[:, 1:],
|
||||
self.ensembled_actions_count[1:],
|
||||
)
|
||||
return action
|
||||
|
||||
|
||||
class ACT(nn.Module):
|
||||
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
||||
|
||||
@@ -292,14 +200,12 @@ class ACT(nn.Module):
|
||||
self.config = config
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
self.use_robot_state = "observation.state" in config.input_shapes
|
||||
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
self.use_input_state = "observation.state" in config.input_shapes
|
||||
if self.config.use_vae:
|
||||
self.vae_encoder = ACTEncoder(config)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||
# Projection layer for joint-space configuration to hidden dimension.
|
||||
if self.use_robot_state:
|
||||
if self.use_input_state:
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
@@ -312,7 +218,7 @@ class ACT(nn.Module):
|
||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
num_input_token_encoder = 1 + config.chunk_size
|
||||
if self.use_robot_state:
|
||||
if self.use_input_state:
|
||||
num_input_token_encoder += 1
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
@@ -320,45 +226,34 @@ class ACT(nn.Module):
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
if self.use_images:
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
||||
# feature map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final feature
|
||||
# map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = ACTEncoder(config)
|
||||
self.decoder = ACTDecoder(config)
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
|
||||
if self.use_robot_state:
|
||||
# [latent, robot_state, image_feature_map_pixels].
|
||||
if self.use_input_state:
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
if self.use_env_state:
|
||||
self.encoder_env_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.dim_model
|
||||
)
|
||||
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
||||
if self.use_images:
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
n_1d_tokens = 1 # for the latent
|
||||
if self.use_robot_state:
|
||||
n_1d_tokens += 1
|
||||
if self.use_env_state:
|
||||
n_1d_tokens += 1
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.use_images:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
num_input_token_decoder = 2 if self.use_input_state else 1
|
||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
@@ -379,13 +274,10 @@ class ACT(nn.Module):
|
||||
"""A forward pass through the Action Chunking Transformer (with optional VAE encoder).
|
||||
|
||||
`batch` should have the following structure:
|
||||
|
||||
{
|
||||
"observation.state" (optional): (B, state_dim) batch of robot states.
|
||||
|
||||
"observation.state": (B, state_dim) batch of robot states.
|
||||
"observation.images": (B, n_cameras, C, H, W) batch of images.
|
||||
AND/OR
|
||||
"observation.environment_state": (B, env_dim) batch of environment states.
|
||||
|
||||
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
|
||||
}
|
||||
|
||||
@@ -399,11 +291,7 @@ class ACT(nn.Module):
|
||||
"action" in batch
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
|
||||
batch_size = (
|
||||
batch["observation.images"]
|
||||
if "observation.images" in batch
|
||||
else batch["observation.environment_state"]
|
||||
).shape[0]
|
||||
batch_size = batch["observation.images"].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and "action" in batch:
|
||||
@@ -411,12 +299,12 @@ class ACT(nn.Module):
|
||||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
if self.use_robot_state:
|
||||
if self.use_input_state:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
|
||||
if self.use_robot_state:
|
||||
if self.use_input_state:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
@@ -426,19 +314,11 @@ class ACT(nn.Module):
|
||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
||||
|
||||
# Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the
|
||||
# sequence depending whether we use the input states or not (cls and robot state)
|
||||
# False means not a padding token.
|
||||
cls_joint_is_pad = torch.full(
|
||||
(batch_size, 2 if self.use_robot_state else 1),
|
||||
False,
|
||||
device=batch["observation.state"].device,
|
||||
)
|
||||
key_padding_mask = torch.cat(
|
||||
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
|
||||
) # (bs, seq+1 or 2)
|
||||
|
||||
# Forward pass through VAE encoder to get the latent PDF parameters.
|
||||
cls_joint_is_pad = torch.full((batch_size, 2), False).to(
|
||||
batch["observation.state"].device
|
||||
) # False: not a padding
|
||||
key_padding_mask = torch.cat([cls_joint_is_pad, batch["action_is_pad"]], axis=1) # (bs, seq+1)
|
||||
cls_token_out = self.vae_encoder(
|
||||
vae_encoder_input.permute(1, 0, 2),
|
||||
pos_embed=pos_embed.permute(1, 0, 2),
|
||||
@@ -459,54 +339,56 @@ class ACT(nn.Module):
|
||||
batch["observation.state"].device
|
||||
)
|
||||
|
||||
# Prepare transformer encoder inputs.
|
||||
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
|
||||
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
|
||||
# Robot state token.
|
||||
if self.use_robot_state:
|
||||
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
|
||||
# Environment state token.
|
||||
if self.use_env_state:
|
||||
encoder_in_tokens.append(
|
||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||
)
|
||||
|
||||
# Prepare all other transformer encoder inputs.
|
||||
# Camera observation features and positional embeddings.
|
||||
if self.use_images:
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
images = batch["observation.images"]
|
||||
|
||||
for cam_index in range(batch["observation.images"].shape[-4]):
|
||||
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
||||
# buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
|
||||
# and move to (sequence, batch, dim).
|
||||
all_cam_features = torch.cat(all_cam_features, axis=-1)
|
||||
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
|
||||
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
|
||||
for cam_index in range(images.shape[-4]):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension.
|
||||
encoder_in = torch.cat(all_cam_features, axis=-1)
|
||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||
|
||||
# Stack all tokens along the sequence dimension.
|
||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
|
||||
# Get positional embeddings for robot state and latent.
|
||||
if self.use_input_state:
|
||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||
|
||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
|
||||
encoder_in = torch.cat(
|
||||
[
|
||||
torch.stack(encoder_in_feats, axis=0),
|
||||
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
||||
]
|
||||
)
|
||||
pos_embed = torch.cat(
|
||||
[
|
||||
self.encoder_robot_and_latent_pos_embed.weight.unsqueeze(1),
|
||||
cam_pos_embed.flatten(2).permute(2, 0, 1),
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
# Forward pass through the transformer modules.
|
||||
encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)
|
||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
|
||||
decoder_in = torch.zeros(
|
||||
(self.config.chunk_size, batch_size, self.config.dim_model),
|
||||
dtype=encoder_in_pos_embed.dtype,
|
||||
device=encoder_in_pos_embed.device,
|
||||
dtype=pos_embed.dtype,
|
||||
device=pos_embed.device,
|
||||
)
|
||||
decoder_out = self.decoder(
|
||||
decoder_in,
|
||||
encoder_out,
|
||||
encoder_pos_embed=encoder_in_pos_embed,
|
||||
encoder_pos_embed=pos_embed,
|
||||
decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1),
|
||||
)
|
||||
|
||||
@@ -558,8 +440,9 @@ class ACTEncoderLayer(nn.Module):
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
q = k = x if pos_embed is None else x + pos_embed
|
||||
x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)
|
||||
x = x[0] # note: [0] to select just the output, not the attention weights
|
||||
x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)[
|
||||
0
|
||||
] # select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
@@ -579,7 +462,10 @@ class ACTDecoder(nn.Module):
|
||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
|
||||
self.norm = nn.LayerNorm(config.dim_model)
|
||||
if config.decoder_norm:
|
||||
self.norm = nn.LayerNorm(config.dim_model)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -592,8 +478,7 @@ class ACTDecoder(nn.Module):
|
||||
x = layer(
|
||||
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
||||
)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -28,12 +28,7 @@ class DiffusionConfig:
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
- Either:
|
||||
- At least one key starting with "observation.image is required as an input.
|
||||
AND/OR
|
||||
- The key "observation.environment_state" is required as input.
|
||||
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||
views. Right now we only support all images having the same shape.
|
||||
- A key starting with "observation.image is required as an input.
|
||||
- "action" is required as an output key.
|
||||
|
||||
Args:
|
||||
@@ -158,33 +153,22 @@ class DiffusionConfig:
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
|
||||
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if len(image_keys) > 0:
|
||||
if self.crop_shape is not None:
|
||||
for image_key in image_keys:
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
)
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key = next(iter(image_keys))
|
||||
for image_key in image_keys:
|
||||
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
|
||||
raise ValueError(
|
||||
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
||||
"expect all image shapes to match."
|
||||
)
|
||||
|
||||
if len(image_keys) != 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
image_key = next(iter(image_keys))
|
||||
if self.crop_shape is not None and (
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
)
|
||||
supported_prediction_types = ["epsilon", "sample"]
|
||||
if self.prediction_type not in supported_prediction_types:
|
||||
raise ValueError(
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
|
||||
TODO(alexander-soare):
|
||||
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||
- Make compatible with multiple image keys.
|
||||
"""
|
||||
|
||||
import math
|
||||
@@ -43,13 +44,7 @@ from lerobot.common.policies.utils import (
|
||||
)
|
||||
|
||||
|
||||
class DiffusionPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "diffusion-policy"],
|
||||
):
|
||||
class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"""
|
||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||
@@ -88,21 +83,23 @@ class DiffusionPolicy(
|
||||
|
||||
self.diffusion = DiffusionModel(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
self.use_env_state = "observation.environment_state" in config.input_shapes
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
if len(image_keys) != 1:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
self.input_image_key = image_keys[0]
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
if len(self.expected_image_keys) > 0:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -117,20 +114,18 @@ class DiffusionPolicy(
|
||||
Schematically this looks like:
|
||||
----------------------------------------------------------------------------------------------
|
||||
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
||||
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h|
|
||||
|observation is used | YES | YES | YES | NO | NO | NO | NO | NO | NO |
|
||||
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||
----------------------------------------------------------------------------------------------
|
||||
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
||||
Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that
|
||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
@@ -149,9 +144,7 @@ class DiffusionPolicy(
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if len(self.expected_image_keys) > 0:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
@@ -175,20 +168,12 @@ class DiffusionModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Build observation encoders (depending on which observations are provided).
|
||||
global_cond_dim = config.input_shapes["observation.state"][0]
|
||||
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
self._use_images = False
|
||||
self._use_env_state = False
|
||||
if num_images > 0:
|
||||
self._use_images = True
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
global_cond_dim += config.input_shapes["observation.environment_state"][0]
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
self.unet = DiffusionConditionalUnet1d(
|
||||
config,
|
||||
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim)
|
||||
* config.n_obs_steps,
|
||||
)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
@@ -235,44 +220,23 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
return sample
|
||||
|
||||
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode image features and concatenate them all together along with the state vector."""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
global_cond_feats = [batch["observation.state"]]
|
||||
# Extract image feature (first combine batch, sequence, and camera index dims).
|
||||
if self._use_images:
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
if self._use_env_state:
|
||||
global_cond_feats.append(batch["observation.environment_state"])
|
||||
|
||||
# Concatenate features then flatten to (B, global_cond_dim).
|
||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||
|
||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This function expects `batch` to have:
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
|
||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||
AND/OR
|
||||
"observation.environment_state": (B, environment_dim)
|
||||
"observation.image": (B, n_obs_steps, C, H, W)
|
||||
}
|
||||
"""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Encode image features and concatenate them all together along with the state vector.
|
||||
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||
# Extract image feature (first combine batch and sequence dims).
|
||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||
|
||||
# run sampling
|
||||
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
@@ -289,28 +253,28 @@ class DiffusionModel(nn.Module):
|
||||
This function expects `batch` to have (at least):
|
||||
{
|
||||
"observation.state": (B, n_obs_steps, state_dim)
|
||||
|
||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||
AND/OR
|
||||
"observation.environment_state": (B, environment_dim)
|
||||
|
||||
"observation.image": (B, n_obs_steps, C, H, W)
|
||||
"action": (B, horizon, action_dim)
|
||||
"action_is_pad": (B, horizon)
|
||||
}
|
||||
"""
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
|
||||
assert "observation.images" in batch or "observation.environment_state" in batch
|
||||
n_obs_steps = batch["observation.state"].shape[1]
|
||||
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
horizon = batch["action"].shape[1]
|
||||
assert horizon == self.config.horizon
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Encode image features and concatenate them all together along with the state vector.
|
||||
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||
# Extract image feature (first combine batch and sequence dims).
|
||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||
|
||||
trajectory = batch["action"]
|
||||
|
||||
# Forward diffusion.
|
||||
trajectory = batch["action"]
|
||||
# Sample noise to add to the trajectory.
|
||||
eps = torch.randn(trajectory.shape, device=trajectory.device)
|
||||
# Sample a random noising timestep for each item in the batch.
|
||||
@@ -341,8 +305,7 @@ class DiffusionModel(nn.Module):
|
||||
if self.config.do_mask_loss_for_padding:
|
||||
if "action_is_pad" not in batch:
|
||||
raise ValueError(
|
||||
"You need to provide 'action_is_pad' in the batch when "
|
||||
f"{self.config.do_mask_loss_for_padding=}."
|
||||
f"You need to provide 'action_is_pad' in the batch when {self.config.do_mask_loss_for_padding=}."
|
||||
)
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||
@@ -465,7 +428,7 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
assert len(image_keys) == 1
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
|
||||
@@ -28,15 +28,9 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||
logging.warning(
|
||||
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
|
||||
)
|
||||
|
||||
# OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid
|
||||
# issues with mutable defaults. This filter changes all lists to tuples.
|
||||
def list_to_tuple(item):
|
||||
return tuple(item) if isinstance(item, list) else item
|
||||
|
||||
policy_cfg = policy_cfg_class(
|
||||
**{
|
||||
k: list_to_tuple(v)
|
||||
k: v
|
||||
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
|
||||
if k in expected_kwargs
|
||||
}
|
||||
@@ -61,11 +55,6 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
return ACTPolicy, ACTConfig
|
||||
elif name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
return VQBeTPolicy, VQBeTConfig
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -86,9 +75,7 @@ def make_policy(
|
||||
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
|
||||
"""
|
||||
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
||||
raise ValueError(
|
||||
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
|
||||
)
|
||||
raise ValueError("Only one of `pretrained_policy_name_or_path` and `dataset_stats` may be provided.")
|
||||
|
||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||
|
||||
@@ -99,10 +86,9 @@ def make_policy(
|
||||
else:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with,
|
||||
# pretrained weights which are then loaded into a fresh policy with the desired config. This PR in
|
||||
# huggingface_hub should make it possible to avoid the hack:
|
||||
# https://github.com/huggingface/huggingface_hub/pull/2274.
|
||||
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained
|
||||
# weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should
|
||||
# make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274.
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
|
||||
|
||||
|
||||
@@ -132,7 +132,6 @@ class Normalize(nn.Module):
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
@@ -198,7 +197,6 @@ class Unnormalize(nn.Module):
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class Policy(Protocol):
|
||||
other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
|
||||
@@ -25,16 +25,12 @@ class TDMPCConfig:
|
||||
camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift`.
|
||||
|
||||
Args:
|
||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
action repeats in Q-learning or ask your favorite chatbot)
|
||||
horizon: Horizon for model predictive control.
|
||||
n_action_steps: Number of action steps to take from the plan given by model predictive control. This
|
||||
is an alternative to using action repeats. If this is set to more than 1, then we require
|
||||
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
|
||||
approach of using multiple steps from the plan is not in the original implementation.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
@@ -104,7 +100,6 @@ class TDMPCConfig:
|
||||
# Input / output structure.
|
||||
n_action_repeats: int = 2
|
||||
horizon: int = 5
|
||||
n_action_steps: int = 1
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -163,18 +158,17 @@ class TDMPCConfig:
|
||||
"""Input validation (not exhaustive)."""
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if len(image_keys) > 1:
|
||||
if len(image_keys) != 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
|
||||
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
)
|
||||
if len(image_keys) > 0:
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
)
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
@@ -185,12 +179,3 @@ class TDMPCConfig:
|
||||
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
|
||||
"information."
|
||||
)
|
||||
if self.n_action_steps > 1:
|
||||
if self.n_action_repeats != 1:
|
||||
raise ValueError(
|
||||
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
|
||||
)
|
||||
if not self.use_mpc:
|
||||
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
|
||||
@@ -19,10 +19,14 @@
|
||||
The comments in this code may sometimes refer to these references:
|
||||
TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955)
|
||||
FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029)
|
||||
|
||||
TODO(alexander-soare): Make rollout work for batch sizes larger than 1.
|
||||
TODO(alexander-soare): Use batch-first throughout.
|
||||
"""
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
@@ -41,13 +45,7 @@ from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
|
||||
|
||||
class TDMPCPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "tdmpc"],
|
||||
):
|
||||
class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"""Implementation of TD-MPC learning + inference.
|
||||
|
||||
Please note several warnings for this policy.
|
||||
@@ -58,11 +56,9 @@ class TDMPCPolicy(
|
||||
process communication to use the xarm environment from FOWM. This is because our xarm
|
||||
environment uses newer dependencies and does not match the environment in FOWM. See
|
||||
https://github.com/huggingface/lerobot/pull/103 for implementation details.
|
||||
- We have NOT checked that training on LeRobot reproduces the results from FOWM.
|
||||
- Nevertheless, we have verified that we can train TD-MPC for PushT. See
|
||||
`lerobot/configs/policy/tdmpc_pusht_keypoints.yaml`.
|
||||
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
|
||||
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
|
||||
match our xarm environment.
|
||||
match our xarm environment.
|
||||
"""
|
||||
|
||||
name = "tdmpc"
|
||||
@@ -78,6 +74,22 @@ class TDMPCPolicy(
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
logging.warning(
|
||||
"""
|
||||
Please note several warnings for this policy.
|
||||
|
||||
- Evaluation of pretrained weights created with the original FOWM code
|
||||
(https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a
|
||||
model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across
|
||||
to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter-
|
||||
process communication to use the xarm environment from FOWM. This is because our xarm
|
||||
environment uses newer dependencies and does not match the environment in FOWM. See
|
||||
https://github.com/huggingface/lerobot/pull/103 for implementation details.
|
||||
- We have NOT checked that training on LeRobot reproduces SOTA results. This is a TODO.
|
||||
- Our current xarm datasets were generated using the environment from FOWM. Therefore they do not
|
||||
match our xarm environment.
|
||||
"""
|
||||
)
|
||||
|
||||
if config is None:
|
||||
config = TDMPCConfig()
|
||||
@@ -102,14 +114,8 @@ class TDMPCPolicy(
|
||||
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
self._use_image = False
|
||||
self._use_env_state = False
|
||||
if len(image_keys) > 0:
|
||||
assert len(image_keys) == 1
|
||||
self._use_image = True
|
||||
self.input_image_key = image_keys[0]
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
assert len(image_keys) == 1
|
||||
self.input_image_key = image_keys[0]
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -119,24 +125,19 @@ class TDMPCPolicy(
|
||||
called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=1),
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
"action": deque(maxlen=self.config.n_action_repeats),
|
||||
}
|
||||
if self._use_image:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if self._use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -150,57 +151,49 @@ class TDMPCPolicy(
|
||||
batch[key] = batch[key][:, 0]
|
||||
|
||||
# NOTE: Order of observations matters here.
|
||||
encode_keys = []
|
||||
if self._use_image:
|
||||
encode_keys.append("observation.image")
|
||||
if self._use_env_state:
|
||||
encode_keys.append("observation.environment_state")
|
||||
encode_keys.append("observation.state")
|
||||
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||
if self.config.use_mpc: # noqa: SIM108
|
||||
actions = self.plan(z) # (horizon, batch, action_dim)
|
||||
z = self.model.encode({k: batch[k] for k in ["observation.image", "observation.state"]})
|
||||
if self.config.use_mpc:
|
||||
batch_size = batch["observation.image"].shape[0]
|
||||
# Batch processing is not handled in MPC mode, so process the batch in a loop.
|
||||
action = [] # will be a batch of actions for one step
|
||||
for i in range(batch_size):
|
||||
# Note: self.plan does not handle batches, hence the squeeze.
|
||||
action.append(self.plan(z[i]))
|
||||
action = torch.stack(action)
|
||||
else:
|
||||
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
|
||||
# sequence dimension like in the MPC branch.
|
||||
actions = self.model.pi(z).unsqueeze(0)
|
||||
# Plan with the policy (π) alone.
|
||||
action = self.model.pi(z)
|
||||
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
self.unnormalize_outputs({"action": action})["action"]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.n_action_repeats > 1:
|
||||
for _ in range(self.config.n_action_repeats):
|
||||
self._queues["action"].append(actions[0])
|
||||
else:
|
||||
# Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action.
|
||||
self._queues["action"].extend(actions[: self.config.n_action_steps])
|
||||
for _ in range(self.config.n_action_repeats):
|
||||
self._queues["action"].append(action)
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
return torch.clamp(action, -1, 1)
|
||||
|
||||
@torch.no_grad()
|
||||
def plan(self, z: Tensor) -> Tensor:
|
||||
"""Plan sequence of actions using TD-MPC inference.
|
||||
"""Plan next action using TD-MPC inference.
|
||||
|
||||
Args:
|
||||
z: (batch, latent_dim,) tensor for the initial state.
|
||||
z: (latent_dim,) tensor for the initial state.
|
||||
Returns:
|
||||
(horizon, batch, action_dim,) tensor for the planned trajectory of actions.
|
||||
(action_dim,) tensor for the next action.
|
||||
|
||||
TODO(alexander-soare) Extend this to be able to work with batches.
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch_size = z.shape[0]
|
||||
|
||||
# Sample Nπ trajectories from the policy.
|
||||
pi_actions = torch.empty(
|
||||
self.config.horizon,
|
||||
self.config.n_pi_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=device,
|
||||
)
|
||||
if self.config.n_pi_samples > 0:
|
||||
_z = einops.repeat(z, "b d -> n b d", n=self.config.n_pi_samples)
|
||||
_z = einops.repeat(z, "d -> n d", n=self.config.n_pi_samples)
|
||||
for t in range(self.config.horizon):
|
||||
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
|
||||
# helpful for CEM.
|
||||
@@ -209,14 +202,12 @@ class TDMPCPolicy(
|
||||
|
||||
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||
# trajectories.
|
||||
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||
z = einops.repeat(z, "d -> n d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||
|
||||
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(
|
||||
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
|
||||
)
|
||||
mean = torch.zeros(self.config.horizon, self.config.output_shapes["action"][0], device=device)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
@@ -227,7 +218,6 @@ class TDMPCPolicy(
|
||||
std_normal_noise = torch.randn(
|
||||
self.config.horizon,
|
||||
self.config.n_gaussian_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=std.device,
|
||||
)
|
||||
@@ -236,24 +226,21 @@ class TDMPCPolicy(
|
||||
# Compute elite actions.
|
||||
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
|
||||
value = self.estimate_value(z, actions).nan_to_num_(0)
|
||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
|
||||
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
|
||||
# (horizon, n_elites, batch, action_dim)
|
||||
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
|
||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices
|
||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||
|
||||
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
|
||||
# Update guassian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||
max_value = elite_value.max(0)[0]
|
||||
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
|
||||
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
|
||||
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
|
||||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||
score /= score.sum(axis=0, keepdim=True)
|
||||
# (horizon, batch, action_dim)
|
||||
_mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1)
|
||||
score /= score.sum()
|
||||
_mean = torch.sum(einops.rearrange(score, "n -> n 1") * elite_actions, dim=1)
|
||||
_std = torch.sqrt(
|
||||
torch.sum(
|
||||
einops.rearrange(score, "n b -> n b 1")
|
||||
* (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2,
|
||||
einops.rearrange(score, "n -> n 1")
|
||||
* (elite_actions - einops.rearrange(_mean, "h d -> h 1 d")) ** 2,
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
@@ -268,9 +255,11 @@ class TDMPCPolicy(
|
||||
|
||||
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
|
||||
# scores from the last iteration.
|
||||
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
|
||||
actions = elite_actions[:, torch.multinomial(score, 1).item()]
|
||||
|
||||
return actions
|
||||
# Select only the first action
|
||||
action = actions[0]
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_value(self, z: Tensor, actions: Tensor):
|
||||
@@ -322,17 +311,12 @@ class TDMPCPolicy(
|
||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss."""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
@@ -342,12 +326,12 @@ class TDMPCPolicy(
|
||||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"] # (t, b, action_dim)
|
||||
reward = batch["next.reward"] # (t, b)
|
||||
action = batch["action"] # (t, b)
|
||||
reward = batch["next.reward"] # (t,)
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
# Apply random image augmentations.
|
||||
if self._use_image and self.config.max_random_shift_ratio > 0:
|
||||
if self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
observations["observation.image"],
|
||||
@@ -359,9 +343,7 @@ class TDMPCPolicy(
|
||||
for k in observations:
|
||||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon, batch_size = next_observations[
|
||||
"observation.image" if self._use_image else "observation.environment_state"
|
||||
].shape[:2]
|
||||
horizon = next_observations["observation.image"].shape[0]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
@@ -431,8 +413,7 @@ class TDMPCPolicy(
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
q_value_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(
|
||||
F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
reduction="none",
|
||||
@@ -481,11 +462,10 @@ class TDMPCPolicy(
|
||||
action_preds = self.model.pi(z_preds[:-1]) # (t, b, a)
|
||||
# Calculate the MSE between the actions and the action predictions.
|
||||
# Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation
|
||||
# gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to
|
||||
# multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action
|
||||
# dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop
|
||||
# the 0.5 as we instead make a configuration parameter for it (see below where we compute the total
|
||||
# loss).
|
||||
# gaussian) and sums over the action dimension. Computing the log probability amounts to multiplying
|
||||
# the MSE by 0.5 and adding a constant offset (the log(2*pi) term) . Here we drop the constant offset
|
||||
# as it doesn't change the optimization step, and we drop the 0.5 as we instead make a configuration
|
||||
# parameter for it (see below where we compute the total loss).
|
||||
mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b)
|
||||
# NOTE: The original implementation does not take the sum over the temporal dimension like with the
|
||||
# other losses.
|
||||
@@ -746,16 +726,6 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
@@ -764,11 +734,8 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
# NOTE: Order of observations matters here.
|
||||
if "observation.image" in self.config.input_shapes:
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
|
||||
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class VQBeTConfig:
|
||||
"""Configuration class for VQ-BeT.
|
||||
|
||||
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
- At least one key starting with "observation.image is required as an input.
|
||||
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||
views. Right now we only support all images having the same shape.
|
||||
- "action" is required as an output key.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts.
|
||||
action_chunk_size: Action chunk size of each action prediction token.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
n_vqvae_training_steps: Number of optimization steps for training Residual VQ.
|
||||
vqvae_n_embed: Number of embedding vectors in the RVQ dictionary (each layer).
|
||||
vqvae_embedding_dim: Dimension of each embedding vector in the RVQ dictionary.
|
||||
vqvae_enc_hidden_dim: Size of hidden dimensions of Encoder / Decoder part of Residaul VQ-VAE
|
||||
gpt_block_size: Max block size of minGPT (should be larger than the number of input tokens)
|
||||
gpt_input_dim: Size of output input of GPT. This is also used as the dimension of observation features.
|
||||
gpt_output_dim: Size of output dimension of GPT. This is also used as a input dimension of offset / bin prediction headers.
|
||||
gpt_n_layer: Number of layers of GPT
|
||||
gpt_n_head: Number of headers of GPT
|
||||
gpt_hidden_dim: Size of hidden dimensions of GPT
|
||||
dropout: Dropout rate for GPT
|
||||
mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
|
||||
offset_loss_weight: A constant that is multiplied to the offset loss
|
||||
primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss
|
||||
secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss
|
||||
bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT
|
||||
sequentially_select: Whether select code of primary / secondary as sequentially (pick primary code,
|
||||
and then select secodnary code), or at the same time.
|
||||
"""
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 5
|
||||
n_action_pred_token: int = 3
|
||||
action_chunk_size: int = 5
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
# VQ-VAE
|
||||
n_vqvae_training_steps: int = 20000
|
||||
vqvae_n_embed: int = 16
|
||||
vqvae_embedding_dim: int = 256
|
||||
vqvae_enc_hidden_dim: int = 128
|
||||
# VQ-BeT
|
||||
gpt_block_size: int = 500
|
||||
gpt_input_dim: int = 512
|
||||
gpt_output_dim: int = 512
|
||||
gpt_n_layer: int = 8
|
||||
gpt_n_head: int = 8
|
||||
gpt_hidden_dim: int = 512
|
||||
dropout: float = 0.1
|
||||
mlp_hidden_dim: int = 1024
|
||||
offset_loss_weight: float = 10000.0
|
||||
primary_code_loss_weight: float = 5.0
|
||||
secondary_code_loss_weight: float = 0.5
|
||||
bet_softmax_temperature: float = 0.1
|
||||
sequentially_select: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if self.crop_shape is not None:
|
||||
for image_key in image_keys:
|
||||
if (
|
||||
self.crop_shape[0] > self.input_shapes[image_key][1]
|
||||
or self.crop_shape[1] > self.input_shapes[image_key][2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {self.input_shapes[image_key]} for "
|
||||
"`input_shapes[{image_key}]`."
|
||||
)
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key = next(iter(image_keys))
|
||||
for image_key in image_keys:
|
||||
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
|
||||
raise ValueError(
|
||||
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
|
||||
"expect all image shapes to match."
|
||||
)
|
||||
@@ -1,959 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
|
||||
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from collections import deque
|
||||
from typing import Callable, List
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
|
||||
class VQBeTPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vqbet"],
|
||||
):
|
||||
"""
|
||||
VQ-BeT Policy as per "Behavior Generation with Latent Actions"
|
||||
"""
|
||||
|
||||
name = "vqbet"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VQBeTConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = VQBeTConfig()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.images": deque(maxlen=self.config.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.action_chunk_size),
|
||||
}
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
warnings.warn(
|
||||
"To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ.",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||
|
||||
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
# since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
# loss: total loss of training RVQ
|
||||
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
|
||||
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
|
||||
loss, n_different_codes, n_different_combinations, recon_l1_error = (
|
||||
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
|
||||
)
|
||||
return {
|
||||
"loss": loss,
|
||||
"n_different_codes": n_different_codes,
|
||||
"n_different_combinations": n_different_combinations,
|
||||
"recon_l1_error": recon_l1_error,
|
||||
}
|
||||
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
|
||||
_, loss_dict = self.vqbet(batch, rollout=False)
|
||||
|
||||
return loss_dict
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
|
||||
|
||||
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
|
||||
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
|
||||
|
||||
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
|
||||
-----------------------------------------------------
|
||||
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
|
||||
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
|
||||
| ... | ... | ... | ... |
|
||||
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
|
||||
-----------------------------------------------------
|
||||
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
|
||||
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
|
||||
|
||||
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
|
||||
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
|
||||
linear mapping (in_channels, H, W) -> (num_kp, H, W).
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape, num_kp=None):
|
||||
"""
|
||||
Args:
|
||||
input_shape (list): (C, H, W) input feature map shape.
|
||||
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3
|
||||
self._in_c, self._in_h, self._in_w = input_shape
|
||||
|
||||
if num_kp is not None:
|
||||
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||
self._out_c = num_kp
|
||||
else:
|
||||
self.nets = None
|
||||
self._out_c = self._in_c
|
||||
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# register as buffer so it's moved to the correct device.
|
||||
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
||||
|
||||
def forward(self, features: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
features: (B, C, H, W) input feature maps.
|
||||
Returns:
|
||||
(B, K, 2) image-space coordinates of keypoints.
|
||||
"""
|
||||
if self.nets is not None:
|
||||
features = self.nets(features)
|
||||
|
||||
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||
features = features.reshape(-1, self._in_h * self._in_w)
|
||||
# 2d softmax normalization
|
||||
attention = F.softmax(features, dim=-1)
|
||||
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
|
||||
expected_xy = attention @ self.pos_grid
|
||||
# reshape to [B, K, 2]
|
||||
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||
|
||||
return feature_keypoints
|
||||
|
||||
|
||||
class VQBeTModel(nn.Module):
|
||||
"""VQ-BeT: The underlying neural network for VQ-BeT
|
||||
|
||||
Note: In this code we use the terms `rgb_encoder`, 'policy', `action_head`. The meanings are as follows.
|
||||
- The `rgb_encoder` process rgb-style image observations to one-dimensional embedding vectors
|
||||
- A `policy` is a minGPT architecture, that takes observation sequences and action query tokens to generate `features`.
|
||||
- These `features` pass through the action head, which passes through the code prediction, offset prediction head,
|
||||
and finally generates a prediction for the action chunks.
|
||||
|
||||
-------------------------------** legend **-------------------------------
|
||||
│ n = n_obs_steps, p = n_action_pred_token, c = action_chunk_size) │
|
||||
│ o_{t} : visual observation at timestep {t} │
|
||||
│ s_{t} : state observation at timestep {t} │
|
||||
│ a_{t} : action at timestep {t} │
|
||||
│ A_Q : action_query_token │
|
||||
--------------------------------------------------------------------------
|
||||
|
||||
|
||||
Training Phase 1. Discretize action using Residual VQ (for config.n_vqvae_training_steps steps)
|
||||
|
||||
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ │ │ │ │ │
|
||||
│ RVQ encoder │ ─► │ Residual │ ─► │ RVQ Decoder │
|
||||
│ (a_{t}~a_{t+p}) │ │ Code Quantizer │ │ │
|
||||
│ │ │ │ │ │
|
||||
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||
|
||||
Training Phase 2.
|
||||
|
||||
timestep {t-n+1} timestep {t-n+2} timestep {t}
|
||||
┌─────┴─────┐ ┌─────┴─────┐ ┌─────┴─────┐
|
||||
|
||||
o_{t-n+1} o_{t-n+2} ... o_{t}
|
||||
│ │ │
|
||||
│ s_{t-n+1} │ s_{t-n+2} ... │ s_{t} p
|
||||
│ │ │ │ │ │ ┌───────┴───────┐
|
||||
│ │ A_Q │ │ A_Q ... │ │ A_Q ... A_Q
|
||||
│ │ │ │ │ │ │ │ │ │
|
||||
┌───▼─────▼─────▼─────▼─────▼─────▼─────────────────▼─────▼─────▼───────────────▼───┐
|
||||
│ │
|
||||
│ GPT │ => policy
|
||||
│ │
|
||||
└───────────────▼─────────────────▼─────────────────────────────▼───────────────▼───┘
|
||||
│ │ │ │
|
||||
┌───┴───┐ ┌───┴───┐ ┌───┴───┐ ┌───┴───┐
|
||||
code offset code offset code offset code offset
|
||||
▼ │ ▼ │ ▼ │ ▼ │ => action_head
|
||||
RVQ Decoder │ RVQ Decoder │ RVQ Decoder │ RVQ Decoder │
|
||||
└── + ──┘ └── + ──┘ └── + ──┘ └── + ──┘
|
||||
▼ ▼ ▼ ▼
|
||||
action chunk action chunk action chunk action chunk
|
||||
a_{t-n+1} ~ a_{t-n+2} ~ a_{t} ~ ... a_{t+p-1} ~
|
||||
a_{t-n+c} a_{t-n+c+1} a_{t+c-1} a_{t+p+c-1}
|
||||
|
||||
▼
|
||||
ONLY this chunk is used in rollout!
|
||||
"""
|
||||
|
||||
def __init__(self, config: VQBeTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.rgb_encoder = VQBeTRgbEncoder(config)
|
||||
self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
|
||||
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
|
||||
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
|
||||
self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
|
||||
|
||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||
self.state_projector = MLP(
|
||||
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self.rgb_feature_projector = MLP(
|
||||
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
|
||||
# GPT part of VQ-BeT
|
||||
self.policy = GPT(config)
|
||||
# bin prediction head / offset prediction head part of VQ-BeT
|
||||
self.action_head = VQBeTHead(config)
|
||||
|
||||
# Action tokens for: each observation step, the current action token, and all future action tokens.
|
||||
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
|
||||
self.register_buffer(
|
||||
"select_target_actions_indices",
|
||||
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
|
||||
)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.images"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Extract image feature (first combine batch and sequence dims).
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
|
||||
)
|
||||
|
||||
# Arrange prior and current observation step tokens as shown in the class docstring.
|
||||
# First project features to token dimension.
|
||||
rgb_tokens = self.rgb_feature_projector(
|
||||
img_features
|
||||
) # (batch, obs_step, number of different cameras, projection dims)
|
||||
input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))]
|
||||
input_tokens.append(
|
||||
self.state_projector(batch["observation.state"])
|
||||
) # (batch, obs_step, projection dims)
|
||||
input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
|
||||
# Interleave tokens by stacking and rearranging.
|
||||
input_tokens = torch.stack(input_tokens, dim=2)
|
||||
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
|
||||
|
||||
len_additional_action_token = self.config.n_action_pred_token - 1
|
||||
future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||
|
||||
# add additional action query tokens for predicting future action chunks
|
||||
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
|
||||
|
||||
# get action features (pass through GPT)
|
||||
features = self.policy(input_tokens)
|
||||
# len(self.config.input_shapes) is the number of different observation modes. this line gets the index of action prompt tokens.
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len(
|
||||
self.config.input_shapes
|
||||
)
|
||||
|
||||
# only extract the output tokens at the position of action query:
|
||||
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251).
|
||||
# Thus, it predict historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
||||
features = torch.cat(
|
||||
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
|
||||
)
|
||||
# pass through action head
|
||||
action_head_output = self.action_head(features)
|
||||
# if rollout, VQ-BeT don't calculate loss
|
||||
if rollout:
|
||||
return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
|
||||
batch_size, self.config.action_chunk_size, -1
|
||||
)
|
||||
# else, it calculate overall loss (bin prediction loss, and offset loss)
|
||||
else:
|
||||
output = batch["action"][:, self.select_target_actions_indices]
|
||||
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
|
||||
return action_head_output, loss
|
||||
|
||||
|
||||
class VQBeTHead(nn.Module):
|
||||
def __init__(self, config: VQBeTConfig):
|
||||
"""
|
||||
VQBeTHead takes output of GPT layers, and pass the feature through bin prediction head (`self.map_to_cbet_preds_bin`), and offset prediction head (`self.map_to_cbet_preds_offset`)
|
||||
|
||||
self.map_to_cbet_preds_bin: outputs probability of each code (for each layer).
|
||||
The input dimension of `self.map_to_cbet_preds_bin` is same with the output of GPT,
|
||||
and the output dimension of `self.map_to_cbet_preds_bin` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed`.
|
||||
if the agent select the code sequentially, we use self.map_to_cbet_preds_primary_bin and self.map_to_cbet_preds_secondary_bin instead of self._map_to_cbet_preds_bin.
|
||||
|
||||
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
|
||||
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
|
||||
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.output_shapes["action"][0]`.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# init vqvae
|
||||
self.vqvae_model = VqVae(config)
|
||||
if config.sequentially_select:
|
||||
self.map_to_cbet_preds_primary_bin = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
hidden_channels=[self.config.vqvae_n_embed],
|
||||
)
|
||||
self.map_to_cbet_preds_secondary_bin = MLP(
|
||||
in_channels=config.gpt_output_dim + self.config.vqvae_n_embed,
|
||||
hidden_channels=[self.config.vqvae_n_embed],
|
||||
)
|
||||
else:
|
||||
self.map_to_cbet_preds_bin = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed],
|
||||
)
|
||||
self.map_to_cbet_preds_offset = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
hidden_channels=[
|
||||
self.vqvae_model.vqvae_num_layers
|
||||
* self.config.vqvae_n_embed
|
||||
* config.action_chunk_size
|
||||
* config.output_shapes["action"][0],
|
||||
],
|
||||
)
|
||||
# loss
|
||||
self._focal_loss_fn = FocalLoss(gamma=2.0)
|
||||
|
||||
def discretize(self, n_vqvae_training_steps, actions):
|
||||
# Resize the action sequence data to fit the action chunk size using a sliding window approach.
|
||||
actions = torch.cat(
|
||||
[
|
||||
actions[:, j : j + self.config.action_chunk_size, :]
|
||||
for j in range(actions.shape[1] + 1 - self.config.action_chunk_size)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
# `actions` is a tensor of shape (new_batch, action_chunk_size, action_dim) where new_batch is the number of possible chunks created from the original sequences using the sliding window.
|
||||
|
||||
loss, metric = self.vqvae_model.vqvae_forward(actions)
|
||||
n_different_codes = sum(
|
||||
[len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)]
|
||||
)
|
||||
n_different_combinations = len(torch.unique(metric[2], dim=0))
|
||||
recon_l1_error = metric[0].detach().cpu().item()
|
||||
self.vqvae_model.optimized_steps += 1
|
||||
# if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
|
||||
if self.vqvae_model.optimized_steps >= n_vqvae_training_steps:
|
||||
self.vqvae_model.discretized = torch.tensor(True)
|
||||
self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)
|
||||
print("Finished discretizing action data!")
|
||||
self.vqvae_model.eval()
|
||||
for param in self.vqvae_model.vq_layer.parameters():
|
||||
param.requires_grad = False
|
||||
return loss, n_different_codes, n_different_combinations, recon_l1_error
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
# N is the batch size, and T is number of action query tokens, which are process through same GPT
|
||||
N, T, _ = x.shape
|
||||
# we calculate N and T side parallely. Thus, the dimensions would be
|
||||
# (batch size * number of action query tokens, action chunk size, action dimension)
|
||||
x = einops.rearrange(x, "N T WA -> (N T) WA")
|
||||
|
||||
# sample offsets
|
||||
cbet_offsets = self.map_to_cbet_preds_offset(x)
|
||||
cbet_offsets = einops.rearrange(
|
||||
cbet_offsets,
|
||||
"(NT) (G C WA) -> (NT) G C WA",
|
||||
G=self.vqvae_model.vqvae_num_layers,
|
||||
C=self.config.vqvae_n_embed,
|
||||
)
|
||||
# if self.config.sequentially_select is True, bin prediction head first sample the primary code, and then sample secondary code
|
||||
if self.config.sequentially_select:
|
||||
cbet_primary_logits = self.map_to_cbet_preds_primary_bin(x)
|
||||
|
||||
# select primary bin first
|
||||
cbet_primary_probs = torch.softmax(
|
||||
cbet_primary_logits / self.config.bet_softmax_temperature, dim=-1
|
||||
)
|
||||
NT, choices = cbet_primary_probs.shape
|
||||
sampled_primary_centers = einops.rearrange(
|
||||
torch.multinomial(cbet_primary_probs.view(-1, choices), num_samples=1),
|
||||
"(NT) 1 -> NT",
|
||||
NT=NT,
|
||||
)
|
||||
|
||||
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
|
||||
torch.cat(
|
||||
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
|
||||
axis=1,
|
||||
)
|
||||
)
|
||||
cbet_secondary_probs = torch.softmax(
|
||||
cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1
|
||||
)
|
||||
sampled_secondary_centers = einops.rearrange(
|
||||
torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1),
|
||||
"(NT) 1 -> NT",
|
||||
NT=NT,
|
||||
)
|
||||
sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1)
|
||||
cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1)
|
||||
# if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once.
|
||||
else:
|
||||
cbet_logits = self.map_to_cbet_preds_bin(x)
|
||||
cbet_logits = einops.rearrange(
|
||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
||||
)
|
||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||
NT, G, choices = cbet_probs.shape
|
||||
sampled_centers = einops.rearrange(
|
||||
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
||||
"(NT G) 1 -> NT G",
|
||||
NT=NT,
|
||||
)
|
||||
|
||||
device = get_device_from_parameters(self)
|
||||
indices = (
|
||||
torch.arange(NT, device=device).unsqueeze(1),
|
||||
torch.arange(self.vqvae_model.vqvae_num_layers, device=device).unsqueeze(0),
|
||||
sampled_centers,
|
||||
)
|
||||
# Use advanced indexing to sample the values (Extract the only offsets corresponding to the sampled codes.)
|
||||
sampled_offsets = cbet_offsets[indices]
|
||||
# Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction
|
||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||
with torch.no_grad():
|
||||
# Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder
|
||||
return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach()
|
||||
# pass the centroids through decoder to get actions.
|
||||
decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach()
|
||||
# reshaped extracted offset to match with decoded centroids
|
||||
sampled_offsets = einops.rearrange(
|
||||
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
|
||||
)
|
||||
# add offset and decoded centroids
|
||||
predicted_action = decoded_action + sampled_offsets
|
||||
predicted_action = einops.rearrange(
|
||||
predicted_action,
|
||||
"(N T) W A -> N T (W A)",
|
||||
N=N,
|
||||
T=T,
|
||||
W=self.config.action_chunk_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"cbet_logits": cbet_logits,
|
||||
"predicted_action": predicted_action,
|
||||
"sampled_centers": sampled_centers,
|
||||
"decoded_action": decoded_action,
|
||||
}
|
||||
|
||||
def loss_fn(self, pred, target, **kwargs):
|
||||
"""
|
||||
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
|
||||
|
||||
predicted_action: predicted action chunk (offset + decoded centroids)
|
||||
sampled_centers: sampled centroids (code of RVQ)
|
||||
decoded_action: decoded action, which is produced by passing sampled_centers through RVQ decoder
|
||||
NT: batch size * T
|
||||
T: number of action query tokens, which are process through same GPT
|
||||
cbet_logits: probability of all codes in each layer
|
||||
"""
|
||||
action_seq = target
|
||||
predicted_action = pred["predicted_action"]
|
||||
sampled_centers = pred["sampled_centers"]
|
||||
decoded_action = pred["decoded_action"]
|
||||
NT = predicted_action.shape[0] * predicted_action.shape[1]
|
||||
|
||||
cbet_logits = pred["cbet_logits"]
|
||||
|
||||
predicted_action = einops.rearrange(
|
||||
predicted_action, "N T (W A) -> (N T) W A", W=self.config.action_chunk_size
|
||||
)
|
||||
|
||||
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
||||
# Figure out the loss for the actions.
|
||||
# First, we need to find the closest cluster center for each ground truth action.
|
||||
with torch.no_grad():
|
||||
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||
|
||||
# Now we can compute the loss.
|
||||
|
||||
# offset loss is L1 distance between the predicted action and ground truth action
|
||||
offset_loss = F.l1_loss(action_seq, predicted_action)
|
||||
|
||||
# calculate primary code prediction loss
|
||||
cbet_loss1 = self._focal_loss_fn(
|
||||
cbet_logits[:, 0, :],
|
||||
action_bins[:, 0],
|
||||
)
|
||||
# calculate secondary code prediction loss
|
||||
cbet_loss2 = self._focal_loss_fn(
|
||||
cbet_logits[:, 1, :],
|
||||
action_bins[:, 1],
|
||||
)
|
||||
# add all the prediction loss
|
||||
cbet_loss = (
|
||||
cbet_loss1 * self.config.primary_code_loss_weight
|
||||
+ cbet_loss2 * self.config.secondary_code_loss_weight
|
||||
)
|
||||
|
||||
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
|
||||
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
|
||||
|
||||
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
|
||||
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
|
||||
offset_action_error = torch.mean(torch.abs(action_seq - predicted_action))
|
||||
action_error_max = torch.max(torch.abs(action_seq - predicted_action))
|
||||
|
||||
loss = cbet_loss + self.config.offset_loss_weight * offset_loss
|
||||
|
||||
loss_dict = {
|
||||
"loss": loss,
|
||||
"classification_loss": cbet_loss.detach().cpu().item(),
|
||||
"offset_loss": offset_loss.detach().cpu().item(),
|
||||
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
|
||||
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
|
||||
"vq_action_error": vq_action_error.detach().cpu().item(),
|
||||
"offset_action_error": offset_action_error.detach().cpu().item(),
|
||||
"action_error_max": action_error_max.detach().cpu().item(),
|
||||
"action_mse_error": action_mse_error.detach().cpu().item(),
|
||||
}
|
||||
return loss_dict
|
||||
|
||||
|
||||
class VQBeTOptimizer(torch.optim.Adam):
|
||||
def __init__(self, policy, cfg):
|
||||
vqvae_params = (
|
||||
list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
|
||||
+ list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
|
||||
+ list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
|
||||
)
|
||||
decay_params, no_decay_params = policy.vqbet.policy.configure_parameters()
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(policy.vqbet.rgb_encoder.parameters())
|
||||
+ list(policy.vqbet.state_projector.parameters())
|
||||
+ list(policy.vqbet.rgb_feature_projector.parameters())
|
||||
+ [policy.vqbet.action_token]
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
|
||||
)
|
||||
|
||||
if cfg.policy.sequentially_select:
|
||||
decay_params = (
|
||||
decay_params
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
|
||||
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
|
||||
)
|
||||
else:
|
||||
decay_params = decay_params + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
|
||||
|
||||
optim_groups = [
|
||||
{
|
||||
"params": decay_params,
|
||||
"weight_decay": cfg.training.adam_weight_decay,
|
||||
"lr": cfg.training.lr,
|
||||
},
|
||||
{
|
||||
"params": vqvae_params,
|
||||
"weight_decay": 0.0001,
|
||||
"lr": cfg.training.vqvae_lr,
|
||||
},
|
||||
{
|
||||
"params": no_decay_params,
|
||||
"weight_decay": 0.0,
|
||||
"lr": cfg.training.lr,
|
||||
},
|
||||
]
|
||||
super().__init__(
|
||||
optim_groups,
|
||||
cfg.training.lr,
|
||||
cfg.training.adam_betas,
|
||||
cfg.training.adam_eps,
|
||||
)
|
||||
|
||||
|
||||
class VQBeTScheduler(nn.Module):
|
||||
def __init__(self, optimizer, cfg):
|
||||
super().__init__()
|
||||
n_vqvae_training_steps = cfg.training.n_vqvae_training_steps
|
||||
|
||||
num_warmup_steps = cfg.training.lr_warmup_steps
|
||||
num_training_steps = cfg.training.offline_steps
|
||||
num_cycles = 0.5
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < n_vqvae_training_steps:
|
||||
return float(1)
|
||||
else:
|
||||
current_step = current_step - n_vqvae_training_steps
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
progress = float(current_step - num_warmup_steps) / float(
|
||||
max(1, num_training_steps - num_warmup_steps)
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
||||
|
||||
self.lr_scheduler = LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
def step(self):
|
||||
self.lr_scheduler.step()
|
||||
|
||||
|
||||
class VQBeTRgbEncoder(nn.Module):
|
||||
"""Encode an RGB image into a 1D feature vector.
|
||||
|
||||
Includes the ability to normalize and crop the image first.
|
||||
|
||||
Same with DiffusionRgbEncoder from modeling_diffusion.py
|
||||
"""
|
||||
|
||||
def __init__(self, config: VQBeTConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if config.crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
self.do_crop = False
|
||||
|
||||
# Set up backbone.
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
weights=config.pretrained_backbone_weights
|
||||
)
|
||||
# Note: This assumes that the layer4 feature map is children()[-3]
|
||||
# TODO(alexander-soare): Use a safer alternative.
|
||||
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
|
||||
if config.use_group_norm:
|
||||
if config.pretrained_backbone_weights:
|
||||
raise ValueError(
|
||||
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
|
||||
)
|
||||
self.backbone = _replace_submodules(
|
||||
root_module=self.backbone,
|
||||
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features),
|
||||
)
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.input_shapes` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
assert len(image_keys) == 1
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, C, H, W) image tensor with pixel values in [0, 1].
|
||||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
else:
|
||||
# Always use center crop for eval.
|
||||
x = self.center_crop(x)
|
||||
# Extract backbone feature.
|
||||
x = torch.flatten(self.pool(self.backbone(x)), start_dim=1)
|
||||
# Final linear layer with non-linearity.
|
||||
x = self.relu(self.out(x))
|
||||
return x
|
||||
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
root_module: The module for which the submodules need to be replaced
|
||||
predicate: Takes a module as an argument and must return True if the that module is to be replaced.
|
||||
func: Takes a module as an argument and returns a new module to replace it with.
|
||||
Returns:
|
||||
The root module with its submodules replaced.
|
||||
"""
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parents, k in replace_list:
|
||||
parent_module = root_module
|
||||
if len(parents) > 0:
|
||||
parent_module = root_module.get_submodule(".".join(parents))
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
src_module = parent_module[int(k)]
|
||||
else:
|
||||
src_module = getattr(parent_module, k)
|
||||
tgt_module = func(src_module)
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
parent_module[int(k)] = tgt_module
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True))
|
||||
return root_module
|
||||
|
||||
|
||||
class VqVae(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: VQBeTConfig,
|
||||
):
|
||||
"""
|
||||
VQ-VAE is composed of three parts: encoder, vq_layer, and decoder.
|
||||
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
|
||||
The vq_layer uses residual VQs.
|
||||
|
||||
This class contains functions for training the encoder and decoder along with the residual VQ layer (for trainign phase 1),
|
||||
as well as functions to help BeT training part in training phase 2.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# 'discretized' indicates whether the Residual VQ part is trained or not. (After finishing the training, we set discretized=True)
|
||||
self.register_buffer("discretized", torch.tensor(False))
|
||||
self.optimized_steps = 0
|
||||
# we use the fixed number of layers for Residual VQ across all environments.
|
||||
self.vqvae_num_layers = 2
|
||||
|
||||
self.vq_layer = ResidualVQ(
|
||||
dim=config.vqvae_embedding_dim,
|
||||
num_quantizers=self.vqvae_num_layers,
|
||||
codebook_size=config.vqvae_n_embed,
|
||||
)
|
||||
|
||||
self.encoder = MLP(
|
||||
in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size,
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_embedding_dim,
|
||||
],
|
||||
)
|
||||
self.decoder = MLP(
|
||||
in_channels=config.vqvae_embedding_dim,
|
||||
hidden_channels=[
|
||||
config.vqvae_enc_hidden_dim,
|
||||
config.vqvae_enc_hidden_dim,
|
||||
self.config.output_shapes["action"][0] * self.config.action_chunk_size,
|
||||
],
|
||||
)
|
||||
|
||||
def get_embeddings_from_code(self, encoding_indices):
|
||||
# This function gets code indices as inputs, and outputs embedding vectors corresponding to the code indices.
|
||||
with torch.no_grad():
|
||||
z_embed = self.vq_layer.get_codebook_vector_from_indices(encoding_indices)
|
||||
# since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
|
||||
z_embed = z_embed.sum(dim=0)
|
||||
return z_embed
|
||||
|
||||
def get_action_from_latent(self, latent):
|
||||
# given latent vector, this function outputs the decoded action.
|
||||
output = self.decoder(latent)
|
||||
if self.config.action_chunk_size == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
else:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
|
||||
def get_code(self, state):
|
||||
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)
|
||||
# this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://arxiv.org/pdf/2403.03181)
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
with torch.no_grad():
|
||||
state_rep = self.encoder(state)
|
||||
state_rep_shape = state_rep.shape[:-1]
|
||||
state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
|
||||
state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
|
||||
state_vq = state_rep_flat.view(*state_rep_shape, -1)
|
||||
vq_code = vq_code.view(*state_rep_shape, -1)
|
||||
vq_loss_state = torch.sum(vq_loss_state)
|
||||
return state_vq, vq_code
|
||||
|
||||
def vqvae_forward(self, state):
|
||||
# This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://arxiv.org/pdf/2403.03181).
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
# We start with passing action (or action chunk) at:t+n through the encoder ϕ.
|
||||
state_rep = self.encoder(state)
|
||||
state_rep_shape = state_rep.shape[:-1]
|
||||
state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
|
||||
# The resulting latent embedding vector x = ϕ(at:t+n) is then mapped to an embedding vector in the codebook of the RVQ layers by the nearest neighbor look-up.
|
||||
state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat)
|
||||
state_vq = state_rep_flat.view(*state_rep_shape, -1)
|
||||
vq_code = vq_code.view(*state_rep_shape, -1)
|
||||
# since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination.
|
||||
vq_loss_state = torch.sum(vq_loss_state)
|
||||
# Then, the discretized vector zq(x) is reconstructed as ψ(zq(x)) by passing through the decoder ψ.
|
||||
dec_out = self.decoder(state_vq)
|
||||
# Calculate L1 reconstruction loss
|
||||
encoder_loss = (state - dec_out).abs().mean()
|
||||
# add encoder reconstruction loss and commitment loss
|
||||
rep_loss = encoder_loss + vq_loss_state * 5
|
||||
|
||||
metric = (
|
||||
encoder_loss.clone().detach(),
|
||||
vq_loss_state.clone().detach(),
|
||||
vq_code,
|
||||
rep_loss.item(),
|
||||
)
|
||||
return rep_loss, metric
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
"""
|
||||
From https://github.com/notmahi/miniBET/blob/main/behavior_transformer/bet.py
|
||||
"""
|
||||
|
||||
def __init__(self, gamma: float = 0, size_average: bool = True):
|
||||
super().__init__()
|
||||
self.gamma = gamma
|
||||
self.size_average = size_average
|
||||
|
||||
def forward(self, input, target):
|
||||
if len(input.shape) == 3:
|
||||
N, T, _ = input.shape
|
||||
logpt = F.log_softmax(input, dim=-1)
|
||||
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
|
||||
elif len(input.shape) == 2:
|
||||
logpt = F.log_softmax(input, dim=-1)
|
||||
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
|
||||
pt = logpt.exp()
|
||||
|
||||
loss = -1 * (1 - pt) ** self.gamma * logpt
|
||||
if self.size_average:
|
||||
return loss.mean()
|
||||
else:
|
||||
return loss.sum()
|
||||
|
||||
|
||||
class MLP(torch.nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_channels: List[int],
|
||||
):
|
||||
layers = []
|
||||
in_dim = in_channels
|
||||
for hidden_dim in hidden_channels[:-1]:
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_dim))
|
||||
layers.append(torch.nn.ReLU())
|
||||
in_dim = hidden_dim
|
||||
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1]))
|
||||
|
||||
super().__init__(*layers)
|
||||
File diff suppressed because it is too large
Load Diff
325
lerobot/common/robot_devices/cameras/intelrealsense.py
Normal file
325
lerobot/common/robot_devices/cameras/intelrealsense.py
Normal file
@@ -0,0 +1,325 @@
|
||||
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
import time
|
||||
import traceback
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pyrealsense2 as rs
|
||||
|
||||
from lerobot.common.robot_devices.cameras.opencv import find_camera_indices
|
||||
from lerobot.common.robot_devices.cameras.utils import save_color_image, save_depth_image
|
||||
|
||||
SERIAL_NUMBER_INDEX = 1
|
||||
|
||||
def find_camera_indices(raise_when_empty=True):
|
||||
camera_ids = []
|
||||
for device in rs.context().query_devices():
|
||||
serial_number = int(device.get_info(rs.camera_info(SERIAL_NUMBER_INDEX)))
|
||||
camera_ids.append(serial_number)
|
||||
|
||||
if raise_when_empty and len(camera_ids) == 0:
|
||||
raise OSError("Not a single camera was detected. Try re-plugging, or re-installing `librealsense` and its python wrapper `pyrealsense2`, or updating the firmware.")
|
||||
|
||||
return camera_ids
|
||||
|
||||
def benchmark_cameras(cameras, out_dir=None, save_images=False):
|
||||
if save_images:
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
while True:
|
||||
now = time.time()
|
||||
for camera in cameras:
|
||||
if camera.use_depth:
|
||||
color_image, depth_image = camera.capture_image("bgr" if save_images else "rgb")
|
||||
else:
|
||||
color_image = camera.capture_image("bgr" if save_images else "rgb")
|
||||
|
||||
if save_images:
|
||||
image_path = out_dir / f"camera_{camera.camera_index:02}.png"
|
||||
print(f"Write to {image_path}")
|
||||
save_color_image(color_image, image_path, write_shape=True)
|
||||
|
||||
if camera.use_depth:
|
||||
# Apply colormap on depth image (image must be converted to 8-bit per pixel first)
|
||||
depth_image_path = out_dir / f"camera_{camera.camera_index:02}_depth.png"
|
||||
print(f"Write to {depth_image_path}")
|
||||
save_depth_image(depth_image_path, depth_image, write_shape=True)
|
||||
|
||||
dt_s = (time.time() - now)
|
||||
dt_ms = dt_s * 1000
|
||||
freq = 1 / dt_s
|
||||
print(f"Latency (ms): {dt_ms:.2f}\tFrequency: {freq:.2f}")
|
||||
|
||||
if save_images:
|
||||
break
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
|
||||
# Pre-defined configs that worked
|
||||
|
||||
@dataclass
|
||||
class IntelRealSenseCameraConfig:
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
IntelRealSenseCameraConfig(30, 640, 480)
|
||||
IntelRealSenseCameraConfig(60, 640, 480)
|
||||
IntelRealSenseCameraConfig(90, 640, 480)
|
||||
IntelRealSenseCameraConfig(30, 1280, 720)
|
||||
|
||||
IntelRealSenseCameraConfig(30, 640, 480, use_depth=True)
|
||||
IntelRealSenseCameraConfig(60, 640, 480, use_depth=True)
|
||||
IntelRealSenseCameraConfig(90, 640, 480, use_depth=True)
|
||||
IntelRealSenseCameraConfig(30, 1280, 720, use_depth=True)
|
||||
```
|
||||
"""
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color: str = "rgb"
|
||||
use_depth: bool = False
|
||||
force_hardware_reset: bool = True
|
||||
|
||||
|
||||
|
||||
class IntelRealSenseCamera():
|
||||
# TODO(rcadene): improve dosctring
|
||||
"""
|
||||
Using this class requires:
|
||||
- [installing `librealsense` and its python wrapper `pyrealsense2`](https://github.com/IntelRealSense/librealsense/blob/master/doc/distribution_linux.md)
|
||||
- [updating the camera(s) firmware](https://dev.intelrealsense.com/docs/firmware-releases-d400)
|
||||
|
||||
Example of getting the `camera_index` for your camera(s):
|
||||
```bash
|
||||
rs-fw-update -l
|
||||
|
||||
> Connected devices:
|
||||
> 1) [USB] Intel RealSense D405 s/n 128422270109, update serial number: 133323070634, firmware version: 5.16.0.1
|
||||
> 2) [USB] Intel RealSense D405 s/n 128422271609, update serial number: 130523070758, firmware version: 5.16.0.1
|
||||
> 3) [USB] Intel RealSense D405 s/n 128422271614, update serial number: 133323070576, firmware version: 5.16.0.1
|
||||
> 4) [USB] Intel RealSense D405 s/n 128422271393, update serial number: 133323070271, firmware version: 5.16.0.1
|
||||
```
|
||||
|
||||
Example of uage:
|
||||
|
||||
```python
|
||||
camera = IntelRealSenseCamera(128422270109) # serial number (s/n)
|
||||
color_image = camera.capture_image()
|
||||
```
|
||||
|
||||
Example of capturing additional depth image:
|
||||
|
||||
```python
|
||||
config = IntelRealSenseCameraConfig(use_depth=True)
|
||||
camera = IntelRealSenseCamera(128422270109, config)
|
||||
color_image, depth_image = camera.capture_image()
|
||||
```
|
||||
"""
|
||||
AVAILABLE_CAMERA_INDICES = find_camera_indices()
|
||||
|
||||
def __init__(self,
|
||||
camera_index: int | None = None,
|
||||
config: IntelRealSenseCameraConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if config is None:
|
||||
config = IntelRealSenseCameraConfig()
|
||||
# Overwrite config arguments using kwargs
|
||||
config = replace(config, **kwargs)
|
||||
|
||||
self.camera_index = camera_index
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.color = config.color
|
||||
self.use_depth = config.use_depth
|
||||
self.force_hardware_reset = config.force_hardware_reset
|
||||
|
||||
# TODO(rcadene): move these two check in config dataclass
|
||||
if self.color not in ["rgb", "bgr"]:
|
||||
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {self.color} is provided.")
|
||||
|
||||
if (self.fps or self.width or self.height) and not (self.fps and self.width and self.height):
|
||||
raise ValueError(f"Expected all fps, width and height to be set, when one of them is set, but {self.fps=}, {self.width=}, {self.height=}.")
|
||||
|
||||
if self.camera_index is None:
|
||||
raise ValueError(f"`camera_index` is expected to be a serial number of one of these available cameras ({IntelRealSenseCamera.AVAILABLE_CAMERA_INDICES}), but {camera_index} is provided instead.")
|
||||
|
||||
self.camera = None
|
||||
self.is_connected = False
|
||||
|
||||
self.t = Thread(target=self.capture_image_loop, args=())
|
||||
self.t.daemon = True
|
||||
self._color_image = None
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise ValueError(f"Camera {self.camera_index} is already connected.")
|
||||
|
||||
config = rs.config()
|
||||
config.enable_device(str(self.camera_index))
|
||||
|
||||
if self.fps and self.width and self.height:
|
||||
# TODO(rcadene): can we set rgb8 directly?
|
||||
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
|
||||
else:
|
||||
config.enable_stream(rs.stream.color)
|
||||
|
||||
if self.use_depth:
|
||||
if self.fps and self.width and self.height:
|
||||
config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
|
||||
else:
|
||||
config.enable_stream(rs.stream.depth)
|
||||
|
||||
self.camera = rs.pipeline()
|
||||
try:
|
||||
self.camera.start(config)
|
||||
except RuntimeError:
|
||||
# Verify that the provided `camera_index` is valid before printing the traceback
|
||||
if self.camera_index not in IntelRealSenseCamera.AVAILABLE_CAMERA_INDICES:
|
||||
raise ValueError(f"`camera_index` is expected to be a serial number of one of these available cameras {IntelRealSenseCamera.AVAILABLE_CAMERA_INDICES}, but {self.camera_index} is provided instead.")
|
||||
traceback.print_exc()
|
||||
|
||||
self.is_connected = True
|
||||
self.t.start()
|
||||
|
||||
def capture_image(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
||||
frame = self.camera.wait_for_frames()
|
||||
|
||||
color_frame = frame.get_color_frame()
|
||||
if not color_frame:
|
||||
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
|
||||
if temporary_color is None:
|
||||
requested_color = self.color
|
||||
else:
|
||||
requested_color = temporary_color
|
||||
|
||||
if requested_color not in ["rgb", "bgr"]:
|
||||
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {requested_color} is provided.")
|
||||
|
||||
# OpenCV uses BGR format as default (blue, green red) for all operations, including displaying images.
|
||||
# However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks,
|
||||
# so we convert the image color from BGR to RGB.
|
||||
# if requested_color == "rgb":
|
||||
# color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if self.use_depth:
|
||||
depth_frame = frame.get_depth_frame()
|
||||
if not depth_frame:
|
||||
raise OSError(f"Can't capture depth image from camera {self.camera_index}.")
|
||||
depth_image = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
return color_image, depth_image
|
||||
else:
|
||||
return color_image
|
||||
|
||||
def capture_image_loop(self):
|
||||
while True:
|
||||
self._color_image = self.capture_image()
|
||||
|
||||
def read(self):
|
||||
while self._color_image is None:
|
||||
time.sleep(0.1)
|
||||
return self._color_image
|
||||
|
||||
def disconnect(self):
|
||||
if getattr(self, "camera", None):
|
||||
try:
|
||||
self.camera.stop()
|
||||
except RuntimeError as e:
|
||||
if "stop() cannot be called before start()" in str(e):
|
||||
# skip this runtime error
|
||||
return
|
||||
traceback.print_exc()
|
||||
|
||||
def __del__(self):
|
||||
self.disconnect()
|
||||
|
||||
|
||||
def save_images_config(config, out_dir: Path):
|
||||
camera_ids = IntelRealSenseCamera.AVAILABLE_CAMERA_INDICES
|
||||
cameras = []
|
||||
print(f"Available camera indices: {camera_ids}")
|
||||
for camera_idx in camera_ids:
|
||||
camera = IntelRealSenseCamera(camera_idx, config)
|
||||
cameras.append(camera)
|
||||
|
||||
out_dir = out_dir.parent / f"{out_dir.name}_{config.width}x{config.height}_{config.fps}_depth_{config.use_depth}"
|
||||
benchmark_cameras(cameras, out_dir, save_images=True)
|
||||
|
||||
def benchmark_config(config, camera_ids: list[int]):
|
||||
cameras = [IntelRealSenseCamera(idx, config) for idx in camera_ids]
|
||||
benchmark_cameras(cameras)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--mode", type=str, choices=["save_images", 'benchmark'], default="save_images")
|
||||
parser.add_argument("--camera-ids", type=int, nargs="*", default=[128422271609, 128422271614, 128422271393])
|
||||
parser.add_argument("--fps", type=int, default=30)
|
||||
parser.add_argument("--width", type=str, default=640)
|
||||
parser.add_argument("--height", type=str, default=480)
|
||||
parser.add_argument("--use-depth", type=int, default=0)
|
||||
parser.add_argument("--out-dir", type=Path, default="outputs/benchmark_cameras/intelrealsense/2024_06_22_1738")
|
||||
args = parser.parse_args()
|
||||
|
||||
config = IntelRealSenseCameraConfig(args.fps, args.width, args.height, use_depth=bool(args.use_depth))
|
||||
# config = IntelRealSenseCameraConfig()
|
||||
# config = IntelRealSenseCameraConfig(60, 640, 480)
|
||||
# config = IntelRealSenseCameraConfig(90, 640, 480)
|
||||
# config = IntelRealSenseCameraConfig(30, 1280, 720)
|
||||
|
||||
if args.mode == "save_images":
|
||||
save_images_config(config, args.out_dir)
|
||||
elif args.mode == "benchmark":
|
||||
benchmark_config(config, args.camera_ids)
|
||||
else:
|
||||
raise ValueError(args.mode)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# # Works well!
|
||||
# # use_depth = False
|
||||
# # fps = 90
|
||||
# # width = 640
|
||||
# # height = 480
|
||||
|
||||
# # # Works well!
|
||||
# # use_depth = True
|
||||
# # fps = 90
|
||||
# # width = 640
|
||||
# # height = 480
|
||||
|
||||
# # # Doesn't work well, latency varies too much
|
||||
# # use_depth = True
|
||||
# # fps = 30
|
||||
# # width = 1280
|
||||
# # height = 720
|
||||
|
||||
# # Works well
|
||||
# use_depth = False
|
||||
# fps = 30
|
||||
# width = 1280
|
||||
# height = 720
|
||||
|
||||
# config = IntelRealSenseCameraConfig()
|
||||
# # config = IntelRealSenseCameraConfig(fps, width, height, use_depth=use_depth)
|
||||
# cameras = [
|
||||
# # IntelRealSenseCamera(0, config),
|
||||
# # IntelRealSenseCamera(128422270109, config),
|
||||
# IntelRealSenseCamera(128422271609, config),
|
||||
# IntelRealSenseCamera(128422271614, config),
|
||||
# IntelRealSenseCamera(128422271393, config),
|
||||
# ]
|
||||
|
||||
# out_dir = "outputs/benchmark_cameras/intelrealsense/2024_06_22_1729"
|
||||
# out_dir += f"{config.width}x{config.height}_{config.fps}_depth_{config.use_depth}"
|
||||
# benchmark_cameras(cameras, out_dir, save_images=False)
|
||||
@@ -1,55 +1,19 @@
|
||||
"""
|
||||
This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring.
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import math
|
||||
import platform
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, replace
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
from lerobot.scripts.control_robot import busy_wait
|
||||
|
||||
# Use 1 thread to avoid blocking the main thread. Especially useful during data collection
|
||||
# when other threads are used to save the images.
|
||||
cv2.setNumThreads(1)
|
||||
|
||||
# The maximum opencv device index depends on your operating system. For instance,
|
||||
# if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case
|
||||
# on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23.
|
||||
# When you change the USB port or reboot the computer, the operating system might
|
||||
# treat the same cameras as new devices. Thus we select a higher bound to search indices.
|
||||
MAX_OPENCV_INDEX = 60
|
||||
from lerobot.common.robot_devices.cameras.utils import save_color_image
|
||||
|
||||
|
||||
def find_camera_indices(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX):
|
||||
if platform.system() == "Linux":
|
||||
# Linux uses camera ports
|
||||
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
||||
possible_camera_ids = []
|
||||
for port in Path("/dev").glob("video*"):
|
||||
camera_idx = int(str(port).replace("/dev/video", ""))
|
||||
possible_camera_ids.append(camera_idx)
|
||||
else:
|
||||
print(
|
||||
"Mac or Windows detected. Finding available camera indices through "
|
||||
f"scanning all indices from 0 to {MAX_OPENCV_INDEX}"
|
||||
)
|
||||
possible_camera_ids = range(max_index_search_range)
|
||||
|
||||
def find_camera_indices(raise_when_empty=False, max_index_search_range=60):
|
||||
camera_ids = []
|
||||
for camera_idx in possible_camera_ids:
|
||||
for camera_idx in range(max_index_search_range):
|
||||
camera = cv2.VideoCapture(camera_idx)
|
||||
is_open = camera.isOpened()
|
||||
camera.release()
|
||||
@@ -59,84 +23,49 @@ def find_camera_indices(raise_when_empty=False, max_index_search_range=MAX_OPENC
|
||||
camera_ids.append(camera_idx)
|
||||
|
||||
if raise_when_empty and len(camera_ids) == 0:
|
||||
raise OSError(
|
||||
"Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, "
|
||||
"or your camera driver, or make sure your camera is compatible with opencv2."
|
||||
)
|
||||
raise OSError("Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, or your camera driver, or make sure your camera is compatible with opencv2.")
|
||||
|
||||
return camera_ids
|
||||
|
||||
def benchmark_cameras(cameras, out_dir=None, save_images=False, num_warmup_frames=4):
|
||||
if out_dir:
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def save_image(img_array, camera_index, frame_index, images_dir):
|
||||
img = Image.fromarray(img_array)
|
||||
path = images_dir / f"camera_{camera_index:02d}_frame_{frame_index:06d}.png"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(str(path), quality=100)
|
||||
for _ in range(num_warmup_frames):
|
||||
for camera in cameras:
|
||||
try:
|
||||
camera.capture_image()
|
||||
time.sleep(0.01)
|
||||
except OSError as e:
|
||||
print(e)
|
||||
|
||||
while True:
|
||||
now = time.time()
|
||||
for camera in cameras:
|
||||
color_image = camera.capture_image("bgr" if save_images else "rgb")
|
||||
|
||||
def save_images_from_cameras(
|
||||
images_dir: Path, camera_ids: list[int] | None = None, fps=None, width=None, height=None, record_time_s=2
|
||||
):
|
||||
if camera_ids is None:
|
||||
camera_ids = find_camera_indices()
|
||||
if save_images:
|
||||
image_path = out_dir / f"camera_{camera.camera_index:02}.png"
|
||||
print(f"Write to {image_path}")
|
||||
save_color_image(color_image, image_path, write_shape=True)
|
||||
|
||||
print("Connecting cameras")
|
||||
cameras = []
|
||||
for cam_idx in camera_ids:
|
||||
camera = OpenCVCamera(cam_idx, fps=fps, width=width, height=height)
|
||||
camera.connect()
|
||||
print(
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
|
||||
f"height={camera.height}, color_mode={camera.color_mode})"
|
||||
)
|
||||
cameras.append(camera)
|
||||
dt_s = (time.time() - now)
|
||||
dt_ms = dt_s * 1000
|
||||
freq = 1 / dt_s
|
||||
print(f"Latency (ms): {dt_ms:.2f}\tFrequency: {freq:.2f}")
|
||||
|
||||
images_dir = Path(images_dir)
|
||||
if images_dir.exists():
|
||||
shutil.rmtree(
|
||||
images_dir,
|
||||
)
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Saving images to {images_dir}")
|
||||
frame_index = 0
|
||||
start_time = time.perf_counter()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
||||
while True:
|
||||
now = time.perf_counter()
|
||||
|
||||
for camera in cameras:
|
||||
# If we use async_read when fps is None, the loop will go full speed, and we will endup
|
||||
# saving the same images from the cameras multiple times until the RAM/disk is full.
|
||||
image = camera.read() if fps is None else camera.async_read()
|
||||
|
||||
executor.submit(
|
||||
save_image,
|
||||
image,
|
||||
camera.camera_index,
|
||||
frame_index,
|
||||
images_dir,
|
||||
)
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
if time.perf_counter() - start_time > record_time_s:
|
||||
break
|
||||
|
||||
print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}")
|
||||
|
||||
frame_index += 1
|
||||
|
||||
print(f"Images have been saved to {images_dir}")
|
||||
if save_images:
|
||||
break
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenCVCameraConfig:
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
|
||||
```python
|
||||
OpenCVCameraConfig(30, 640, 480)
|
||||
OpenCVCameraConfig(60, 640, 480)
|
||||
@@ -144,58 +73,26 @@ class OpenCVCameraConfig:
|
||||
OpenCVCameraConfig(30, 1280, 720)
|
||||
```
|
||||
"""
|
||||
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"Expected color_mode values are 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
color: str = "rgb"
|
||||
|
||||
|
||||
class OpenCVCamera:
|
||||
|
||||
class OpenCVCamera():
|
||||
# TODO(rcadene): improve dosctring
|
||||
"""
|
||||
The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate
|
||||
with the cameras. Most cameras are compatible. For more info, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
https://docs.opencv.org/4.x/d0/da7/videoio_overview.html
|
||||
https://docs.opencv.org/4.x/d4/d15/group__videoio__flags__base.html#ga023786be1ee68a9105bf2e48c700294d
|
||||
|
||||
An OpenCVCamera instance requires a camera index (e.g. `OpenCVCamera(camera_index=0)`). When you only have one camera
|
||||
like a webcam of a laptop, the camera index is expected to be 0, but it might also be very different, and the camera index
|
||||
might change if you reboot your computer or re-plug your camera. This behavior depends on your operation system.
|
||||
Example of uage:
|
||||
|
||||
To find the camera indices of your cameras, you can run our utility script that will be save a few frames for each camera:
|
||||
```bash
|
||||
python lerobot/common/robot_devices/cameras/opencv.py --images-dir outputs/images_from_opencv_cameras
|
||||
```
|
||||
|
||||
When an OpenCVCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
||||
of the given camera will be used.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
camera = OpenCVCamera(camera_index=0)
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
# when done using the camera, consider disconnecting
|
||||
camera.disconnect()
|
||||
```
|
||||
|
||||
Example of changing default fps, width, height and color_mode:
|
||||
```python
|
||||
camera = OpenCVCamera(0, fps=30, width=1280, height=720)
|
||||
camera = connect() # applies the settings, might error out if these settings are not compatible with the camera
|
||||
|
||||
camera = OpenCVCamera(0, fps=90, width=640, height=480)
|
||||
camera = connect()
|
||||
|
||||
camera = OpenCVCamera(0, fps=90, width=640, height=480, color_mode="bgr")
|
||||
camera = connect()
|
||||
camera = OpenCVCamera(2)
|
||||
color_image = camera.capture_image()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, camera_index: int, config: OpenCVCameraConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
config = OpenCVCameraConfig()
|
||||
@@ -206,28 +103,28 @@ class OpenCVCamera:
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.color_mode = config.color_mode
|
||||
self.color = config.color
|
||||
|
||||
if self.color not in ["rgb", "bgr"]:
|
||||
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {self.color} is provided.")
|
||||
|
||||
if self.camera_index is None:
|
||||
raise ValueError(f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {camera_index} is provided instead.")
|
||||
|
||||
self.camera = None
|
||||
self.is_connected = False
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
self.color_image = None
|
||||
self.logs = {}
|
||||
|
||||
self.t = Thread(target=self.capture_image_loop, args=())
|
||||
self.t.daemon = True
|
||||
self._color_image = None
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(f"Camera {self.camera_index} is already connected.")
|
||||
raise ValueError(f"Camera {self.camera_index} is already connected.")
|
||||
|
||||
# First create a temporary camera trying to access `camera_index`,
|
||||
# and verify it is a valid camera by calling `isOpened`.
|
||||
|
||||
if platform.system() == "Linux":
|
||||
# Linux uses ports for connecting to cameras
|
||||
tmp_camera = cv2.VideoCapture(f"/dev/video{self.camera_index}")
|
||||
else:
|
||||
tmp_camera = cv2.VideoCapture(self.camera_index)
|
||||
|
||||
tmp_camera = cv2.VideoCapture(self.camera_index)
|
||||
is_camera_open = tmp_camera.isOpened()
|
||||
# Release camera to make it accessible for `find_camera_indices`
|
||||
del tmp_camera
|
||||
@@ -236,188 +133,117 @@ class OpenCVCamera:
|
||||
# valid cameras.
|
||||
if not is_camera_open:
|
||||
# Verify that the provided `camera_index` is valid before printing the traceback
|
||||
available_cam_ids = find_camera_indices()
|
||||
if self.camera_index not in available_cam_ids:
|
||||
raise ValueError(
|
||||
f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead. "
|
||||
"To find the camera index you should use, run `python lerobot/common/robot_devices/cameras/opencv.py`."
|
||||
)
|
||||
if self.camera_index not in find_camera_indices():
|
||||
raise ValueError(f"`camera_index` is expected to be one of these available cameras {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}, but {self.camera_index} is provided instead.")
|
||||
|
||||
raise OSError(f"Can't access camera {self.camera_index}.")
|
||||
|
||||
|
||||
# Secondly, create the camera that will be used downstream.
|
||||
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
|
||||
# needs to be re-created.
|
||||
if platform.system() == "Linux":
|
||||
self.camera = cv2.VideoCapture(f"/dev/video{self.camera_index}")
|
||||
else:
|
||||
self.camera = cv2.VideoCapture(self.camera_index)
|
||||
self.camera = cv2.VideoCapture(self.camera_index)
|
||||
|
||||
if self.fps is not None:
|
||||
if self.fps:
|
||||
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
||||
if self.width is not None:
|
||||
if self.width:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
|
||||
if self.height is not None:
|
||||
if self.height:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
|
||||
|
||||
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
|
||||
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
|
||||
actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)
|
||||
|
||||
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for camera {self.camera_index}. Actual value is {actual_fps}."
|
||||
)
|
||||
if self.width is not None and self.width != actual_width:
|
||||
raise OSError(
|
||||
f"Can't set {self.width=} for camera {self.camera_index}. Actual value is {actual_width}."
|
||||
)
|
||||
if self.height is not None and self.height != actual_height:
|
||||
raise OSError(
|
||||
f"Can't set {self.height=} for camera {self.camera_index}. Actual value is {actual_height}."
|
||||
)
|
||||
|
||||
self.fps = actual_fps
|
||||
self.width = actual_width
|
||||
self.height = actual_height
|
||||
|
||||
if self.fps and self.fps != actual_fps:
|
||||
raise OSError(f"Can't set {self.fps=} for camera {self.camera_index}. Actual value is {actual_fps}.")
|
||||
if self.width and self.width != actual_width:
|
||||
raise OSError(f"Can't set {self.width=} for camera {self.camera_index}. Actual value is {actual_width}.")
|
||||
if self.height and self.height != actual_height:
|
||||
raise OSError(f"Can't set {self.height=} for camera {self.camera_index}. Actual value is {actual_height}.")
|
||||
|
||||
self.is_connected = True
|
||||
self.t.start()
|
||||
|
||||
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
|
||||
"""Read a frame from the camera returned in the format (height, width, channels)
|
||||
(e.g. (640, 480, 3)), contrarily to the pytorch format which is channel first.
|
||||
|
||||
Note: Reading a frame is done every `camera.fps` times per second, and it is blocking.
|
||||
If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`.
|
||||
"""
|
||||
def capture_image(self, temporary_color: str | None = None) -> np.ndarray:
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
self.connect()
|
||||
|
||||
ret, color_image = self.camera.read()
|
||||
if not ret:
|
||||
raise OSError(f"Can't capture color image from camera {self.camera_index}.")
|
||||
|
||||
requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode
|
||||
|
||||
if requested_color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided."
|
||||
)
|
||||
|
||||
if temporary_color is None:
|
||||
requested_color = self.color
|
||||
else:
|
||||
requested_color = temporary_color
|
||||
|
||||
if requested_color not in ["rgb", "bgr"]:
|
||||
raise ValueError(f"Expected color values are 'rgb' or 'bgr', but {requested_color} is provided.")
|
||||
|
||||
# OpenCV uses BGR format as default (blue, green red) for all operations, including displaying images.
|
||||
# However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks,
|
||||
# so we convert the image color from BGR to RGB.
|
||||
if requested_color_mode == "rgb":
|
||||
if requested_color == "rgb":
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
h, w, _ = color_image.shape
|
||||
if h != self.height or w != self.width:
|
||||
raise OSError(
|
||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
|
||||
# log the number of seconds it took to read the image
|
||||
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the image was received
|
||||
self.logs["timestamp_utc"] = capture_timestamp_utc()
|
||||
|
||||
|
||||
return color_image
|
||||
|
||||
def capture_image_loop(self):
|
||||
while True:
|
||||
self._color_image = self.capture_image()
|
||||
|
||||
def read_loop(self):
|
||||
while self.stop_event is None or not self.stop_event.is_set():
|
||||
self.color_image = self.read()
|
||||
|
||||
def async_read(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
if self.thread is None:
|
||||
self.stop_event = threading.Event()
|
||||
self.thread = Thread(target=self.read_loop, args=())
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
num_tries = 0
|
||||
while self.color_image is None:
|
||||
num_tries += 1
|
||||
time.sleep(1 / self.fps)
|
||||
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
||||
raise Exception(
|
||||
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
||||
)
|
||||
|
||||
return self.color_image
|
||||
def read(self):
|
||||
while self._color_image is None:
|
||||
time.sleep(0.1)
|
||||
return self._color_image
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
# wait for the thread to finish
|
||||
self.stop_event.set()
|
||||
self.thread.join()
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
self.camera.release()
|
||||
self.camera = None
|
||||
|
||||
self.is_connected = False
|
||||
if getattr(self, "camera", None):
|
||||
self.camera.release()
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
self.disconnect()
|
||||
|
||||
|
||||
def save_images_config(config: OpenCVCameraConfig, out_dir: Path):
|
||||
cameras = []
|
||||
print(f"Available camera indices: {OpenCVCamera.AVAILABLE_CAMERAS_INDICES}")
|
||||
for camera_idx in OpenCVCamera.AVAILABLE_CAMERAS_INDICES:
|
||||
camera = OpenCVCamera(camera_idx, config)
|
||||
cameras.append(camera)
|
||||
|
||||
out_dir = out_dir.parent / f"{out_dir.name}_{config.width}x{config.height}_{config.fps}"
|
||||
benchmark_cameras(cameras, out_dir, save_images=True)
|
||||
|
||||
def benchmark_config(config: OpenCVCameraConfig, camera_ids: list[int]):
|
||||
cameras = [OpenCVCamera(idx, config) for idx in camera_ids]
|
||||
benchmark_cameras(cameras)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Save a few frames using `OpenCVCamera` for all cameras connected to the computer, or a selected subset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--camera-ids",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="List of camera indices used to instantiate the `OpenCVCamera`. If not provided, find and use all available camera indices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--images-dir",
|
||||
type=Path,
|
||||
default="outputs/images_from_opencv_cameras",
|
||||
help="Set directory to save a few frames for each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--record-time-s",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="Set the number of seconds used to record the frames. By default, 2 seconds.",
|
||||
)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--mode", type=str, choices=["save_images", 'benchmark'], default="save_images")
|
||||
parser.add_argument("--camera-ids", type=int, nargs="*", default=[16, 4, 22, 10])
|
||||
parser.add_argument("--fps", type=int, default=30)
|
||||
parser.add_argument("--width", type=str, default=640)
|
||||
parser.add_argument("--height", type=str, default=480)
|
||||
parser.add_argument("--out-dir", type=Path, default="outputs/benchmark_cameras/opencv/2024_06_22_1727")
|
||||
args = parser.parse_args()
|
||||
save_images_from_cameras(**vars(args))
|
||||
|
||||
config = OpenCVCameraConfig(args.fps, args.width, args.height)
|
||||
# config = OpenCVCameraConfig()
|
||||
# config = OpenCVCameraConfig(60, 640, 480)
|
||||
# config = OpenCVCameraConfig(90, 640, 480)
|
||||
# config = OpenCVCameraConfig(30, 1280, 720)
|
||||
|
||||
if args.mode == "save_images":
|
||||
save_images_config(config, args.out_dir)
|
||||
elif args.mode == "benchmark":
|
||||
benchmark_config(config, args.camera_ids)
|
||||
else:
|
||||
raise ValueError(args.mode)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
|
||||
from pathlib import Path
|
||||
import time
|
||||
import cv2
|
||||
from typing import Protocol
|
||||
|
||||
import cv2
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
|
||||
def write_shape_on_image_inplace(image):
|
||||
height, width = image.shape[:2]
|
||||
text = f"Width: {width} Height: {height}"
|
||||
text = f'Width: {width} Height: {height}'
|
||||
|
||||
# Define the font, scale, color, and thickness
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
@@ -40,19 +41,8 @@ def save_depth_image(depth, path, write_shape=False):
|
||||
cv2.imwrite(str(path), depth_image)
|
||||
|
||||
|
||||
def convert_torch_image_to_cv2(tensor, rgb_to_bgr=True):
|
||||
assert tensor.ndim == 3
|
||||
c, h, w = tensor.shape
|
||||
assert c < h and c < w
|
||||
color_image = einops.rearrange(tensor, "c h w -> h w c").numpy()
|
||||
if rgb_to_bgr:
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||
return color_image
|
||||
|
||||
|
||||
# Defines a camera type
|
||||
class Camera(Protocol):
|
||||
def connect(self): ...
|
||||
def read(self, temporary_color: str | None = None) -> np.ndarray: ...
|
||||
def async_read(self) -> np.ndarray: ...
|
||||
def disconnect(self): ...
|
||||
|
||||
@@ -1,32 +1,15 @@
|
||||
import enum
|
||||
import time
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import enum
|
||||
from typing import Union
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from dynamixel_sdk import (
|
||||
COMM_SUCCESS,
|
||||
DXL_HIBYTE,
|
||||
DXL_HIWORD,
|
||||
DXL_LOBYTE,
|
||||
DXL_LOWORD,
|
||||
GroupSyncRead,
|
||||
GroupSyncWrite,
|
||||
PacketHandler,
|
||||
PortHandler,
|
||||
)
|
||||
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
from dynamixel_sdk import PacketHandler, PortHandler, COMM_SUCCESS, GroupSyncRead, GroupSyncWrite
|
||||
from dynamixel_sdk import DXL_HIBYTE, DXL_HIWORD, DXL_LOBYTE, DXL_LOWORD
|
||||
|
||||
PROTOCOL_VERSION = 2.0
|
||||
BAUDRATE = 1_000_000
|
||||
BAUD_RATE = 1_000_000
|
||||
TIMEOUT_MS = 1000
|
||||
|
||||
MAX_ID_RANGE = 252
|
||||
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m077
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m288
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl430-w250
|
||||
@@ -86,21 +69,13 @@ X_SERIES_CONTROL_TABLE = {
|
||||
"Velocity_Trajectory": (136, 4),
|
||||
"Position_Trajectory": (140, 4),
|
||||
"Present_Input_Voltage": (144, 2),
|
||||
"Present_Temperature": (146, 1),
|
||||
}
|
||||
|
||||
X_SERIES_BAUDRATE_TABLE = {
|
||||
0: 9_600,
|
||||
1: 57_600,
|
||||
2: 115_200,
|
||||
3: 1_000_000,
|
||||
4: 2_000_000,
|
||||
5: 3_000_000,
|
||||
6: 4_000_000,
|
||||
"Present_Temperature": (146, 1)
|
||||
}
|
||||
|
||||
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
#CONVERT_POSITION_TO_ANGLE_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
CONVERT_POSITION_TO_ANGLE_REQUIRED = []
|
||||
|
||||
MODEL_CONTROL_TABLE = {
|
||||
"x_series": X_SERIES_CONTROL_TABLE,
|
||||
@@ -111,140 +86,58 @@ MODEL_CONTROL_TABLE = {
|
||||
"xm540-w270": X_SERIES_CONTROL_TABLE,
|
||||
}
|
||||
|
||||
MODEL_RESOLUTION = {
|
||||
"x_series": 4096,
|
||||
"xl330-m077": 4096,
|
||||
"xl330-m288": 4096,
|
||||
"xl430-w250": 4096,
|
||||
"xm430-w350": 4096,
|
||||
"xm540-w270": 4096,
|
||||
}
|
||||
|
||||
MODEL_BAUDRATE_TABLE = {
|
||||
"x_series": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl330-m077": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl330-m288": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl430-w250": X_SERIES_BAUDRATE_TABLE,
|
||||
"xm430-w350": X_SERIES_BAUDRATE_TABLE,
|
||||
"xm540-w270": X_SERIES_BAUDRATE_TABLE,
|
||||
}
|
||||
|
||||
NUM_READ_RETRY = 10
|
||||
NUM_WRITE_RETRY = 10
|
||||
|
||||
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]):
|
||||
"""This function convert the degree range to the step range for indicating motors rotation.
|
||||
It assums a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
def uint32_to_int32(values: np.ndarray):
|
||||
"""
|
||||
if isinstance(degrees, float):
|
||||
degrees = np.array(degrees)
|
||||
Convert an unsigned 32-bit integer array to a signed 32-bit integer array.
|
||||
"""
|
||||
for i in range(len(values)):
|
||||
if values[i] is not None and values[i] > 2147483647:
|
||||
values[i] = values[i] - 4294967296
|
||||
return values
|
||||
|
||||
resolutions = [MODEL_RESOLUTION[model] for model in models]
|
||||
steps = degrees / 180 * np.array(resolutions) / 2
|
||||
steps = steps.astype(int)
|
||||
return steps
|
||||
def int32_to_uint32(values: np.ndarray):
|
||||
"""
|
||||
Convert a signed 32-bit integer array to an unsigned 32-bit integer array.
|
||||
"""
|
||||
for i in range(len(values)):
|
||||
if values[i] is not None and values[i] < 0:
|
||||
values[i] = values[i] + 4294967296
|
||||
return values
|
||||
|
||||
def motor_position_to_angle(position: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert from motor position in [-2048, 2048] to radian in [-pi, pi]
|
||||
"""
|
||||
return (position / 2048) * 3.14
|
||||
|
||||
def motor_angle_to_position(angle: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert from radian in [-pi, pi] to motor position in [-2048, 2048]
|
||||
"""
|
||||
return ((angle / 3.14) * 2048).astype(np.int64)
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes):
|
||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||
# already handles it for us.
|
||||
if bytes == 1:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
DXL_LOBYTE(DXL_HIWORD(value)),
|
||||
DXL_HIBYTE(DXL_HIWORD(value)),
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead."
|
||||
)
|
||||
return data
|
||||
# def pwm2vel(pwm: np.ndarray) -> np.ndarray:
|
||||
# """
|
||||
# :param pwm: numpy array of pwm/s joint velocities
|
||||
# :return: numpy array of rad/s joint velocities
|
||||
# """
|
||||
# return pwm * 3.14 / 2048
|
||||
|
||||
|
||||
# def vel2pwm(vel: np.ndarray) -> np.ndarray:
|
||||
# """
|
||||
# :param vel: numpy array of rad/s joint velocities
|
||||
# :return: numpy array of pwm/s joint velocities
|
||||
# """
|
||||
# return (vel * 2048 / 3.14).astype(np.int64)
|
||||
|
||||
|
||||
def get_group_sync_key(data_name, motor_names):
|
||||
group_key = f"{data_name}_" + "_".join(motor_names)
|
||||
group_key = f"{data_name}_" + "_".join([name for name in motor_names])
|
||||
return group_key
|
||||
|
||||
|
||||
def get_result_name(fn_name, data_name, motor_names):
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
rslt_name = f"{fn_name}_{group_key}"
|
||||
return rslt_name
|
||||
|
||||
|
||||
def get_queue_name(fn_name, data_name, motor_names):
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
queue_name = f"{fn_name}_{group_key}"
|
||||
return queue_name
|
||||
|
||||
|
||||
def get_log_name(var_name, fn_name, data_name, motor_names):
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
log_name = f"{var_name}_{fn_name}_{group_key}"
|
||||
return log_name
|
||||
|
||||
|
||||
def assert_same_address(model_ctrl_table, motor_models, data_name):
|
||||
all_addr = []
|
||||
all_bytes = []
|
||||
for model in motor_models:
|
||||
addr, bytes = model_ctrl_table[model][data_name]
|
||||
all_addr.append(addr)
|
||||
all_bytes.append(bytes)
|
||||
|
||||
if len(set(all_addr)) != 1:
|
||||
raise NotImplementedError(
|
||||
f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
|
||||
)
|
||||
|
||||
if len(set(all_bytes)) != 1:
|
||||
raise NotImplementedError(
|
||||
f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
|
||||
)
|
||||
|
||||
|
||||
def find_available_ports():
|
||||
ports = []
|
||||
for path in Path("/dev").glob("tty*"):
|
||||
ports.append(str(path))
|
||||
return ports
|
||||
|
||||
|
||||
def find_port():
|
||||
print("Finding all available ports for the DynamixelMotorsBus.")
|
||||
ports_before = find_available_ports()
|
||||
print(ports_before)
|
||||
|
||||
print("Remove the usb cable from your DynamixelMotorsBus and press Enter when done.")
|
||||
input()
|
||||
|
||||
time.sleep(0.5)
|
||||
ports_after = find_available_ports()
|
||||
ports_diff = list(set(ports_before) - set(ports_after))
|
||||
|
||||
if len(ports_diff) == 1:
|
||||
port = ports_diff[0]
|
||||
print(f"The port of this DynamixelMotorsBus is '{port}'")
|
||||
print("Reconnect the usb cable.")
|
||||
elif len(ports_diff) == 0:
|
||||
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
|
||||
else:
|
||||
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
|
||||
|
||||
|
||||
class TorqueMode(enum.Enum):
|
||||
ENABLED = 1
|
||||
DISABLED = 0
|
||||
@@ -265,52 +158,9 @@ class DriveMode(enum.Enum):
|
||||
|
||||
|
||||
class DynamixelMotorsBus:
|
||||
# TODO(rcadene): Add a script to find the motor indices without DynamixelWizzard2
|
||||
"""
|
||||
The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on
|
||||
the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
|
||||
|
||||
A DynamixelMotorsBus instance requires a port (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
```bash
|
||||
python lerobot/common/robot_devices/motors/dynamixel.py
|
||||
>>> Finding all available ports for the DynamixelMotorsBus.
|
||||
>>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
>>> Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
>>> The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751.
|
||||
>>> Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example of usage for 1 motor connected to the bus:
|
||||
```python
|
||||
motor_name = "gripper"
|
||||
motor_index = 6
|
||||
motor_model = "xl330-m288"
|
||||
|
||||
motors_bus = DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={motor_name: (motor_index, motor_model)},
|
||||
)
|
||||
motors_bus.connect()
|
||||
|
||||
position = motors_bus.read("Present_Position")
|
||||
|
||||
# move from a few motor steps as an example
|
||||
few_steps = 30
|
||||
motors_bus.write("Goal_Position", position + few_steps)
|
||||
|
||||
# when done, consider disconnecting
|
||||
motors_bus.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, tuple[int, str]],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None,
|
||||
extra_model_resolution: dict[str, int] | None = None,
|
||||
):
|
||||
def __init__(self, port: str, motors: dict[str, tuple[int, str]],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
|
||||
@@ -318,330 +168,64 @@ class DynamixelMotorsBus:
|
||||
if extra_model_control_table:
|
||||
self.model_ctrl_table.update(extra_model_control_table)
|
||||
|
||||
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
||||
if extra_model_resolution:
|
||||
self.model_resolution.update(extra_model_resolution)
|
||||
|
||||
self.port_handler = None
|
||||
self.packet_handler = None
|
||||
self.calibration = None
|
||||
self.is_connected = False
|
||||
self.group_readers = {}
|
||||
self.group_writers = {}
|
||||
self.logs = {}
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
|
||||
)
|
||||
|
||||
self.port_handler = PortHandler(self.port)
|
||||
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
|
||||
|
||||
try:
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port '{self.port}'.")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
print(
|
||||
"\nTry running `python lerobot/common/robot_devices/motors/dynamixel.py` to make sure you are using the correct port.\n"
|
||||
)
|
||||
raise
|
||||
|
||||
# Allow to read and write
|
||||
self.is_connected = True
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port {self.port}")
|
||||
|
||||
self.port_handler.setBaudRate(BAUD_RATE)
|
||||
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
|
||||
|
||||
# Set expected baudrate for the bus
|
||||
self.set_bus_baudrate(BAUDRATE)
|
||||
self.group_readers = {}
|
||||
self.group_writers = {}
|
||||
|
||||
if not self.are_motors_configured():
|
||||
input(
|
||||
"\n/!\\ A configuration issue has been detected with your motors: \n"
|
||||
"If it's the first time that you use these motors, press enter to configure your motors... but before "
|
||||
"verify that all the cables are connected the proper way. If you find an issue, before making a modification, "
|
||||
"kill the python process, unplug the power cord to not damage the motors, rewire correctly, then plug the power "
|
||||
"again and relaunch the script.\n"
|
||||
)
|
||||
print()
|
||||
self.configure_motors()
|
||||
|
||||
def reconnect(self):
|
||||
self.port_handler = PortHandler(self.port)
|
||||
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port '{self.port}'.")
|
||||
self.is_connected = True
|
||||
|
||||
def are_motors_configured(self):
|
||||
# Only check the motor indices and not baudrate, since if the motor baudrates are incorrect,
|
||||
# a ConnectionError will be raised anyway.
|
||||
try:
|
||||
return (self.motor_indices == self.read("ID")).all()
|
||||
except ConnectionError as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
def configure_motors(self):
|
||||
# TODO(rcadene): This script assumes motors follow the X_SERIES baudrates
|
||||
# TODO(rcadene): Refactor this function with intermediate high-level functions
|
||||
|
||||
print("Scanning all baudrates and motor indices")
|
||||
all_baudrates = set(X_SERIES_BAUDRATE_TABLE.values())
|
||||
ids_per_baudrate = {}
|
||||
for baudrate in all_baudrates:
|
||||
self.set_bus_baudrate(baudrate)
|
||||
present_ids = self.find_motor_indices()
|
||||
if len(present_ids) > 0:
|
||||
ids_per_baudrate[baudrate] = present_ids
|
||||
print(f"Motor indices detected: {ids_per_baudrate}")
|
||||
print()
|
||||
|
||||
possible_baudrates = list(ids_per_baudrate.keys())
|
||||
possible_ids = list({idx for sublist in ids_per_baudrate.values() for idx in sublist})
|
||||
untaken_ids = list(set(range(MAX_ID_RANGE)) - set(possible_ids) - set(self.motor_indices))
|
||||
|
||||
# Connect successively one motor to the chain and write a unique random index for each
|
||||
for i in range(len(self.motors)):
|
||||
self.disconnect()
|
||||
input(
|
||||
"1. Unplug the power cord\n"
|
||||
"2. Plug/unplug minimal number of cables to only have the first "
|
||||
f"{i+1} motor(s) ({self.motor_names[:i+1]}) connected.\n"
|
||||
"3. Re-plug the power cord\n"
|
||||
"Press Enter to continue..."
|
||||
)
|
||||
print()
|
||||
self.reconnect()
|
||||
|
||||
if i > 0:
|
||||
try:
|
||||
self._read_with_motor_ids(self.motor_models, untaken_ids[:i], "ID")
|
||||
except ConnectionError:
|
||||
print(f"Failed to read from {untaken_ids[:i+1]}. Make sure the power cord is plugged in.")
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
self.reconnect()
|
||||
|
||||
print("Scanning possible baudrates and motor indices")
|
||||
motor_found = False
|
||||
for baudrate in possible_baudrates:
|
||||
self.set_bus_baudrate(baudrate)
|
||||
present_ids = self.find_motor_indices(possible_ids)
|
||||
if len(present_ids) == 1:
|
||||
present_idx = present_ids[0]
|
||||
print(f"Detected motor with index {present_idx}")
|
||||
|
||||
if baudrate != BAUDRATE:
|
||||
print(f"Setting its baudrate to {BAUDRATE}")
|
||||
baudrate_idx = list(X_SERIES_BAUDRATE_TABLE.values()).index(BAUDRATE)
|
||||
|
||||
# The write can fail, so we allow retries
|
||||
for _ in range(NUM_WRITE_RETRY):
|
||||
self._write_with_motor_ids(
|
||||
self.motor_models, present_idx, "Baud_Rate", baudrate_idx
|
||||
)
|
||||
time.sleep(0.5)
|
||||
self.set_bus_baudrate(BAUDRATE)
|
||||
try:
|
||||
present_baudrate_idx = self._read_with_motor_ids(
|
||||
self.motor_models, present_idx, "Baud_Rate"
|
||||
)
|
||||
except ConnectionError:
|
||||
print("Failed to write baudrate. Retrying.")
|
||||
self.set_bus_baudrate(baudrate)
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise
|
||||
|
||||
if present_baudrate_idx != baudrate_idx:
|
||||
raise OSError("Failed to write baudrate.")
|
||||
|
||||
print(f"Setting its index to a temporary untaken index ({untaken_ids[i]})")
|
||||
self._write_with_motor_ids(self.motor_models, present_idx, "ID", untaken_ids[i])
|
||||
|
||||
present_idx = self._read_with_motor_ids(self.motor_models, untaken_ids[i], "ID")
|
||||
if present_idx != untaken_ids[i]:
|
||||
raise OSError("Failed to write index.")
|
||||
|
||||
motor_found = True
|
||||
break
|
||||
elif len(present_ids) > 1:
|
||||
raise OSError(f"More than one motor detected ({present_ids}), but only one was expected.")
|
||||
|
||||
if not motor_found:
|
||||
raise OSError(
|
||||
"No motor found, but one new motor expected. Verify power cord is plugged in and retry."
|
||||
)
|
||||
print()
|
||||
|
||||
print(f"Setting expected motor indices: {self.motor_indices}")
|
||||
self.set_bus_baudrate(BAUDRATE)
|
||||
self._write_with_motor_ids(
|
||||
self.motor_models, untaken_ids[: len(self.motors)], "ID", self.motor_indices
|
||||
)
|
||||
print()
|
||||
|
||||
if (self.read("ID") != self.motor_indices).any():
|
||||
raise OSError("Failed to write motors indices.")
|
||||
|
||||
print("Configuration is done!")
|
||||
|
||||
def find_motor_indices(self, possible_ids=None):
|
||||
if possible_ids is None:
|
||||
possible_ids = range(MAX_ID_RANGE)
|
||||
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self._read_with_motor_ids(self.motor_models, [idx], "ID")[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
if idx != present_idx:
|
||||
# sanity check
|
||||
raise OSError(
|
||||
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
|
||||
)
|
||||
indices.append(idx)
|
||||
|
||||
return indices
|
||||
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
raise OSError("Failed to write bus baud rate.")
|
||||
self.calibration = None
|
||||
|
||||
@property
|
||||
def motor_names(self) -> list[str]:
|
||||
def motor_names(self) -> list[int]:
|
||||
return list(self.motors.keys())
|
||||
|
||||
@property
|
||||
def motor_models(self) -> list[str]:
|
||||
return [model for _, model in self.motors.values()]
|
||||
|
||||
@property
|
||||
def motor_indices(self) -> list[int]:
|
||||
return [idx for idx, _ in self.motors.values()]
|
||||
|
||||
|
||||
def set_calibration(self, calibration: dict[str, tuple[int, bool]]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor
|
||||
rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range.
|
||||
|
||||
Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
|
||||
when given a goal position that is + or - their resolution. For instance, dynamixel xl330-m077 have a resolution of 4096, and
|
||||
at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
|
||||
or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor.
|
||||
To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
|
||||
in the centered nominal degree range ]-180, 180[.
|
||||
"""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
# Convert from unsigned int32 original range [0, 2**32[ to centered signed int32 range [-2**31, 2**31[
|
||||
values = values.astype(np.int32)
|
||||
|
||||
for i, name in enumerate(motor_names):
|
||||
homing_offset, drive_mode = self.calibration[name]
|
||||
|
||||
# Update direction of rotation of the motor to match between leader and follower. In fact, the motor of the leader for a given joint
|
||||
# can be assembled in an opposite direction in term of rotation than the motor of the follower on the same joint.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from range [-2**31, 2**31[ to nominal range ]-resolution, resolution[ (e.g. ]-2048, 2048[)
|
||||
values[i] += homing_offset
|
||||
|
||||
# Convert from range ]-resolution, resolution[ to the universal float32 centered degree range ]-180, 180[
|
||||
values = values.astype(np.float32)
|
||||
for i, name in enumerate(motor_names):
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
values[i] = values[i] / (resolution // 2) * 180
|
||||
|
||||
return values
|
||||
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
# Convert from the universal float32 centered degree range ]-180, 180[ to resolution range ]-resolution, resolution[
|
||||
for i, name in enumerate(motor_names):
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
values[i] = values[i] / 180 * (resolution // 2)
|
||||
|
||||
values = np.round(values).astype(np.int32)
|
||||
|
||||
# Convert from nominal range ]-resolution, resolution[ to centered signed int32 range [-2**31, 2**31[
|
||||
for i, name in enumerate(motor_names):
|
||||
homing_offset, drive_mode = self.calibration[name]
|
||||
values[i] -= homing_offset
|
||||
|
||||
# Update direction of rotation of the motor that was matching between leader and follower to their original direction.
|
||||
# In fact, the motor of the leader for a given joint can be assembled in an opposite direction in term of rotation
|
||||
# than the motor of the follower on the same joint.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
return values
|
||||
|
||||
def _read_with_motor_ids(self, motor_models, motor_ids, data_name):
|
||||
return_list = True
|
||||
if not isinstance(motor_ids, list):
|
||||
return_list = False
|
||||
motor_ids = [motor_ids]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx in motor_ids:
|
||||
group.addParam(idx)
|
||||
|
||||
comm = group.txRxPacket()
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = group.getData(idx, addr, bytes)
|
||||
values.append(value)
|
||||
|
||||
if return_list:
|
||||
if not self.calibration:
|
||||
return values
|
||||
else:
|
||||
return values[0]
|
||||
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
for i, name in enumerate(motor_names):
|
||||
homing_offset, drive_mode = self.calibration[name]
|
||||
|
||||
start_time = time.perf_counter()
|
||||
if values[i] is not None:
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
values[i] += homing_offset
|
||||
|
||||
return values
|
||||
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
if not self.calibration:
|
||||
return values
|
||||
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
if isinstance(motor_names, str):
|
||||
motor_names = [motor_names]
|
||||
for i, name in enumerate(motor_names):
|
||||
homing_offset, drive_mode = self.calibration[name]
|
||||
|
||||
if values[i] is not None:
|
||||
values[i] -= homing_offset
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
return values
|
||||
|
||||
def read(self, data_name, motor_names: list[str] | None = None):
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
motor_ids = []
|
||||
models = []
|
||||
@@ -650,7 +234,7 @@ class DynamixelMotorsBus:
|
||||
motor_ids.append(motor_idx)
|
||||
models.append(model)
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
# TODO(rcadene): assert all motors follow same address
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
@@ -660,11 +244,7 @@ class DynamixelMotorsBus:
|
||||
for idx in motor_ids:
|
||||
self.group_readers[group_key].addParam(idx)
|
||||
|
||||
for _ in range(NUM_READ_RETRY):
|
||||
comm = self.group_readers[group_key].txRxPacket()
|
||||
if comm == COMM_SUCCESS:
|
||||
break
|
||||
|
||||
comm = self.group_readers[group_key].txRxPacket()
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
|
||||
@@ -678,74 +258,25 @@ class DynamixelMotorsBus:
|
||||
|
||||
values = np.array(values)
|
||||
|
||||
# Convert to signed int to use range [-2048, 2048] for our motor positions.
|
||||
# TODO(rcadene): explain why
|
||||
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
|
||||
values = values.astype(np.int32)
|
||||
values = uint32_to_int32(values)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
||||
if data_name in CALIBRATION_REQUIRED:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
|
||||
# We expect our motors to stay in a nominal range of [-180, 180] degrees
|
||||
# which corresponds to a half turn rotation.
|
||||
# However, some motors can turn a bit more, hence we extend the nominal range to [-270, 270]
|
||||
# which is less than a full 360 degree rotation.
|
||||
if not np.all((values > -270) & (values < 270)):
|
||||
raise ValueError(
|
||||
f"Wrong motor position range detected. "
|
||||
f"Expected to be in [-270, +270] but in [{values.min()}, {values.max()}]. "
|
||||
"This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. "
|
||||
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
|
||||
)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
|
||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||
|
||||
if data_name in CONVERT_POSITION_TO_ANGLE_REQUIRED:
|
||||
values = motor_position_to_angle(values)
|
||||
|
||||
return values
|
||||
|
||||
def _write_with_motor_ids(self, motor_models, motor_ids, data_name, values):
|
||||
if not isinstance(motor_ids, list):
|
||||
motor_ids = [motor_ids]
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes)
|
||||
group.addParam(idx, data)
|
||||
|
||||
comm = group.txPacket()
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
if isinstance(motor_names, str):
|
||||
motor_names = [motor_names]
|
||||
|
||||
if isinstance(values, (int, float, np.integer)):
|
||||
values = [int(values)] * len(motor_names)
|
||||
|
||||
values = np.array(values)
|
||||
|
||||
motor_ids = []
|
||||
models = []
|
||||
for name in motor_names:
|
||||
@@ -753,23 +284,53 @@ class DynamixelMotorsBus:
|
||||
motor_ids.append(motor_idx)
|
||||
models.append(model)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
||||
if isinstance(values, (int, float, np.integer)):
|
||||
values = [int(values)] * len(motor_ids)
|
||||
|
||||
values = np.array(values)
|
||||
|
||||
if data_name in CONVERT_POSITION_TO_ANGLE_REQUIRED:
|
||||
values = motor_angle_to_position(values)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED:
|
||||
values = self.revert_calibration(values, motor_names)
|
||||
|
||||
# TODO(rcadene): why dont we do it?
|
||||
# if data_name in CONVERT_INT32_TO_UINT32_REQUIRED:
|
||||
# values = int32_to_uint32(values)
|
||||
|
||||
values = values.tolist()
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
# TODO(rcadene): assert all motors follow same address
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
init_group = data_name not in self.group_readers
|
||||
if init_group:
|
||||
self.group_writers[group_key] = GroupSyncWrite(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
)
|
||||
self.group_writers[group_key] = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
||||
|
||||
for idx, value in zip(motor_ids, values):
|
||||
if bytes == 1:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
DXL_LOBYTE(DXL_HIWORD(value)),
|
||||
DXL_HIBYTE(DXL_HIWORD(value)),
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead.")
|
||||
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes)
|
||||
if init_group:
|
||||
self.group_writers[group_key].addParam(idx, data)
|
||||
else:
|
||||
@@ -782,35 +343,63 @@ class DynamixelMotorsBus:
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
# def read(self, data_name, motor_name: str):
|
||||
# motor_idx, model = self.motors[motor_name]
|
||||
# addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
# log the utc time when the write has been completed
|
||||
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
|
||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||
# args = (self.port_handler, motor_idx, addr)
|
||||
# if bytes == 1:
|
||||
# value, comm, err = self.packet_handler.read1ByteTxRx(*args)
|
||||
# elif bytes == 2:
|
||||
# value, comm, err = self.packet_handler.read2ByteTxRx(*args)
|
||||
# elif bytes == 4:
|
||||
# value, comm, err = self.packet_handler.read4ByteTxRx(*args)
|
||||
# else:
|
||||
# raise NotImplementedError(
|
||||
# f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
# f"{bytes} is provided instead.")
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first."
|
||||
)
|
||||
# if comm != COMM_SUCCESS:
|
||||
# raise ConnectionError(
|
||||
# f"Read failed due to communication error on port {self.port} for motor {motor_idx}: "
|
||||
# f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
# )
|
||||
# elif err != 0:
|
||||
# raise ConnectionError(
|
||||
# f"Read failed due to error {err} on port {self.port} for motor {motor_idx}: "
|
||||
# f"{self.packet_handler.getTxRxResult(err)}"
|
||||
# )
|
||||
|
||||
if self.port_handler is not None:
|
||||
self.port_handler.closePort()
|
||||
self.port_handler = None
|
||||
# if data_name in CALIBRATION_REQUIRED:
|
||||
# value = self.apply_calibration([value], [motor_name])[0]
|
||||
|
||||
self.packet_handler = None
|
||||
self.group_readers = {}
|
||||
self.group_writers = {}
|
||||
self.is_connected = False
|
||||
# return value
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
# def write(self, data_name, value, motor_name: str):
|
||||
# if data_name in CALIBRATION_REQUIRED:
|
||||
# value = self.revert_calibration([value], [motor_name])[0]
|
||||
|
||||
# motor_idx, model = self.motors[motor_name]
|
||||
# addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
# args = (self.port_handler, motor_idx, addr, value)
|
||||
# if bytes == 1:
|
||||
# comm, err = self.packet_handler.write1ByteTxRx(*args)
|
||||
# elif bytes == 2:
|
||||
# comm, err = self.packet_handler.write2ByteTxRx(*args)
|
||||
# elif bytes == 4:
|
||||
# comm, err = self.packet_handler.write4ByteTxRx(*args)
|
||||
# else:
|
||||
# raise NotImplementedError(
|
||||
# f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but {bytes} "
|
||||
# f"is provided instead.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Helper to find the usb port associated to all your DynamixelMotorsBus.
|
||||
find_port()
|
||||
# if comm != COMM_SUCCESS:
|
||||
# raise ConnectionError(
|
||||
# f"Write failed due to communication error on port {self.port} for motor {motor_idx}: "
|
||||
# f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
# )
|
||||
# elif err != 0:
|
||||
# raise ConnectionError(
|
||||
# f"Write failed due to error {err} on port {self.port} for motor {motor_idx}: "
|
||||
# f"{self.packet_handler.getTxRxResult(err)}"
|
||||
# )
|
||||
713
lerobot/common/robot_devices/motors/feetech.py
Normal file
713
lerobot/common/robot_devices/motors/feetech.py
Normal file
@@ -0,0 +1,713 @@
|
||||
from copy import deepcopy
|
||||
import enum
|
||||
import numpy as np
|
||||
|
||||
from scservo_sdk import PacketHandler, PortHandler, COMM_SUCCESS, GroupSyncRead, GroupSyncWrite
|
||||
from scservo_sdk import SCS_HIBYTE, SCS_HIBYTE, SCS_LOBYTE, SCS_LOWORD
|
||||
|
||||
PROTOCOL_VERSION = 0
|
||||
BAUD_RATE = 1_000_000
|
||||
TIMEOUT_MS = 1000
|
||||
|
||||
|
||||
def u32_to_i32(value: int | np.array) -> int | np.array:
|
||||
"""
|
||||
Convert an unsigned 32-bit integer array to a signed 32-bit integer array.
|
||||
"""
|
||||
|
||||
if isinstance(value, int):
|
||||
if value > 2147483647:
|
||||
value = value - 4294967296
|
||||
else:
|
||||
for i in range(len(value)):
|
||||
if value[i] is not None and value[i] > 2147483647:
|
||||
value[i] = value[i] - 4294967296
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def i32_to_u32(value: int | np.array) -> int | np.array:
|
||||
"""
|
||||
Convert a signed 32-bit integer array to an unsigned 32-bit integer array.
|
||||
"""
|
||||
|
||||
if isinstance(value, int):
|
||||
if value < 0:
|
||||
value = value + 4294967296
|
||||
else:
|
||||
for i in range(len(value)):
|
||||
if value[i] is not None and value[i] < 0:
|
||||
value[i] = value[i] + 4294967296
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def retrieve_ids_and_command(values: np.array, ids: np.array) -> (list[int], np.array):
|
||||
"""
|
||||
Convert the values to a chain command. Skip the None values and return the ids and values.
|
||||
"""
|
||||
|
||||
non_none_values = np.array([value for value in values if value is not None])
|
||||
non_none_values_ids = [ids[i] for i, value in enumerate(values) if value is not None]
|
||||
|
||||
return non_none_values_ids, non_none_values
|
||||
|
||||
|
||||
class TorqueMode(enum.Enum):
|
||||
ENABLED = 1
|
||||
DISABLED = 0
|
||||
|
||||
|
||||
class OperatingMode(enum.Enum):
|
||||
pass
|
||||
|
||||
|
||||
class DriveMode(enum.Enum):
|
||||
pass
|
||||
|
||||
|
||||
SCS_SERIES_CONTROL_TABLE = [
|
||||
("Model", 3, 2),
|
||||
("ID", 5, 1),
|
||||
("Baud_Rate", 6, 1),
|
||||
("Return_Delay", 7, 1),
|
||||
("Response_Status_Level", 8, 1),
|
||||
("Min_Angle_Limit", 9, 2),
|
||||
("Max_Angle_Limit", 11, 2),
|
||||
("Max_Temperature_Limit", 13, 1),
|
||||
("Max_Voltage_Limit", 14, 1),
|
||||
("Min_Voltage_Limit", 15, 1),
|
||||
("Max_Torque_Limit", 16, 2),
|
||||
("Phase", 18, 1),
|
||||
("Unloading_Condition", 19, 1),
|
||||
("LED_Alarm_Condition", 20, 1),
|
||||
("P_Coefficient", 21, 1),
|
||||
("D_Coefficient", 22, 1),
|
||||
("I_Coefficient", 23, 1),
|
||||
("Minimum_Startup_Force", 24, 2),
|
||||
("CW_Dead_Zone", 26, 1),
|
||||
("CCW_Dead_Zone", 27, 1),
|
||||
("Protection_Current", 28, 2),
|
||||
("Angular_Resolution", 30, 1),
|
||||
("Offset", 31, 2),
|
||||
("Mode", 33, 1),
|
||||
("Protective_Torque", 34, 1),
|
||||
("Protection_Time", 35, 1),
|
||||
("Overload_Torque", 36, 1),
|
||||
("Speed_closed_loop_P_proportional_coefficient", 37, 1),
|
||||
("Over_Current_Protection_Time", 38, 1),
|
||||
("Velocity_closed_loop_I_integral_coefficient", 39, 1),
|
||||
("Torque_Enable", 40, 1),
|
||||
("Acceleration", 41, 1),
|
||||
("Goal_Position", 42, 2),
|
||||
("Goal_Time", 44, 2),
|
||||
("Goal_Speed", 46, 2),
|
||||
("Lock", 55, 1),
|
||||
("Present_Position", 56, 2),
|
||||
("Present_Speed", 58, 2),
|
||||
("Present_Load", 60, 2),
|
||||
("Present_Voltage", 62, 1),
|
||||
("Present_Temperature", 63, 1),
|
||||
("Status", 65, 1),
|
||||
("Moving", 66, 1),
|
||||
("Present_Current", 69, 2)
|
||||
]
|
||||
|
||||
MODEL_CONTROL_TABLE = {
|
||||
"scs_series": SCS_SERIES_CONTROL_TABLE,
|
||||
"sts3215": SCS_SERIES_CONTROL_TABLE,
|
||||
}
|
||||
|
||||
|
||||
class FeetechBus:
|
||||
|
||||
def __init__(self, port: str, motor_models: dict[int, str],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None):
|
||||
self.port = port
|
||||
self.motor_models = motor_models
|
||||
|
||||
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||
if extra_model_control_table:
|
||||
self.model_ctrl_table.update(extra_model_control_table)
|
||||
|
||||
# Find read/write addresses and number of bytes for each motor
|
||||
self.motor_ctrl = {}
|
||||
for idx, model in self.motor_models.items():
|
||||
for data_name, addr, bytes in self.model_ctrl_table[model]:
|
||||
if idx not in self.motor_ctrl:
|
||||
self.motor_ctrl[idx] = {}
|
||||
self.motor_ctrl[idx][data_name] = {
|
||||
"addr": addr,
|
||||
"bytes": bytes,
|
||||
}
|
||||
|
||||
self.port_handler = PortHandler(self.port)
|
||||
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
|
||||
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port {self.port}")
|
||||
|
||||
self.port_handler.setBaudRate(BAUD_RATE)
|
||||
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
|
||||
|
||||
self.group_readers = {}
|
||||
self.group_writers = {}
|
||||
|
||||
@property
|
||||
def motor_ids(self) -> list[int]:
|
||||
return list(self.motor_models.keys())
|
||||
|
||||
def close(self):
|
||||
self.port_handler.closePort()
|
||||
|
||||
def write(self, data_name, value, motor_idx: int):
|
||||
|
||||
addr = self.motor_ctrl[motor_idx][data_name]["addr"]
|
||||
bytes = self.motor_ctrl[motor_idx][data_name]["bytes"]
|
||||
args = (self.port_handler, motor_idx, addr, value)
|
||||
if bytes == 1:
|
||||
comm, err = self.packet_handler.write1ByteTxRx(*args)
|
||||
elif bytes == 2:
|
||||
comm, err = self.packet_handler.write2ByteTxRx(*args)
|
||||
elif bytes == 4:
|
||||
comm, err = self.packet_handler.write4ByteTxRx(*args)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but {bytes} "
|
||||
f"is provided instead.")
|
||||
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Write failed due to communication error on port {self.port} for motor {motor_idx}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
elif err != 0:
|
||||
raise ConnectionError(
|
||||
f"Write failed due to error {err} on port {self.port} for motor {motor_idx}: "
|
||||
f"{self.packet_handler.getTxRxResult(err)}"
|
||||
)
|
||||
|
||||
def read(self, data_name, motor_idx: int):
|
||||
addr = self.motor_ctrl[motor_idx][data_name]["addr"]
|
||||
bytes = self.motor_ctrl[motor_idx][data_name]["bytes"]
|
||||
args = (self.port_handler, motor_idx, addr)
|
||||
if bytes == 1:
|
||||
value, comm, err = self.packet_handler.read1ByteTxRx(*args)
|
||||
elif bytes == 2:
|
||||
value, comm, err = self.packet_handler.read2ByteTxRx(*args)
|
||||
elif bytes == 4:
|
||||
value, comm, err = self.packet_handler.read4ByteTxRx(*args)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead.")
|
||||
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Read failed due to communication error on port {self.port} for motor {motor_idx}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
elif err != 0:
|
||||
raise ConnectionError(
|
||||
f"Read failed due to error {err} on port {self.port} for motor {motor_idx}: "
|
||||
f"{self.packet_handler.getTxRxResult(err)}"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
def sync_read(self, data_name, motor_ids: list[int] | None = None):
|
||||
if motor_ids is None:
|
||||
motor_ids = self.motor_ids
|
||||
|
||||
group_key = f"{data_name}_" + "_".join([str(idx) for idx in motor_ids])
|
||||
first_motor_idx = list(self.motor_ctrl.keys())[0]
|
||||
addr = self.motor_ctrl[first_motor_idx][data_name]["addr"]
|
||||
bytes = self.motor_ctrl[first_motor_idx][data_name]["bytes"]
|
||||
|
||||
if data_name not in self.group_readers:
|
||||
self.group_readers[group_key] = GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx in motor_ids:
|
||||
self.group_readers[group_key].addParam(idx)
|
||||
|
||||
comm = self.group_readers[group_key].txRxPacket()
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
||||
values.append(value)
|
||||
|
||||
return np.array(values)
|
||||
|
||||
def sync_write(self, data_name, values: int | list[int], motor_ids: int | list[int] | None = None):
|
||||
if motor_ids is None:
|
||||
motor_ids = self.motor_ids
|
||||
|
||||
if isinstance(motor_ids, int):
|
||||
motor_ids = [motor_ids]
|
||||
|
||||
if isinstance(values, (int, np.integer)):
|
||||
values = [int(values)] * len(motor_ids)
|
||||
|
||||
if isinstance(values, np.ndarray):
|
||||
values = values.tolist()
|
||||
|
||||
group_key = f"{data_name}_" + "_".join([str(idx) for idx in motor_ids])
|
||||
|
||||
first_motor_idx = list(self.motor_ctrl.keys())[0]
|
||||
addr = self.motor_ctrl[first_motor_idx][data_name]["addr"]
|
||||
bytes = self.motor_ctrl[first_motor_idx][data_name]["bytes"]
|
||||
init_group = data_name not in self.group_readers
|
||||
|
||||
if init_group:
|
||||
self.group_writers[group_key] = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
||||
|
||||
for idx, value in zip(motor_ids, values):
|
||||
if bytes == 1:
|
||||
data = [
|
||||
SCS_LOBYTE(SCS_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
data = [
|
||||
SCS_LOBYTE(SCS_LOWORD(value)),
|
||||
SCS_HIBYTE(SCS_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
data = [
|
||||
SCS_LOBYTE(SCS_LOWORD(value)),
|
||||
SCS_HIBYTE(SCS_LOWORD(value)),
|
||||
SCS_LOBYTE(SCS_HIBYTE(value)),
|
||||
SCS_HIBYTE(SCS_HIBYTE(value)),
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but {bytes} "
|
||||
f"is provided instead.")
|
||||
|
||||
if init_group:
|
||||
self.group_writers[group_key].addParam(idx, data)
|
||||
else:
|
||||
self.group_writers[group_key].changeParam(idx, data)
|
||||
|
||||
comm = self.group_writers[group_key].txPacket()
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Write failed due to communication error on port {self.port} for group_key {group_key}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def read_model(self, motor_idx: int):
|
||||
return self.read("Model", motor_idx)
|
||||
|
||||
def sync_read_model(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Model", motor_ids)
|
||||
|
||||
def write_id(self, value, motor_idx: int):
|
||||
self.write("ID", value, motor_idx)
|
||||
|
||||
def read_id(self, motor_idx: int):
|
||||
return self.read("ID", motor_idx)
|
||||
|
||||
def sync_read_id(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("ID", motor_ids)
|
||||
|
||||
def sync_write_id(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("ID", values, motor_ids)
|
||||
|
||||
def write_baud_rate(self, value, motor_idx: int):
|
||||
self.write("Baud_Rate", value, motor_idx)
|
||||
|
||||
def read_baud_rate(self, motor_idx: int):
|
||||
return self.read("Baud_Rate", motor_idx)
|
||||
|
||||
def sync_read_baud_rate(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Baud_Rate", motor_ids)
|
||||
|
||||
def sync_write_baud_rate(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Baud_Rate", values, motor_ids)
|
||||
|
||||
def read_return_delay(self, motor_idx: int):
|
||||
return self.read("Return_Delay", motor_idx)
|
||||
|
||||
def sync_read_return_delay(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Return_Delay", motor_ids)
|
||||
|
||||
def read_response_status_level(self, motor_idx: int):
|
||||
return self.read("Response_Status_Level", motor_idx)
|
||||
|
||||
def sync_read_response_status_level(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Response_Status_Level", motor_ids)
|
||||
|
||||
def write_min_angle_limit(self, value, motor_idx: int):
|
||||
self.write("Min_Angle_Limit", value, motor_idx)
|
||||
|
||||
def read_min_angle_limit(self, motor_idx: int):
|
||||
return self.read("Min_Angle_Limit", motor_idx)
|
||||
|
||||
def sync_read_min_angle_limit(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Min_Angle_Limit", motor_ids)
|
||||
|
||||
def sync_write_min_angle_limit(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Min_Angle_Limit", values, motor_ids)
|
||||
|
||||
def write_max_angle_limit(self, value, motor_idx: int):
|
||||
self.write("Max_Angle_Limit", value, motor_idx)
|
||||
|
||||
def read_max_angle_limit(self, motor_idx: int):
|
||||
return self.read("Max_Angle_Limit", motor_idx)
|
||||
|
||||
def sync_read_max_angle_limit(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Max_Angle_Limit", motor_ids)
|
||||
|
||||
def sync_write_max_angle_limit(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Max_Angle_Limit", values, motor_ids)
|
||||
|
||||
def write_max_temperature_limit(self, value, motor_idx: int):
|
||||
self.write("Max_Temperature_Limit", value, motor_idx)
|
||||
|
||||
def read_max_temperature_limit(self, motor_idx: int):
|
||||
return self.read("Max_Temperature_Limit", motor_idx)
|
||||
|
||||
def sync_read_max_temperature_limit(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Max_Temperature_Limit", motor_ids)
|
||||
|
||||
def sync_write_max_temperature_limit(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Max_Temperature_Limit", values, motor_ids)
|
||||
|
||||
def write_max_voltage_limit(self, value, motor_idx: int):
|
||||
self.write("Max_Voltage_Limit", value, motor_idx)
|
||||
|
||||
def read_max_voltage_limit(self, motor_idx: int):
|
||||
return self.read("Max_Voltage_Limit", motor_idx)
|
||||
|
||||
def sync_read_max_voltage_limit(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Max_Voltage_Limit", motor_ids)
|
||||
|
||||
def sync_write_max_voltage_limit(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Max_Voltage_Limit", values, motor_ids)
|
||||
|
||||
def write_min_voltage_limit(self, value, motor_idx: int):
|
||||
self.write("Min_Voltage_Limit", value, motor_idx)
|
||||
|
||||
def read_min_voltage_limit(self, motor_idx: int):
|
||||
return self.read("Min_Voltage_Limit", motor_idx)
|
||||
|
||||
def sync_read_min_voltage_limit(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Min_Voltage_Limit", motor_ids)
|
||||
|
||||
def sync_write_min_voltage_limit(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Min_Voltage_Limit", values, motor_ids)
|
||||
|
||||
def write_max_torque_limit(self, value, motor_idx: int):
|
||||
self.write("Max_Torque_Limit", value, motor_idx)
|
||||
|
||||
def read_max_torque_limit(self, motor_idx: int):
|
||||
return self.read("Max_Torque_Limit", motor_idx)
|
||||
|
||||
def sync_read_max_torque_limit(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Max_Torque_Limit", motor_ids)
|
||||
|
||||
def sync_write_max_torque_limit(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Max_Torque_Limit", values, motor_ids)
|
||||
|
||||
def write_p_coefficient(self, value, motor_idx: int):
|
||||
self.write("P_Coefficient", value, motor_idx)
|
||||
|
||||
def read_p_coefficient(self, motor_idx: int):
|
||||
return self.read("P_Coefficient", motor_idx)
|
||||
|
||||
def sync_read_p_coefficient(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("P_Coefficient", motor_ids)
|
||||
|
||||
def sync_write_p_coefficient(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("P_Coefficient", values, motor_ids)
|
||||
|
||||
def write_d_coefficient(self, value, motor_idx: int):
|
||||
self.write("D_Coefficient", value, motor_idx)
|
||||
|
||||
def read_d_coefficient(self, motor_idx: int):
|
||||
return self.read("D_Coefficient", motor_idx)
|
||||
|
||||
def sync_read_d_coefficient(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("D_Coefficient", motor_ids)
|
||||
|
||||
def sync_write_d_coefficient(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("D_Coefficient", values, motor_ids)
|
||||
|
||||
def write_i_coefficient(self, value, motor_idx: int):
|
||||
self.write("I_Coefficient", value, motor_idx)
|
||||
|
||||
def read_i_coefficient(self, motor_idx: int):
|
||||
return self.read("I_Coefficient", motor_idx)
|
||||
|
||||
def sync_read_i_coefficient(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("I_Coefficient", motor_ids)
|
||||
|
||||
def sync_write_i_coefficient(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("I_Coefficient", values, motor_ids)
|
||||
|
||||
def write_minimum_startup_force(self, value, motor_idx: int):
|
||||
self.write("Minimum_Startup_Force", value, motor_idx)
|
||||
|
||||
def read_minimum_startup_force(self, motor_idx: int):
|
||||
return self.read("Minimum_Startup_Force", motor_idx)
|
||||
|
||||
def sync_read_minimum_startup_force(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Minimum_Startup_Force", motor_ids)
|
||||
|
||||
def sync_write_minimum_startup_force(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Minimum_Startup_Force", values, motor_ids)
|
||||
|
||||
def write_cw_dead_zone(self, value, motor_idx: int):
|
||||
self.write("CW_Dead_Zone", value, motor_idx)
|
||||
|
||||
def read_cw_dead_zone(self, motor_idx: int):
|
||||
return self.read("CW_Dead_Zone", motor_idx)
|
||||
|
||||
def sync_read_cw_dead_zone(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("CW_Dead_Zone", motor_ids)
|
||||
|
||||
def sync_write_cw_dead_zone(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("CW_Dead_Zone", values, motor_ids)
|
||||
|
||||
def write_ccw_dead_zone(self, value, motor_idx: int):
|
||||
self.write("CCW_Dead_Zone", value, motor_idx)
|
||||
|
||||
def read_ccw_dead_zone(self, motor_idx: int):
|
||||
return self.read("CCW_Dead_Zone", motor_idx)
|
||||
|
||||
def sync_read_ccw_dead_zone(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("CCW_Dead_Zone", motor_ids)
|
||||
|
||||
def sync_write_ccw_dead_zone(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("CCW_Dead_Zone", values, motor_ids)
|
||||
|
||||
def write_protection_current(self, value, motor_idx: int):
|
||||
self.write("Protection_Current", value, motor_idx)
|
||||
|
||||
def read_protection_current(self, motor_idx: int):
|
||||
return self.read("Protection_Current", motor_idx)
|
||||
|
||||
def sync_read_protection_current(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Protection_Current", motor_ids)
|
||||
|
||||
def sync_write_protection_current(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Protection_Current", values, motor_ids)
|
||||
|
||||
def read_angular_resolution(self, motor_idx: int):
|
||||
return self.read("Angular_Resolution", motor_idx)
|
||||
|
||||
def sync_read_angular_resolution(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Angular_Resolution", motor_ids)
|
||||
|
||||
def write_offset(self, value, motor_idx: int):
|
||||
self.write("Offset", value, motor_idx)
|
||||
|
||||
def read_offset(self, motor_idx: int):
|
||||
return self.read("Offset", motor_idx)
|
||||
|
||||
def sync_read_offset(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Offset", motor_ids)
|
||||
|
||||
def sync_write_offset(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Offset", values, motor_ids)
|
||||
|
||||
def write_mode(self, value, motor_idx: int):
|
||||
self.write("Mode", value, motor_idx)
|
||||
|
||||
def read_mode(self, motor_idx: int):
|
||||
return self.read("Mode", motor_idx)
|
||||
|
||||
def sync_read_mode(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Mode", motor_ids)
|
||||
|
||||
def sync_write_mode(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Mode", values, motor_ids)
|
||||
|
||||
def write_protective_torque(self, value, motor_idx: int):
|
||||
self.write("Protective_Torque", value, motor_idx)
|
||||
|
||||
def read_protective_torque(self, motor_idx: int):
|
||||
return self.read("Protective_Torque", motor_idx)
|
||||
|
||||
def sync_read_protective_torque(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Protective_Torque", motor_ids)
|
||||
|
||||
def sync_write_protective_torque(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Protective_Torque", values, motor_ids)
|
||||
|
||||
def read_protection_time(self, motor_idx: int):
|
||||
return self.read("Protection_Time", motor_idx)
|
||||
|
||||
def sync_read_protection_time(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Protection_Time", motor_ids)
|
||||
|
||||
def write_speed_closed_loop_p_proportional_coefficient(self, value, motor_idx: int):
|
||||
self.write("Speed_closed_loop_P_proportional_coefficient", value, motor_idx)
|
||||
|
||||
def read_speed_closed_loop_p_proportional_coefficient(self, motor_idx: int):
|
||||
return self.read("Speed_closed_loop_P_proportional_coefficient", motor_idx)
|
||||
|
||||
def sync_read_speed_closed_loop_p_proportional_coefficient(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Speed_closed_loop_P_proportional_coefficient", motor_ids)
|
||||
|
||||
def sync_write_speed_closed_loop_p_proportional_coefficient(self, values: int | list[int],
|
||||
motor_ids: list[int] | None = None):
|
||||
self.sync_write("Speed_closed_loop_P_proportional_coefficient", values, motor_ids)
|
||||
|
||||
def write_over_current_protection_time(self, value, motor_idx: int):
|
||||
self.write("Over_Current_Protection_Time", value, motor_idx)
|
||||
|
||||
def read_over_current_protection_time(self, motor_idx: int):
|
||||
return self.read("Over_Current_Protection_Time", motor_idx)
|
||||
|
||||
def sync_read_over_current_protection_time(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Over_Current_Protection_Time", motor_ids)
|
||||
|
||||
def sync_write_over_current_protection_time(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Over_Current_Protection_Time", values, motor_ids)
|
||||
|
||||
def write_velocity_closed_loop_i_integral_coefficient(self, value, motor_idx: int):
|
||||
self.write("Velocity_closed_loop_I_integral_coefficient", value, motor_idx)
|
||||
|
||||
def read_velocity_closed_loop_i_integral_coefficient(self, motor_idx: int):
|
||||
return self.read("Velocity_closed_loop_I_integral_coefficient", motor_idx)
|
||||
|
||||
def sync_read_velocity_closed_loop_i_integral_coefficient(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Velocity_closed_loop_I_integral_coefficient", motor_ids)
|
||||
|
||||
def sync_write_velocity_closed_loop_i_integral_coefficient(self, values: int | list[int],
|
||||
motor_ids: list[int] | None = None):
|
||||
self.sync_write("Velocity_closed_loop_I_integral_coefficient", values, motor_ids)
|
||||
|
||||
def write_torque_enable(self, value, motor_idx: int):
|
||||
self.write("Torque_Enable", value, motor_idx)
|
||||
|
||||
def read_torque_enable(self, motor_idx: int):
|
||||
return self.read("Torque_Enable", motor_idx)
|
||||
|
||||
def sync_read_torque_enable(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Torque_Enable", motor_ids)
|
||||
|
||||
def sync_write_torque_enable(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Torque_Enable", values, motor_ids)
|
||||
|
||||
def write_goal_position_u32(self, value, motor_idx: int):
|
||||
self.write("Goal_Position", value, motor_idx)
|
||||
|
||||
def write_goal_position_i32(self, value, motor_idx: int):
|
||||
self.write("Goal_Position", i32_to_u32(value), motor_idx)
|
||||
|
||||
def read_goal_position_u32(self, motor_idx: int):
|
||||
return self.read("Goal_Position", motor_idx)
|
||||
|
||||
def read_goal_position_i32(self, motor_idx: int):
|
||||
goal_position_u32 = self.read_goal_position_u32(motor_idx)
|
||||
|
||||
return u32_to_i32(goal_position_u32)
|
||||
|
||||
def sync_read_goal_position_u32(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Goal_Position", motor_ids)
|
||||
|
||||
def sync_read_goal_position_i32(self, motor_ids: list[int] | None = None):
|
||||
goal_position_u32 = self.sync_read_goal_position_u32(motor_ids)
|
||||
|
||||
return u32_to_i32(goal_position_u32)
|
||||
|
||||
def sync_write_goal_position_u32(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Goal_Position", values, motor_ids)
|
||||
|
||||
def sync_write_goal_position_i32(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Goal_Position", i32_to_u32(values), motor_ids)
|
||||
|
||||
def write_goal_time(self, value, motor_idx: int):
|
||||
self.write("Goal_Time", value, motor_idx)
|
||||
|
||||
def read_goal_time(self, motor_idx: int):
|
||||
return self.read("Goal_Time", motor_idx)
|
||||
|
||||
def sync_read_goal_time(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Goal_Time", motor_ids)
|
||||
|
||||
def sync_write_goal_time(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Goal_Time", values, motor_ids)
|
||||
|
||||
def write_goal_speed(self, value, motor_idx: int):
|
||||
self.write("Goal_Speed", value, motor_idx)
|
||||
|
||||
def read_goal_speed(self, motor_idx: int):
|
||||
return self.read("Goal_Speed", motor_idx)
|
||||
|
||||
def sync_read_goal_speed(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Goal_Speed", motor_ids)
|
||||
|
||||
def sync_write_goal_speed(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Goal_Speed", values, motor_ids)
|
||||
|
||||
def write_lock(self, value, motor_idx: int):
|
||||
self.write("Lock", value, motor_idx)
|
||||
|
||||
def read_lock(self, motor_idx: int):
|
||||
return self.read("Lock", motor_idx)
|
||||
|
||||
def sync_read_lock(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Lock", motor_ids)
|
||||
|
||||
def sync_write_lock(self, values: int | list[int], motor_ids: list[int] | None = None):
|
||||
self.sync_write("Lock", values, motor_ids)
|
||||
|
||||
def read_present_position_u32(self, motor_idx: int):
|
||||
return self.read("Present_Position", motor_idx)
|
||||
|
||||
def read_present_position_i32(self, motor_idx: int):
|
||||
present_position_u32 = self.read_present_position_u32(motor_idx)
|
||||
|
||||
return u32_to_i32(present_position_u32)
|
||||
|
||||
def sync_read_present_position_u32(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Present_Position", motor_ids)
|
||||
|
||||
def sync_read_present_position_i32(self, motor_ids: list[int] | None = None):
|
||||
present_position_u32 = self.sync_read_present_position_u32(motor_ids)
|
||||
|
||||
return u32_to_i32(present_position_u32)
|
||||
|
||||
def read_present_speed(self, motor_idx: int):
|
||||
return self.read("Present_Speed", motor_idx)
|
||||
|
||||
def sync_read_present_speed(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Present_Speed", motor_ids)
|
||||
|
||||
def read_present_load(self, motor_idx: int):
|
||||
return self.read("Present_Load", motor_idx)
|
||||
|
||||
def sync_read_present_load(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Present_Load", motor_ids)
|
||||
|
||||
def read_present_voltage(self, motor_idx: int):
|
||||
return self.read("Present_Voltage", motor_idx)
|
||||
|
||||
def sync_read_present_voltage(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Present_Voltage", motor_ids)
|
||||
|
||||
def read_present_temperature(self, motor_idx: int):
|
||||
return self.read("Present_Temperature", motor_idx)
|
||||
|
||||
def sync_read_present_temperature(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Present_Temperature", motor_ids)
|
||||
|
||||
def read_moving(self, motor_idx: int):
|
||||
return self.read("Moving", motor_idx)
|
||||
|
||||
def sync_read_moving(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Moving", motor_ids)
|
||||
|
||||
def read_present_current(self, motor_idx: int):
|
||||
return self.read("Present_Current", motor_idx)
|
||||
|
||||
def sync_read_present_current(self, motor_ids: list[int] | None = None):
|
||||
return self.sync_read("Present_Current", motor_ids)
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class MotorsBus(Protocol):
|
||||
def motor_names(self): ...
|
||||
def set_calibration(self): ...
|
||||
|
||||
204
lerobot/common/robot_devices/robots/aloha.py
Normal file
204
lerobot/common/robot_devices/robots/aloha.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import copy
|
||||
from dataclasses import dataclass, field, replace
|
||||
import numpy as np
|
||||
import torch
|
||||
from examples.real_robot_example.gym_real_world.robot import Robot
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
|
||||
MAX_LEADER_GRIPPER_RAD = 0.7761942786701344
|
||||
MAX_LEADER_GRIPPER_POS = 2567
|
||||
MAX_FOLLOWER_GRIPPER_RAD = 1.6827769243105486
|
||||
MAX_FOLLOWER_GRIPPER_POS = 3100
|
||||
|
||||
MIN_LEADER_GRIPPER_RAD = -0.12732040539450828
|
||||
MIN_LEADER_GRIPPER_POS = 1984
|
||||
MIN_FOLLOWER_GRIPPER_RAD = 0.6933593161243099
|
||||
MIN_FOLLOWER_GRIPPER_POS = 2512
|
||||
|
||||
GRIPPER_INDEX = -1
|
||||
|
||||
def convert_gripper_range_from_leader_to_follower(leader_pos):
|
||||
follower_goal_pos = copy.copy(leader_pos)
|
||||
follower_goal_pos[GRIPPER_INDEX] = \
|
||||
(leader_pos[GRIPPER_INDEX] - MIN_LEADER_GRIPPER_POS) \
|
||||
/ (MAX_LEADER_GRIPPER_POS - MIN_LEADER_GRIPPER_POS) \
|
||||
* (MAX_FOLLOWER_GRIPPER_POS - MIN_FOLLOWER_GRIPPER_POS) \
|
||||
+ MIN_FOLLOWER_GRIPPER_POS
|
||||
return follower_goal_pos
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlohaRobotConfig:
|
||||
"""
|
||||
Example of usage:
|
||||
```python
|
||||
AlohaRobotConfig()
|
||||
```
|
||||
|
||||
Example of only using left arm:
|
||||
```python
|
||||
AlohaRobotConfig(
|
||||
activated_leaders=["left"],
|
||||
activated_followers=["left"],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
# Define all the components of the robot
|
||||
leader_devices: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"right": {
|
||||
#"port": "/dev/ttyDXL_master_right",
|
||||
"port": "/dev/ttyDXL_master_left",
|
||||
"servos": [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
},
|
||||
"left": {
|
||||
"port": "/dev/ttyDXL_master_left",
|
||||
"servos": [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
},
|
||||
}
|
||||
)
|
||||
follower_devices: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"right": {
|
||||
"port": "/dev/ttyDXL_puppet_right",
|
||||
"servos": [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
},
|
||||
"left": {
|
||||
"port": "/dev/ttyDXL_puppet_left",
|
||||
"servos": [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
},
|
||||
}
|
||||
)
|
||||
camera_devices: dict[str, Camera] = field(
|
||||
default_factory=lambda: {
|
||||
# "cam_high": OpenCVCamera(16),
|
||||
# "cam_low": OpenCVCamera(4),
|
||||
# "cam_left_wrist": OpenCVCamera(10),
|
||||
# "cam_right_wrist": OpenCVCamera(22),
|
||||
}
|
||||
)
|
||||
|
||||
# Allows to easily pick a subset of all devices
|
||||
activated_leaders: list[str] | None = field(
|
||||
default_factory=lambda: ["left", "right"]
|
||||
)
|
||||
activated_followers: list[str] | None = field(
|
||||
default_factory=lambda: ["left", "right"]
|
||||
)
|
||||
activated_cameras: list[str] | None = field(
|
||||
default_factory=lambda: ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
|
||||
)
|
||||
|
||||
|
||||
class AlohaRobot():
|
||||
""" Trossen Robotics
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
robot = AlohaRobot()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: AlohaRobotConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
config = AlohaRobotConfig()
|
||||
# Overwrite config arguments using kwargs
|
||||
config = replace(config, **kwargs)
|
||||
self.config = config
|
||||
|
||||
self.leaders = {}
|
||||
self.followers = {}
|
||||
self.cameras = {}
|
||||
|
||||
if config.activated_leaders:
|
||||
for name in config.activated_leaders:
|
||||
info = config.leader_devices[name]
|
||||
self.leaders[name] = Robot(info["port"], servo_ids=info["servos"])
|
||||
|
||||
if config.activated_followers:
|
||||
for name in config.activated_followers:
|
||||
info = config.follower_devices[name]
|
||||
self.followers[name] = Robot(info["port"], servo_ids=info["servos"])
|
||||
|
||||
if config.activated_cameras:
|
||||
for name in config.activated_cameras:
|
||||
self.cameras[name] = config.camera_devices[name]
|
||||
|
||||
def init_teleop(self):
|
||||
for name in self.followers:
|
||||
self.followers[name]._enable_torque()
|
||||
for name in self.cameras:
|
||||
self.cameras[name].connect()
|
||||
|
||||
def teleop_step(self, record_data=False) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
# Prepare to assign the positions of the leader to the follower
|
||||
leader_pos = {}
|
||||
for name in self.leaders:
|
||||
leader_pos[name] = self.leaders[name].read_position()
|
||||
|
||||
# Update the position of the follower gripper to account for the different minimum and maximum range
|
||||
# position in range [0, 4096[ which corresponds to 4096 bins of 360 degrees
|
||||
# for all our dynamixel servors
|
||||
# gripper id=8 has a different range from leader to follower
|
||||
follower_goal_pos = {}
|
||||
for name in self.leaders:
|
||||
follower_goal_pos[name] = convert_gripper_range_from_leader_to_follower(leader_pos[name])
|
||||
|
||||
# Send action
|
||||
for name in self.followers:
|
||||
self.followers[name].set_goal_pos(follower_goal_pos[name])
|
||||
|
||||
# Early exit when recording data is not requested
|
||||
if not record_data:
|
||||
return
|
||||
|
||||
# Read follower position
|
||||
follower_pos = {}
|
||||
for name in self.followers:
|
||||
follower_pos[name] = self.followers[name].read_position()
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
for name in ["left", "right"]:
|
||||
if name in follower_pos:
|
||||
state.append(follower_pos[name])
|
||||
state = np.concatenate(state)
|
||||
state = pwm2pos(state)
|
||||
|
||||
# Create action by concatenating follower goal position
|
||||
action = []
|
||||
for name in ["left", "right"]:
|
||||
if name in follower_goal_pos:
|
||||
action.append(follower_goal_pos[name])
|
||||
action = np.concatenate(action)
|
||||
action = pwm2pos(action)
|
||||
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
images[name] = self.cameras[name].read()
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
obs_dict, action_dict = {}, {}
|
||||
obs_dict["observation.state"] = torch.from_numpy(state)
|
||||
action_dict["action"] = torch.from_numpy(action)
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
def send_action(self, action):
|
||||
from_idx = 0
|
||||
to_idx = 0
|
||||
follower_goal_pos = {}
|
||||
for name in ["left", "right"]:
|
||||
if name in self.followers:
|
||||
to_idx += len(self.config.follower_devices[name]["servos"])
|
||||
follower_goal_pos[name] = pos2pwm(action[from_idx:to_idx].numpy())
|
||||
from_idx = to_idx
|
||||
|
||||
for name in self.followers:
|
||||
self.followers[name].set_goal_pos(follower_goal_pos[name])
|
||||
@@ -1,7 +1,46 @@
|
||||
import hydra
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def make_robot(cfg: DictConfig):
|
||||
robot = hydra.utils.instantiate(cfg)
|
||||
return robot
|
||||
def make_robot(name):
|
||||
|
||||
if name == "koch":
|
||||
from lerobot.common.robot_devices.robots.koch import KochRobot
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
robot = KochRobot(
|
||||
leader_arms={
|
||||
"main": DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl330-m077"),
|
||||
"shoulder_lift": (2, "xl330-m077"),
|
||||
"elbow_flex": (3, "xl330-m077"),
|
||||
"wrist_flex": (4, "xl330-m077"),
|
||||
"wrist_roll": (5, "xl330-m077"),
|
||||
"gripper": (6, "xl330-m077"),
|
||||
},
|
||||
),
|
||||
},
|
||||
follower_arms={
|
||||
"main": DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0032081",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl430-w250"),
|
||||
"shoulder_lift": (2, "xl430-w250"),
|
||||
"elbow_flex": (3, "xl330-m288"),
|
||||
"wrist_flex": (4, "xl330-m288"),
|
||||
"wrist_roll": (5, "xl330-m288"),
|
||||
"gripper": (6, "xl330-m288"),
|
||||
},
|
||||
),
|
||||
},
|
||||
cameras={
|
||||
"main": OpenCVCamera(1, fps=30, width=640, height=480),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Robot '{name}' not found.")
|
||||
|
||||
return robot
|
||||
@@ -1,50 +1,110 @@
|
||||
import pickle
|
||||
import time
|
||||
import copy
|
||||
from dataclasses import dataclass, field, replace
|
||||
from pathlib import Path
|
||||
|
||||
import pickle
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.motors.dynamixel import (
|
||||
OperatingMode,
|
||||
TorqueMode,
|
||||
convert_degrees_to_steps,
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode, motor_position_to_angle
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
|
||||
|
||||
########################################################################
|
||||
# Calibration logic
|
||||
########################################################################
|
||||
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
# TARGET_HORIZONTAL_POSITION = motor_position_to_angle(np.array([0, -1024, 1024, 0, -1024, 0]))
|
||||
# TARGET_90_DEGREE_POSITION = motor_position_to_angle(np.array([1024, 0, 0, 1024, 0, -1024]))
|
||||
# GRIPPER_OPEN = motor_position_to_angle(np.array([-400]))
|
||||
|
||||
# In nominal degree range ]-180, +180[
|
||||
ZERO_POSITION_DEGREE = 0
|
||||
ROTATED_POSITION_DEGREE = 90
|
||||
GRIPPER_OPEN_DEGREE = 35.156
|
||||
TARGET_HORIZONTAL_POSITION = np.array([0, -1024, 1024, 0, -1024, 0])
|
||||
TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024])
|
||||
GRIPPER_OPEN = np.array([-400])
|
||||
|
||||
def apply_homing_offset(values: np.array, homing_offset: np.array) -> np.array:
|
||||
for i in range(len(values)):
|
||||
if values[i] is not None:
|
||||
values[i] += homing_offset[i]
|
||||
return values
|
||||
|
||||
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
||||
def apply_drive_mode(values: np.array, drive_mode: np.array) -> np.array:
|
||||
for i in range(len(values)):
|
||||
if values[i] is not None and drive_mode[i]:
|
||||
values[i] = -values[i]
|
||||
return values
|
||||
|
||||
def apply_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
|
||||
values = apply_drive_mode(values, drive_mode)
|
||||
values = apply_homing_offset(values, homing_offset)
|
||||
return values
|
||||
|
||||
def revert_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
|
||||
"""
|
||||
Transform working position into real position for the robot.
|
||||
"""
|
||||
values = apply_homing_offset(values, np.array([
|
||||
-homing_offset if homing_offset is not None else None for homing_offset in homing_offset
|
||||
]))
|
||||
values = apply_drive_mode(values, drive_mode)
|
||||
return values
|
||||
|
||||
|
||||
def apply_drive_mode(position, drive_mode):
|
||||
assert_drive_mode(drive_mode)
|
||||
# Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted,
|
||||
# to [-1, 1] with 1 indicates original rotation direction and -1 inverted.
|
||||
signed_drive_mode = -(drive_mode * 2 - 1)
|
||||
position *= signed_drive_mode
|
||||
return position
|
||||
def revert_appropriate_positions(positions: np.array, drive_mode: list[bool]) -> np.array:
|
||||
for i, revert in enumerate(drive_mode):
|
||||
if not revert and positions[i] is not None:
|
||||
positions[i] = -positions[i]
|
||||
return positions
|
||||
|
||||
|
||||
def reset_torque_mode(arm: MotorsBus):
|
||||
def compute_corrections(positions: np.array, drive_mode: list[bool], target_position: np.array) -> np.array:
|
||||
correction = revert_appropriate_positions(positions, drive_mode)
|
||||
|
||||
for i in range(len(positions)):
|
||||
if correction[i] is not None:
|
||||
if drive_mode[i]:
|
||||
correction[i] -= target_position[i]
|
||||
else:
|
||||
correction[i] += target_position[i]
|
||||
|
||||
return correction
|
||||
|
||||
|
||||
def compute_nearest_rounded_positions(positions: np.array) -> np.array:
|
||||
return np.array(
|
||||
[round(positions[i] / 1024) * 1024 if positions[i] is not None else None for i in range(len(positions))])
|
||||
|
||||
|
||||
def compute_homing_offset(arm: DynamixelMotorsBus, drive_mode: list[bool], target_position: np.array) -> np.array:
|
||||
# Get the present positions of the servos
|
||||
present_positions = apply_calibration(
|
||||
arm.read("Present_Position"),
|
||||
np.array([0, 0, 0, 0, 0, 0]),
|
||||
drive_mode)
|
||||
|
||||
nearest_positions = compute_nearest_rounded_positions(present_positions)
|
||||
correction = compute_corrections(nearest_positions, drive_mode, target_position)
|
||||
return correction
|
||||
|
||||
|
||||
def compute_drive_mode(arm: DynamixelMotorsBus, offset: np.array):
|
||||
# Get current positions
|
||||
present_positions = apply_calibration(
|
||||
arm.read("Present_Position"),
|
||||
offset,
|
||||
np.array([False, False, False, False, False, False]))
|
||||
|
||||
nearest_positions = compute_nearest_rounded_positions(present_positions)
|
||||
|
||||
# construct 'drive_mode' list comparing nearest_positions and TARGET_90_DEGREE_POSITION
|
||||
drive_mode = []
|
||||
for i in range(len(nearest_positions)):
|
||||
drive_mode.append(nearest_positions[i] != TARGET_90_DEGREE_POSITION[i])
|
||||
return drive_mode
|
||||
|
||||
|
||||
def reset_arm(arm: MotorsBus):
|
||||
# To be configured, all servos must be in "torque disable" mode
|
||||
arm.write("Torque_Enable", TorqueMode.DISABLED.value)
|
||||
|
||||
@@ -53,95 +113,44 @@ def reset_torque_mode(arm: MotorsBus):
|
||||
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
||||
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
|
||||
if len(all_motors_except_gripper) > 0:
|
||||
arm.write("Operating_Mode", OperatingMode.EXTENDED_POSITION.value, all_motors_except_gripper)
|
||||
arm.write("Operating_Mode", OperatingMode.EXTENDED_POSITION.value, all_motors_except_gripper)
|
||||
|
||||
# Use 'position control current based' for gripper to be limited by the limit of the current.
|
||||
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
|
||||
# it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
|
||||
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
|
||||
# to make it move, and it will move back to its original target position when we release the force.
|
||||
# TODO(rcadene): why?
|
||||
# Use 'position control current based' for gripper
|
||||
arm.write("Operating_Mode", OperatingMode.CURRENT_CONTROLLED_POSITION.value, "gripper")
|
||||
|
||||
# Make sure the native calibration (homing offset abd drive mode) is disabled, since we use our own calibration layer to be more generic
|
||||
arm.write("Homing_Offset", 0)
|
||||
arm.write("Drive_Mode", DriveMode.NON_INVERTED.value)
|
||||
|
||||
def run_arm_calibration(arm: MotorsBus, name: str, arm_type: str):
|
||||
"""This function ensures that a neural network trained on data collected on a given robot
|
||||
can work on another robot. For instance before calibration, setting a same goal position
|
||||
for each motor of two different robots will get two very different positions. But after calibration,
|
||||
the two robots will move to the same position.To this end, this function computes the homing offset
|
||||
and the drive mode for each motor of a given robot.
|
||||
|
||||
Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps
|
||||
to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions
|
||||
being 0. During the calibration process, you will need to manually move the robot to this "zero position".
|
||||
def run_arm_calibration(arm: MotorsBus, name: str):
|
||||
reset_arm(arm)
|
||||
|
||||
Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled
|
||||
in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot
|
||||
to the "rotated position".
|
||||
|
||||
After calibration, the homing offsets and drive modes are stored in a cache.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
run_arm_calibration(arm, "left", "follower")
|
||||
```
|
||||
"""
|
||||
reset_torque_mode(arm)
|
||||
|
||||
print(f"\nRunning calibration of {name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="zero"))
|
||||
# TODO(rcadene): document what position 1 mean
|
||||
print(f"Please move the '{name}' arm to the horizontal position (gripper fully closed)")
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# We arbitrarely choosed our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
# It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will
|
||||
# corresponds to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position.
|
||||
zero_position = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models)
|
||||
horizontal_homing_offset = compute_homing_offset(arm, [False, False, False, False, False, False], TARGET_HORIZONTAL_POSITION)
|
||||
|
||||
def _compute_nearest_rounded_position(position, models):
|
||||
# TODO(rcadene): Rework this function since some motors cant physically rotate a quarter turn
|
||||
# (e.g. the gripper of Aloha arms can only rotate ~50 degree)
|
||||
quarter_turn_degree = 90
|
||||
quarter_turn = convert_degrees_to_steps(quarter_turn_degree, models)
|
||||
nearest_pos = np.round(position.astype(float) / quarter_turn) * quarter_turn
|
||||
return nearest_pos.astype(position.dtype)
|
||||
|
||||
# Compute homing offset so that `present_position + homing_offset ~= target_position`.
|
||||
position = arm.read("Present_Position")
|
||||
position = _compute_nearest_rounded_position(position, arm.motor_models)
|
||||
homing_offset = zero_position - position
|
||||
|
||||
print("\nMove arm to rotated target position")
|
||||
print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="rotated"))
|
||||
# TODO(rcadene): document what position 2 mean
|
||||
print(f"Please move the '{name}' arm to the 90 degree position (gripper fully open)")
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# The rotated target position corresponds to a rotation of a quarter turn from the zero position.
|
||||
# This allows to identify the rotation direction of each motor.
|
||||
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
|
||||
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
|
||||
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
rotated_position = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
drive_mode = compute_drive_mode(arm, horizontal_homing_offset)
|
||||
homing_offset = compute_homing_offset(arm, drive_mode, TARGET_90_DEGREE_POSITION)
|
||||
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
position = arm.read("Present_Position")
|
||||
position += homing_offset
|
||||
position = _compute_nearest_rounded_position(position, arm.motor_models)
|
||||
drive_mode = (position != rotated_position).astype(np.int32)
|
||||
# Invert offset for all drive_mode servos
|
||||
for i in range(len(drive_mode)):
|
||||
if drive_mode[i]:
|
||||
homing_offset[i] = -homing_offset[i]
|
||||
|
||||
# Re-compute homing offset to take into account drive mode
|
||||
position = arm.read("Present_Position")
|
||||
position = apply_drive_mode(position, drive_mode)
|
||||
position = _compute_nearest_rounded_position(position, arm.motor_models)
|
||||
homing_offset = rotated_position - position
|
||||
print("Calibration is done!")
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="rest"))
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
print("=====================================")
|
||||
print(" HOMING_OFFSET: ", " ".join([str(i) for i in homing_offset]))
|
||||
print(" DRIVE_MODE: ", " ".join([str(i) for i in drive_mode]))
|
||||
print("=====================================")
|
||||
|
||||
return homing_offset, drive_mode
|
||||
|
||||
@@ -150,7 +159,6 @@ def run_arm_calibration(arm: MotorsBus, name: str, arm_type: str):
|
||||
# Alexander Koch robot arm
|
||||
########################################################################
|
||||
|
||||
|
||||
@dataclass
|
||||
class KochRobotConfig:
|
||||
"""
|
||||
@@ -160,117 +168,53 @@ class KochRobotConfig:
|
||||
```
|
||||
"""
|
||||
|
||||
# Define all components of the robot
|
||||
leader_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
|
||||
follower_arms: dict[str, MotorsBus] = field(default_factory=lambda: {})
|
||||
cameras: dict[str, Camera] = field(default_factory=lambda: {})
|
||||
# Define all the components of the robot
|
||||
leader_arms: dict[str, MotorsBus] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl330-m077"),
|
||||
"shoulder_lift": (2, "xl330-m077"),
|
||||
"elbow_flex": (3, "xl330-m077"),
|
||||
"wrist_flex": (4, "xl330-m077"),
|
||||
"wrist_roll": (5, "xl330-m077"),
|
||||
"gripper": (6, "xl330-m077"),
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
follower_arms: dict[str, MotorsBus] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0032081",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl430-w250"),
|
||||
"shoulder_lift": (2, "xl430-w250"),
|
||||
"elbow_flex": (3, "xl330-m288"),
|
||||
"wrist_flex": (4, "xl330-m288"),
|
||||
"wrist_roll": (5, "xl330-m288"),
|
||||
"gripper": (6, "xl330-m288"),
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
cameras: dict[str, Camera] = field(
|
||||
default_factory=lambda: {}
|
||||
)
|
||||
|
||||
class KochRobot():
|
||||
""" Tau Robotics: https://tau-robotics.com
|
||||
|
||||
class KochRobot:
|
||||
# TODO(rcadene): Implement force feedback
|
||||
"""This class allows to control any Koch robot of various number of motors.
|
||||
|
||||
A few versions are available:
|
||||
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, which was developed
|
||||
by Alexander Koch from [Tau Robotics](https://tau-robotics.com): [Github for sourcing and assembly](
|
||||
- [Koch v1.1])https://github.com/jess-moss/koch-v1-1), which was developed by Jess Moss.
|
||||
|
||||
Example of highest frequency teleoperation without camera:
|
||||
Example of usage:
|
||||
```python
|
||||
# Defines how to communicate with the motors of the leader and follower arms
|
||||
leader_arms = {
|
||||
"main": DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl330-m077"),
|
||||
"shoulder_lift": (2, "xl330-m077"),
|
||||
"elbow_flex": (3, "xl330-m077"),
|
||||
"wrist_flex": (4, "xl330-m077"),
|
||||
"wrist_roll": (5, "xl330-m077"),
|
||||
"gripper": (6, "xl330-m077"),
|
||||
},
|
||||
),
|
||||
}
|
||||
follower_arms = {
|
||||
"main": DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0032081",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl430-w250"),
|
||||
"shoulder_lift": (2, "xl430-w250"),
|
||||
"elbow_flex": (3, "xl330-m288"),
|
||||
"wrist_flex": (4, "xl330-m288"),
|
||||
"wrist_roll": (5, "xl330-m288"),
|
||||
"gripper": (6, "xl330-m288"),
|
||||
},
|
||||
),
|
||||
}
|
||||
robot = KochRobot(leader_arms, follower_arms)
|
||||
|
||||
# Connect motors buses and cameras if any (Required)
|
||||
robot.connect()
|
||||
|
||||
while True:
|
||||
robot.teleop_step()
|
||||
```
|
||||
|
||||
Example of highest frequency data collection without camera:
|
||||
```python
|
||||
# Assumes leader and follower arms have been instantiated already (see first example)
|
||||
robot = KochRobot(leader_arms, follower_arms)
|
||||
robot.connect()
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
```
|
||||
|
||||
Example of highest frequency data collection with cameras:
|
||||
```python
|
||||
# Defines how to communicate with 2 cameras connected to the computer.
|
||||
# Here, the webcam of the laptop and the phone (connected in USB to the laptop)
|
||||
# can be reached respectively using the camera indices 0 and 1. These indices can be
|
||||
# arbitrary. See the documentation of `OpenCVCamera` to find your own camera indices.
|
||||
cameras = {
|
||||
"laptop": OpenCVCamera(camera_index=0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480),
|
||||
}
|
||||
|
||||
# Assumes leader and follower arms have been instantiated already (see first example)
|
||||
robot = KochRobot(leader_arms, follower_arms, cameras)
|
||||
robot.connect()
|
||||
while True:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
```
|
||||
|
||||
Example of controlling the robot with a policy (without running multiple policies in parallel to ensure highest frequency):
|
||||
```python
|
||||
# Assumes leader and follower arms + cameras have been instantiated already (see previous example)
|
||||
robot = KochRobot(leader_arms, follower_arms, cameras)
|
||||
robot.connect()
|
||||
while True:
|
||||
# Uses the follower arms and cameras to capture an observation
|
||||
observation = robot.capture_observation()
|
||||
|
||||
# Assumes a policy has been instantiated
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# Orders the robot to move
|
||||
robot.send_action(action)
|
||||
```
|
||||
|
||||
Example of disconnecting which is not mandatory since we disconnect when the object is deleted:
|
||||
```python
|
||||
robot.disconnect()
|
||||
robot = KochRobot()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: KochRobotConfig | None = None,
|
||||
calibration_path: Path = ".cache/calibration/koch.pkl",
|
||||
**kwargs,
|
||||
):
|
||||
def __init__(self, config: KochRobotConfig | None = None, calibration_path: Path = ".cache/calibration/koch.pkl", **kwargs):
|
||||
if config is None:
|
||||
config = KochRobotConfig()
|
||||
# Overwrite config arguments using kwargs
|
||||
@@ -280,110 +224,62 @@ class KochRobot:
|
||||
self.leader_arms = self.config.leader_arms
|
||||
self.follower_arms = self.config.follower_arms
|
||||
self.cameras = self.config.cameras
|
||||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
"KochRobot is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
if not self.leader_arms and not self.follower_arms and not self.cameras:
|
||||
raise ValueError(
|
||||
"KochRobot doesn't have any device to connect. See example of usage in docstring of the class."
|
||||
)
|
||||
|
||||
# Connect the arms
|
||||
for name in self.follower_arms:
|
||||
print(f"Connecting {name} follower arm.")
|
||||
self.follower_arms[name].connect()
|
||||
print(f"Connecting {name} leader arm.")
|
||||
self.leader_arms[name].connect()
|
||||
|
||||
# Reset the arms and load or run calibration
|
||||
def init_teleop(self):
|
||||
if self.calibration_path.exists():
|
||||
# Reset all arms before setting calibration
|
||||
for name in self.follower_arms:
|
||||
reset_torque_mode(self.follower_arms[name])
|
||||
for name in self.leader_arms:
|
||||
reset_torque_mode(self.leader_arms[name])
|
||||
reset_arm(self.follower_arms[name])
|
||||
|
||||
with open(self.calibration_path, "rb") as f:
|
||||
for name in self.leader_arms:
|
||||
reset_arm(self.leader_arms[name])
|
||||
|
||||
with open(self.calibration_path, 'rb') as f:
|
||||
calibration = pickle.load(f)
|
||||
else:
|
||||
print(f"Missing calibration file '{self.calibration_path}'. Starting calibration precedure.")
|
||||
# Run calibration process which begins by reseting all arms
|
||||
calibration = self.run_calibration()
|
||||
|
||||
print(f"Calibration is done! Saving calibration file '{self.calibration_path}'")
|
||||
self.calibration_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.calibration_path, "wb") as f:
|
||||
with open(self.calibration_path, 'wb') as f:
|
||||
pickle.dump(calibration, f)
|
||||
|
||||
# Set calibration
|
||||
for name in self.follower_arms:
|
||||
self.follower_arms[name].set_calibration(calibration[f"follower_{name}"])
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].set_calibration(calibration[f"leader_{name}"])
|
||||
|
||||
# Set better PID values to close the gap between recored states and actions
|
||||
# TODO(rcadene): Implement an automatic procedure to set optimial PID values for each motor
|
||||
for name in self.follower_arms:
|
||||
self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex")
|
||||
self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex")
|
||||
self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex")
|
||||
|
||||
# Enable torque on all motors of the follower arms
|
||||
for name in self.follower_arms:
|
||||
print(f"Activating torque on {name} follower arm.")
|
||||
self.follower_arms[name].write("Torque_Enable", 1)
|
||||
|
||||
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
|
||||
# so that we can use it as a trigger to close the gripper of the follower arms.
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].set_calibration(calibration[f"leader_{name}"])
|
||||
# TODO(rcadene): add comments
|
||||
self.leader_arms[name].write("Goal_Position", GRIPPER_OPEN, "gripper")
|
||||
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
||||
self.leader_arms[name].write("Goal_Position", GRIPPER_OPEN_DEGREE, "gripper")
|
||||
|
||||
# Connect the cameras
|
||||
for name in self.cameras:
|
||||
self.cameras[name].connect()
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def run_calibration(self):
|
||||
calibration = {}
|
||||
|
||||
for name in self.follower_arms:
|
||||
homing_offset, drive_mode = run_arm_calibration(self.follower_arms[name], name, "follower")
|
||||
|
||||
homing_offset, drive_mode = run_arm_calibration(self.follower_arms[name], f"{name} follower")
|
||||
|
||||
calibration[f"follower_{name}"] = {}
|
||||
for idx, motor_name in enumerate(self.follower_arms[name].motor_names):
|
||||
calibration[f"follower_{name}"][motor_name] = (homing_offset[idx], drive_mode[idx])
|
||||
|
||||
for name in self.leader_arms:
|
||||
homing_offset, drive_mode = run_arm_calibration(self.leader_arms[name], name, "leader")
|
||||
|
||||
homing_offset, drive_mode = run_arm_calibration(self.leader_arms[name], f"{name} leader")
|
||||
|
||||
calibration[f"leader_{name}"] = {}
|
||||
for idx, motor_name in enumerate(self.leader_arms[name].motor_names):
|
||||
calibration[f"leader_{name}"][motor_name] = (homing_offset[idx], drive_mode[idx])
|
||||
|
||||
return calibration
|
||||
|
||||
def teleop_step(
|
||||
self, record_data=False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"KochRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
# Prepare to assign the position of the leader to the follower
|
||||
def teleop_step(self, record_data=False) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
# Prepare to assign the positions of the leader to the follower
|
||||
leader_pos = {}
|
||||
for name in self.leader_arms:
|
||||
before_lread_t = time.perf_counter()
|
||||
leader_pos[name] = self.leader_arms[name].read("Present_Position")
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
|
||||
|
||||
follower_goal_pos = {}
|
||||
for name in self.leader_arms:
|
||||
@@ -391,21 +287,16 @@ class KochRobot:
|
||||
|
||||
# Send action
|
||||
for name in self.follower_arms:
|
||||
before_fwrite_t = time.perf_counter()
|
||||
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name])
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
||||
|
||||
# Early exit when recording data is not requested
|
||||
if not record_data:
|
||||
return
|
||||
|
||||
# TODO(rcadene): Add velocity and other info
|
||||
# Read follower position
|
||||
follower_pos = {}
|
||||
for name in self.follower_arms:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
@@ -424,33 +315,22 @@ class KochRobot:
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
images[name] = self.cameras[name].read()
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
obs_dict, action_dict = {}, {}
|
||||
obs_dict["observation.state"] = torch.from_numpy(state)
|
||||
action_dict["action"] = torch.from_numpy(action)
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
||||
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
def capture_observation(self):
|
||||
"""The returned observations do not have a batch dimension."""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"KochRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
# Read follower position
|
||||
follower_pos = {}
|
||||
for name in self.follower_arms:
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
@@ -462,10 +342,7 @@ class KochRobot:
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
images[name] = self.cameras[name].read()
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
obs_dict = {}
|
||||
@@ -474,13 +351,7 @@ class KochRobot:
|
||||
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor):
|
||||
"""The provided action is expected to be a vector."""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"KochRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
def send_action(self, action):
|
||||
from_idx = 0
|
||||
to_idx = 0
|
||||
follower_goal_pos = {}
|
||||
@@ -492,24 +363,3 @@ class KochRobot:
|
||||
|
||||
for name in self.follower_arms:
|
||||
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name].astype(np.int32))
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"KochRobot is not connected. You need to run `robot.connect()` before disconnecting."
|
||||
)
|
||||
|
||||
for name in self.follower_arms:
|
||||
self.follower_arms[name].disconnect()
|
||||
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].disconnect()
|
||||
|
||||
for name in self.cameras:
|
||||
self.cameras[name].disconnect()
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class Robot(Protocol):
|
||||
def init_teleop(self): ...
|
||||
def run_calibration(self): ...
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
class RobotDeviceNotConnectedError(Exception):
|
||||
"""Exception raised when the robot device is not connected."""
|
||||
|
||||
def __init__(
|
||||
self, message="This robot device is not connected. Try calling `robot_device.connect()` first."
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class RobotDeviceAlreadyConnectedError(Exception):
|
||||
"""Exception raised when the robot device is already connected."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message="This robot device is already connected. Try not calling `robot_device.connect()` twice.",
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
@@ -1,92 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import threading
|
||||
import time
|
||||
from contextlib import ContextDecorator
|
||||
|
||||
|
||||
class TimeBenchmark(ContextDecorator):
|
||||
"""
|
||||
Measures execution time using a context manager or decorator.
|
||||
|
||||
This class supports both context manager and decorator usage, and is thread-safe for multithreaded
|
||||
environments.
|
||||
|
||||
Args:
|
||||
print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults
|
||||
to False.
|
||||
|
||||
Examples:
|
||||
|
||||
Using as a context manager:
|
||||
|
||||
>>> benchmark = TimeBenchmark()
|
||||
>>> with benchmark:
|
||||
... time.sleep(1)
|
||||
>>> print(f"Block took {benchmark.result:.4f} seconds")
|
||||
Block took approximately 1.0000 seconds
|
||||
|
||||
Using with multithreading:
|
||||
|
||||
```python
|
||||
import threading
|
||||
|
||||
benchmark = TimeBenchmark()
|
||||
|
||||
def context_manager_example():
|
||||
with benchmark:
|
||||
time.sleep(0.01)
|
||||
print(f"Block took {benchmark.result_ms:.2f} milliseconds")
|
||||
|
||||
threads = []
|
||||
for _ in range(3):
|
||||
t1 = threading.Thread(target=context_manager_example)
|
||||
threads.append(t1)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
```
|
||||
Expected output:
|
||||
Block took approximately 10.00 milliseconds
|
||||
Block took approximately 10.00 milliseconds
|
||||
Block took approximately 10.00 milliseconds
|
||||
"""
|
||||
|
||||
def __init__(self, print=False):
|
||||
self.local = threading.local()
|
||||
self.print_time = print
|
||||
|
||||
def __enter__(self):
|
||||
self.local.start_time = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.local.end_time = time.perf_counter()
|
||||
self.local.elapsed_time = self.local.end_time - self.local.start_time
|
||||
if self.print_time:
|
||||
print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds")
|
||||
return False
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return getattr(self.local, "elapsed_time", None)
|
||||
|
||||
@property
|
||||
def result_ms(self):
|
||||
return self.result * 1e3
|
||||
@@ -25,3 +25,42 @@ def write_video(video_path, stacked_frames, fps):
|
||||
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
|
||||
)
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
|
||||
|
||||
import serial
|
||||
import os
|
||||
import time
|
||||
|
||||
def reset_usb_port(port):
|
||||
try:
|
||||
# Close the serial port if it's open
|
||||
ser = serial.Serial(port)
|
||||
ser.close()
|
||||
except serial.serialutil.SerialException as e:
|
||||
print(f"Exception while closing the port: {e}")
|
||||
|
||||
# Find the USB device path
|
||||
usb_device_path = None
|
||||
for root, dirs, files in os.walk('/sys/bus/usb/drivers/usb'):
|
||||
for dir_name in dirs:
|
||||
if port in dir_name:
|
||||
usb_device_path = os.path.join(root, dir_name)
|
||||
break
|
||||
|
||||
if usb_device_path:
|
||||
# Unbind and rebind the USB device
|
||||
try:
|
||||
unbind_path = os.path.join(usb_device_path, 'unbind')
|
||||
bind_path = os.path.join(usb_device_path, 'bind')
|
||||
usb_id = os.path.basename(usb_device_path)
|
||||
|
||||
with open(unbind_path, 'w') as f:
|
||||
f.write(usb_id)
|
||||
time.sleep(1) # Wait for a second
|
||||
with open(bind_path, 'w') as f:
|
||||
f.write(usb_id)
|
||||
print(f"USB port {port} has been reset.")
|
||||
except Exception as e:
|
||||
print(f"Exception during USB reset: {e}")
|
||||
else:
|
||||
print(f"Could not find USB device path for port: {port}")
|
||||
@@ -17,7 +17,7 @@ import logging
|
||||
import os.path as osp
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
@@ -158,7 +158,6 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
|
||||
version_base="1.2",
|
||||
)
|
||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@@ -173,7 +172,3 @@ def print_cuda_memory_usage():
|
||||
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
|
||||
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
|
||||
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
|
||||
|
||||
|
||||
def capture_timestamp_utc():
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
@@ -28,92 +28,21 @@ seed: ???
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datsets are provided.
|
||||
dataset_repo_id: lerobot/pusht
|
||||
video_backend: pyav
|
||||
|
||||
training:
|
||||
offline_steps: ???
|
||||
|
||||
# Number of workers for the offline training dataloader.
|
||||
num_workers: 4
|
||||
|
||||
batch_size: ???
|
||||
|
||||
eval_freq: ???
|
||||
log_freq: 200
|
||||
save_checkpoint: true
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: ???
|
||||
|
||||
# Online training. Note that the online training loop adopts most of the options above apart from the
|
||||
# dataloader options. Unless otherwise specified.
|
||||
# The online training look looks something like:
|
||||
#
|
||||
# for i in range(online_steps):
|
||||
# do_online_rollout_and_update_online_buffer()
|
||||
# for j in range(online_steps_between_rollouts):
|
||||
# batch = next(dataloader_with_offline_and_online_data)
|
||||
# loss = policy(batch)
|
||||
# loss.backward()
|
||||
# optimizer.step()
|
||||
#
|
||||
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
|
||||
online_steps: ???
|
||||
# How many episodes to collect at once when we reach the online rollout part of the training loop.
|
||||
online_rollout_n_episodes: 1
|
||||
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
|
||||
# the policy. Ideally you should set this to by an even divisor or online_rollout_n_episodes.
|
||||
online_rollout_batch_size: 1
|
||||
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
|
||||
online_steps_between_rollouts: null
|
||||
# The proportion of online samples (vs offline samples) to include in the online training batches.
|
||||
online_steps_between_rollouts: ???
|
||||
online_sampling_ratio: 0.5
|
||||
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
|
||||
online_env_seed: null
|
||||
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
|
||||
# FIFO.
|
||||
online_buffer_capacity: null
|
||||
# The minimum number of frames to have in the online buffer before commencing online training.
|
||||
# If online_buffer_seed_size > online_rollout_n_episodes, the rollout will be run multiple times until the
|
||||
# seed size condition is satisfied.
|
||||
online_buffer_seed_size: 0
|
||||
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
|
||||
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
|
||||
# + eval + environment rendering simultaneously.
|
||||
do_online_rollout_async: false
|
||||
|
||||
image_transforms:
|
||||
# These transforms are all using standard torchvision.transforms.v2
|
||||
# You can find out how these transformations affect images here:
|
||||
# https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
|
||||
# We use a custom RandomSubsetApply container to sample them.
|
||||
# For each transform, the following parameters are available:
|
||||
# weight: This represents the multinomial probability (with no replacement)
|
||||
# used for sampling the transform. If the sum of the weights is not 1,
|
||||
# they will be normalized.
|
||||
# min_max: Lower & upper bound respectively used for sampling the transform's parameter
|
||||
# (following uniform distribution) when it's applied.
|
||||
# Set this flag to `true` to enable transforms during training
|
||||
enable: false
|
||||
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
|
||||
# It's an integer in the interval [1, number of available transforms].
|
||||
max_num_transforms: 3
|
||||
# By default, transforms are applied in Torchvision's suggested order (shown below).
|
||||
# Set this to True to apply them in a random order.
|
||||
random_order: false
|
||||
brightness:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
contrast:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
saturation:
|
||||
weight: 1
|
||||
min_max: [0.5, 1.5]
|
||||
hue:
|
||||
weight: 1
|
||||
min_max: [-0.05, 0.05]
|
||||
sharpness:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
# `online_env_seed` is used for environments for online training data rollouts.
|
||||
online_env_seed: ???
|
||||
eval_freq: ???
|
||||
save_freq: ???
|
||||
log_freq: 250
|
||||
save_checkpoint: true
|
||||
num_workers: 4
|
||||
batch_size: ???
|
||||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
@@ -121,6 +50,8 @@ eval:
|
||||
batch_size: 1
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: false
|
||||
# Specify the number of episodes to render during evaluation.
|
||||
max_episodes_rendered: 10
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
|
||||
1
lerobot/configs/env/aloha.yaml
vendored
1
lerobot/configs/env/aloha.yaml
vendored
@@ -9,6 +9,7 @@ env:
|
||||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
real_world: false
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
|
||||
1
lerobot/configs/env/dora_aloha_real.yaml
vendored
1
lerobot/configs/env/dora_aloha_real.yaml
vendored
@@ -9,5 +9,6 @@ env:
|
||||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
real_world: true
|
||||
gym:
|
||||
fps: ${fps}
|
||||
|
||||
10
lerobot/configs/env/koch_real.yaml
vendored
10
lerobot/configs/env/koch_real.yaml
vendored
@@ -1,10 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
1
lerobot/configs/env/pusht.yaml
vendored
1
lerobot/configs/env/pusht.yaml
vendored
@@ -10,6 +10,7 @@ env:
|
||||
action_dim: 2
|
||||
fps: ${fps}
|
||||
episode_length: 300
|
||||
real_world: false
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
|
||||
3
lerobot/configs/env/xarm.yaml
vendored
3
lerobot/configs/env/xarm.yaml
vendored
@@ -9,7 +9,8 @@ env:
|
||||
state_dim: 4
|
||||
action_dim: 4
|
||||
fps: ${fps}
|
||||
episode_length: 200
|
||||
episode_length: 25
|
||||
real_world: false
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
|
||||
@@ -10,10 +10,11 @@ override_dataset_stats:
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 100000
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: 20000
|
||||
save_freq: 20000
|
||||
eval_freq: 10000
|
||||
save_freq: 100000
|
||||
log_freq: 250
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
@@ -75,7 +76,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
@@ -36,10 +36,11 @@ override_dataset_stats:
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 100000
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 20000
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
@@ -107,7 +108,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
@@ -34,10 +34,11 @@ override_dataset_stats:
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 100000
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 20000
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
@@ -103,7 +104,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
@@ -24,8 +24,9 @@ override_dataset_stats:
|
||||
training:
|
||||
offline_steps: 200000
|
||||
online_steps: 0
|
||||
eval_freq: 25000
|
||||
save_freq: 25000
|
||||
eval_freq: 5000
|
||||
save_freq: 5000
|
||||
log_freq: 250
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 64
|
||||
@@ -98,7 +99,7 @@ policy:
|
||||
clip_sample_range: 1.0
|
||||
|
||||
# Inference
|
||||
num_inference_steps: null # if not provided, defaults to `num_train_timesteps`
|
||||
num_inference_steps: 100
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: false
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Defaults for training for the pusht_keypoints dataset.
|
||||
|
||||
# They keypoints are on the vertices of the rectangles that make up the PushT as documented in the PushT
|
||||
# environment:
|
||||
# https://github.com/huggingface/gym-pusht/blob/5e2489be9ff99ed9cd47b6c653dda3b7aa844d24/gym_pusht/envs/pusht.py#L522-L534
|
||||
# For completeness, the diagram is copied here:
|
||||
# 0───────────1
|
||||
# │ │
|
||||
# 3───4───5───2
|
||||
# │ │
|
||||
# │ │
|
||||
# │ │
|
||||
# │ │
|
||||
# 7───6
|
||||
|
||||
|
||||
# Note: The original work trains keypoints-only with conditioning via inpainting. Here, we encode the
|
||||
# observation along with the agent position and use the encoding as global conditioning for the denoising
|
||||
# U-Net.
|
||||
|
||||
# Note: We do not track EMA model weights as we discovered it does not improve the results. See
|
||||
# https://github.com/huggingface/lerobot/pull/134 for more details.
|
||||
|
||||
seed: 100000
|
||||
dataset_repo_id: lerobot/pusht_keypoints
|
||||
|
||||
training:
|
||||
offline_steps: 200000
|
||||
online_steps: 0
|
||||
eval_freq: 5000
|
||||
save_freq: 5000
|
||||
log_freq: 250
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 64
|
||||
grad_clip_norm: 10
|
||||
lr: 1.0e-4
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
adam_betas: [0.95, 0.999]
|
||||
adam_eps: 1.0e-8
|
||||
adam_weight_decay: 1.0e-6
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
observation.environment_state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
||||
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
policy:
|
||||
name: diffusion
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 2
|
||||
horizon: 16
|
||||
n_action_steps: 8
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.environment_state: [16]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.environment_state: min_max
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
crop_shape: [84, 84]
|
||||
crop_is_random: True
|
||||
pretrained_backbone_weights: null
|
||||
use_group_norm: True
|
||||
spatial_softmax_num_keypoints: 32
|
||||
# Unet.
|
||||
down_dims: [256, 512, 1024]
|
||||
kernel_size: 5
|
||||
n_groups: 8
|
||||
diffusion_step_embed_dim: 128
|
||||
use_film_scale_modulation: True
|
||||
# Noise scheduler.
|
||||
noise_scheduler_type: DDIM
|
||||
num_train_timesteps: 100
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
beta_start: 0.0001
|
||||
beta_end: 0.02
|
||||
prediction_type: epsilon # epsilon / sample
|
||||
clip_sample: True
|
||||
clip_sample_range: 1.0
|
||||
|
||||
# Inference
|
||||
num_inference_steps: 10 # if not provided, defaults to `num_train_timesteps`
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: false
|
||||
@@ -4,30 +4,18 @@ seed: 1
|
||||
dataset_repo_id: lerobot/xarm_lift_medium
|
||||
|
||||
training:
|
||||
offline_steps: 50000
|
||||
|
||||
num_workers: 4
|
||||
offline_steps: 25000
|
||||
# TODO(alexander-soare): uncomment when online training gets reinstated
|
||||
online_steps: 0 # 25000 not implemented yet
|
||||
eval_freq: 5000
|
||||
online_steps_between_rollouts: 1
|
||||
online_sampling_ratio: 0.5
|
||||
online_env_seed: 10000
|
||||
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 5000
|
||||
log_freq: 100
|
||||
|
||||
online_steps: 50000
|
||||
online_rollout_n_episodes: 1
|
||||
online_rollout_batch_size: 1
|
||||
# Note: in FOWM `online_steps_between_rollouts` is actually dynamically set to match exactly the length of
|
||||
# the last sampled episode.
|
||||
online_steps_between_rollouts: 50
|
||||
online_sampling_ratio: 0.5
|
||||
online_env_seed: 10000
|
||||
# FOWM Push uses 10000 for `online_buffer_capacity`. Given that their maximum episode length for this task
|
||||
# is 25, 10000 is approx 400 of their episodes worth. Since our episodes are about 8 times longer, we'll use
|
||||
# 80000.
|
||||
online_buffer_capacity: 80000
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
@@ -42,7 +30,6 @@ policy:
|
||||
# Input / output structure.
|
||||
n_action_repeats: 2
|
||||
horizon: 5
|
||||
n_action_steps: 1
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
@@ -67,7 +54,7 @@ policy:
|
||||
discount: 0.9
|
||||
|
||||
# Inference.
|
||||
use_mpc: true
|
||||
use_mpc: false
|
||||
cem_iterations: 6
|
||||
max_std: 2.0
|
||||
min_std: 0.05
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# env=pusht \
|
||||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
# policy=tdmpc_pusht_keypoints \
|
||||
# eval.batch_size=50 \
|
||||
# eval.n_episodes=50 \
|
||||
# eval.use_async_envs=true \
|
||||
# device=cuda \
|
||||
# use_amp=true
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: lerobot/pusht_keypoints
|
||||
|
||||
training:
|
||||
offline_steps: 0
|
||||
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 10000
|
||||
log_freq: 500
|
||||
save_freq: 50000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 40000
|
||||
online_buffer_seed_size: 0
|
||||
do_online_rollout_async: false
|
||||
|
||||
delta_timestamps:
|
||||
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: tdmpc
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 5
|
||||
n_action_steps: 5
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.environment_state: [16]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.environment_state: min_max
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: 32
|
||||
state_encoder_hidden_dim: 256
|
||||
latent_dim: 50
|
||||
q_ensemble_size: 5
|
||||
mlp_dim: 512
|
||||
# Reinforcement learning.
|
||||
discount: 0.98
|
||||
|
||||
# Inference.
|
||||
use_mpc: true
|
||||
cem_iterations: 6
|
||||
max_std: 2.0
|
||||
min_std: 0.05
|
||||
n_gaussian_samples: 512
|
||||
n_pi_samples: 51
|
||||
uncertainty_regularizer_coeff: 1.0
|
||||
n_elites: 50
|
||||
elite_weighting_temperature: 0.5
|
||||
gaussian_mean_momentum: 0.1
|
||||
|
||||
# Training and loss computation.
|
||||
max_random_shift_ratio: 0.0476
|
||||
# Loss coefficients.
|
||||
reward_coeff: 0.5
|
||||
expectile_weight: 0.9
|
||||
value_coeff: 0.1
|
||||
consistency_coeff: 20.0
|
||||
advantage_scaling: 3.0
|
||||
pi_coeff: 0.5
|
||||
temporal_decay_coeff: 0.5
|
||||
# Target model.
|
||||
target_model_momentum: 0.995
|
||||
@@ -1,103 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Defaults for training for the PushT dataset.
|
||||
|
||||
seed: 100000
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
||||
override_dataset_stats:
|
||||
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||
observation.image:
|
||||
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
|
||||
# from the original codebase, but we should remove these and train our own pretrained model
|
||||
observation.state:
|
||||
min: [13.456424, 32.938293]
|
||||
max: [496.14618, 510.9579]
|
||||
action:
|
||||
min: [12.0, 25.0]
|
||||
max: [511.0, 511.0]
|
||||
|
||||
training:
|
||||
offline_steps: 250000
|
||||
online_steps: 0
|
||||
eval_freq: 25000
|
||||
save_freq: 25000
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 64
|
||||
grad_clip_norm: 10
|
||||
lr: 1.0e-4
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
adam_betas: [0.95, 0.999]
|
||||
adam_eps: 1.0e-8
|
||||
adam_weight_decay: 1.0e-6
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
# VQ-BeT specific
|
||||
vqvae_lr: 1.0e-3
|
||||
n_vqvae_training_steps: 20000
|
||||
bet_weight_decay: 2e-4
|
||||
bet_learning_rate: 5.5e-5
|
||||
bet_betas: [0.9, 0.999]
|
||||
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.action_chunk_size} - 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
policy:
|
||||
name: vqbet
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 5
|
||||
n_action_pred_token: 7
|
||||
action_chunk_size: 5
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.image: [3, 96, 96]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.image: mean_std
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
crop_shape: [84, 84]
|
||||
crop_is_random: True
|
||||
pretrained_backbone_weights: null
|
||||
use_group_norm: True
|
||||
spatial_softmax_num_keypoints: 32
|
||||
# VQ-VAE
|
||||
n_vqvae_training_steps: ${training.n_vqvae_training_steps}
|
||||
vqvae_n_embed: 16
|
||||
vqvae_embedding_dim: 256
|
||||
vqvae_enc_hidden_dim: 128
|
||||
# VQ-BeT
|
||||
gpt_block_size: 500
|
||||
gpt_input_dim: 512
|
||||
gpt_output_dim: 512
|
||||
gpt_n_layer: 8
|
||||
gpt_n_head: 8
|
||||
gpt_hidden_dim: 512
|
||||
dropout: 0.1
|
||||
mlp_hidden_dim: 1024
|
||||
offset_loss_weight: 10000.
|
||||
primary_code_loss_weight: 5.0
|
||||
secondary_code_loss_weight: 0.5
|
||||
bet_softmax_temperature: 0.1
|
||||
sequentially_select: False
|
||||
@@ -1,39 +0,0 @@
|
||||
_target_: lerobot.common.robot_devices.robots.koch.KochRobot
|
||||
calibration_path: .cache/calibration/koch.pkl
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0031751
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
shoulder_lift: [2, "xl330-m077"]
|
||||
elbow_flex: [3, "xl330-m077"]
|
||||
wrist_flex: [4, "xl330-m077"]
|
||||
wrist_roll: [5, "xl330-m077"]
|
||||
gripper: [6, "xl330-m077"]
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0032081
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
shoulder_lift: [2, "xl430-w250"]
|
||||
elbow_flex: [3, "xl330-m288"]
|
||||
wrist_flex: [4, "xl330-m288"]
|
||||
wrist_roll: [5, "xl330-m288"]
|
||||
gripper: [6, "xl330-m288"]
|
||||
cameras:
|
||||
laptop:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
phone:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
@@ -1,22 +1,9 @@
|
||||
"""
|
||||
Utilities to control a robot.
|
||||
|
||||
Useful to record a dataset, replay a recorded episode, run the policy on your robot
|
||||
and record an evaluation dataset, and to recalibrate your robot if needed.
|
||||
|
||||
Examples of usage:
|
||||
|
||||
- Recalibrate your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate
|
||||
```
|
||||
Example of usage:
|
||||
|
||||
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate
|
||||
|
||||
# Remove the cameras from the robot definition. They are not used in 'teleoperate' anyway.
|
||||
python lerobot/scripts/control_robot.py teleoperate --robot-overrides '~cameras'
|
||||
```
|
||||
|
||||
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency:
|
||||
@@ -27,7 +14,7 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
|
||||
- Record one episode in order to test replay:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
python lerobot/scripts/control_robot.py record_dataset \
|
||||
--fps 30 \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
@@ -45,123 +32,65 @@ python lerobot/scripts/visualize_dataset.py \
|
||||
|
||||
- Replay this test episode:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
python lerobot/scripts/control_robot.py replay_episode \
|
||||
--fps 30 \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
- Record a full dataset in order to train a policy, with 2 seconds of warmup,
|
||||
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
|
||||
- Record a full dataset in order to train a policy:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
python lerobot/scripts/control_robot.py record_dataset \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id $USER/koch_pick_place_lego \
|
||||
--num-episodes 50 \
|
||||
--warmup-time-s 2 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 10
|
||||
--run-compute-stats 1
|
||||
```
|
||||
|
||||
**NOTE**: You can use your keyboard to control data recording flow.
|
||||
- Tap right arrow key '->' to early exit while recording an episode and go to resseting the environment.
|
||||
- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode.
|
||||
- Tap left arrow key '<-' to early exit and re-record the current episode.
|
||||
- Tap escape key 'esc' to stop the data recording.
|
||||
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
||||
|
||||
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
||||
To avoid resuming by deleting the dataset, use `--force-override 1`.
|
||||
|
||||
- Train on this dataset with the ACT policy:
|
||||
- Train on this dataset (TODO(rcadene)):
|
||||
```bash
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
policy=act_koch_real \
|
||||
env=koch_real \
|
||||
dataset_repo_id=$USER/koch_pick_place_lego \
|
||||
hydra.run.dir=outputs/train/act_koch_real
|
||||
python lerobot/scripts/train.py
|
||||
```
|
||||
|
||||
- Run the pretrained policy on the robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id $USER/eval_act_koch_real \
|
||||
--num-episodes 10 \
|
||||
--warmup-time-s 2 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 10
|
||||
-p outputs/train/act_koch_real/checkpoints/080000/pretrained_model
|
||||
python lerobot/scripts/control_robot.py run_policy \
|
||||
-p TODO(rcadene)
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import tqdm
|
||||
from omegaconf import DictConfig
|
||||
from PIL import Image
|
||||
from termcolor import colored
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from omegaconf import DictConfig
|
||||
import torch
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, load_hf_dataset
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
from lerobot.scripts.push_dataset_to_hub import (
|
||||
push_dataset_card_to_hub,
|
||||
push_meta_data_to_hub,
|
||||
push_videos_to_hub,
|
||||
save_meta_data,
|
||||
)
|
||||
from lerobot.scripts.push_dataset_to_hub import save_meta_data
|
||||
from lerobot.scripts.robot_controls.record_dataset import record_dataset
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Utilities
|
||||
########################################################################################
|
||||
|
||||
|
||||
def say(text, blocking=False):
|
||||
# Check if mac, linux, or windows.
|
||||
if platform.system() == "Darwin":
|
||||
cmd = f'say "{text}"'
|
||||
elif platform.system() == "Linux":
|
||||
cmd = f'spd-say "{text}"'
|
||||
elif platform.system() == "Windows":
|
||||
cmd = (
|
||||
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
|
||||
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
|
||||
)
|
||||
|
||||
if not blocking and platform.system() in ["Darwin", "Linux"]:
|
||||
# TODO(rcadene): Make it work for Windows
|
||||
# Use the ampersand to run command in the background
|
||||
cmd += " &"
|
||||
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
|
||||
img = Image.fromarray(img_tensor.numpy())
|
||||
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
||||
@@ -172,323 +101,97 @@ def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
|
||||
def busy_wait(seconds):
|
||||
# Significantly more accurate than `time.sleep`, and mendatory for our use case,
|
||||
# but it consumes CPU cycles.
|
||||
# TODO(rcadene): find an alternative: from python 11, time.sleep is precise
|
||||
# TODO(rcadene): find an alternative
|
||||
end_time = time.perf_counter() + seconds
|
||||
while time.perf_counter() < end_time:
|
||||
pass
|
||||
|
||||
|
||||
def none_or_int(value):
|
||||
if value == "None":
|
||||
if value == 'None':
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
log_items = []
|
||||
if episode_index is not None:
|
||||
log_items += [f"ep:{episode_index}"]
|
||||
if frame_index is not None:
|
||||
log_items += [f"frame:{frame_index}"]
|
||||
|
||||
def log_dt(shortname, dt_val_s):
|
||||
nonlocal log_items
|
||||
log_items += [f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"]
|
||||
|
||||
# total step time displayed in milliseconds and its frequency
|
||||
log_dt("dt", dt_s)
|
||||
|
||||
for name in robot.leader_arms:
|
||||
key = f"read_leader_{name}_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtRlead", robot.logs[key])
|
||||
|
||||
for name in robot.follower_arms:
|
||||
key = f"write_follower_{name}_goal_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtWfoll", robot.logs[key])
|
||||
|
||||
key = f"read_follower_{name}_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtRfoll", robot.logs[key])
|
||||
|
||||
for name in robot.cameras:
|
||||
key = f"read_camera_{name}_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt(f"dtR{name}", robot.logs[key])
|
||||
|
||||
info_str = " ".join(log_items)
|
||||
if fps is not None:
|
||||
actual_fps = 1 / dt_s
|
||||
if actual_fps < fps - 1:
|
||||
info_str = colored(info_str, "yellow")
|
||||
logging.info(info_str)
|
||||
|
||||
|
||||
@cache
|
||||
def is_headless():
|
||||
"""Detects if python is running without a monitor."""
|
||||
try:
|
||||
import pynput # noqa
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
print(
|
||||
"Error trying to import pynput. Switching to headless mode. "
|
||||
"As a result, the video stream from the cameras won't be shown, "
|
||||
"and you won't be able to change the control flow with keyboards. "
|
||||
"For more info, see traceback below.\n"
|
||||
)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
return True
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Control modes
|
||||
########################################################################################
|
||||
|
||||
def teleoperate(robot: Robot, fps: int | None = None):
|
||||
robot.init_teleop()
|
||||
|
||||
def calibrate(robot: Robot):
|
||||
if robot.calibration_path.exists():
|
||||
print(f"Removing '{robot.calibration_path}'")
|
||||
robot.calibration_path.unlink()
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
# Calling `connect` automatically runs calibration
|
||||
# when the calibration file is missing
|
||||
robot.connect()
|
||||
|
||||
|
||||
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_teleop_t = time.perf_counter()
|
||||
while True:
|
||||
start_loop_t = time.perf_counter()
|
||||
now = time.perf_counter()
|
||||
robot.teleop_step()
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
dt_s = (time.perf_counter() - now)
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
|
||||
break
|
||||
dt_s = (time.perf_counter() - now)
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||
|
||||
|
||||
def record(
|
||||
robot: Robot,
|
||||
policy: torch.nn.Module | None = None,
|
||||
hydra_cfg: DictConfig | None = None,
|
||||
fps: int | None = None,
|
||||
root="data",
|
||||
repo_id="lerobot/debug",
|
||||
warmup_time_s=2,
|
||||
episode_time_s=10,
|
||||
reset_time_s=5,
|
||||
num_episodes=50,
|
||||
video=True,
|
||||
run_compute_stats=True,
|
||||
push_to_hub=True,
|
||||
tags=None,
|
||||
num_image_writers=8,
|
||||
force_override=False,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
# TODO(rcadene): Clean this function via decomposition in higher level functions
|
||||
|
||||
_, dataset_name = repo_id.split("/")
|
||||
if dataset_name.startswith("eval_") and policy is None:
|
||||
raise ValueError(
|
||||
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
||||
)
|
||||
|
||||
def record_dataset(robot: Robot, fps: int | None = None, root="data", repo_id="lerobot/debug", warmup_time_s=2, episode_time_s=10, num_episodes=50, video=True, run_compute_stats=True):
|
||||
if not video:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
robot.init_teleop()
|
||||
|
||||
local_dir = Path(root) / repo_id
|
||||
if local_dir.exists() and force_override:
|
||||
if local_dir.exists():
|
||||
shutil.rmtree(local_dir)
|
||||
|
||||
episodes_dir = local_dir / "episodes"
|
||||
episodes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
videos_dir = local_dir / "videos"
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Logic to resume data recording
|
||||
rec_info_path = episodes_dir / "data_recording_info.json"
|
||||
if rec_info_path.exists():
|
||||
with open(rec_info_path) as f:
|
||||
rec_info = json.load(f)
|
||||
episode_index = rec_info["last_episode_index"] + 1
|
||||
else:
|
||||
episode_index = 0
|
||||
|
||||
if is_headless():
|
||||
logging.info(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
exit_early = False
|
||||
rerecord_episode = False
|
||||
stop_recording = False
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
if not is_headless():
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
nonlocal exit_early, rerecord_episode, stop_recording
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
exit_early = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
rerecord_episode = True
|
||||
exit_early = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
stop_recording = True
|
||||
exit_early = True
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
# Load policy if any
|
||||
if policy is not None:
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
|
||||
# override fps using policy fps
|
||||
fps = hydra_cfg.env.fps
|
||||
|
||||
# Execute a few seconds without recording data, to give times
|
||||
# to the robot devices to connect and start synchronizing.
|
||||
timestamp = 0
|
||||
start_warmup_t = time.perf_counter()
|
||||
is_warmup_print = False
|
||||
while timestamp < warmup_time_s:
|
||||
if not is_warmup_print:
|
||||
logging.info("Warming up (no data recording)")
|
||||
say("Warming up")
|
||||
is_warmup_print = True
|
||||
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if policy is None:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_warmup_t
|
||||
is_record_print = False
|
||||
ep_dicts = []
|
||||
|
||||
# Save images using threads to reach high fps (30 and more)
|
||||
# Using `with` to exist smoothly if an execption is raised.
|
||||
# Using only 4 worker threads to avoid blocking the main thread.
|
||||
futures = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||
# Start recording all episodes
|
||||
while episode_index < num_episodes:
|
||||
logging.info(f"Recording episode {episode_index}")
|
||||
say(f"Recording episode {episode_index}")
|
||||
# Using `with` ensures the program exists smoothly if an execption is raised.
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
for episode_index in range(num_episodes):
|
||||
|
||||
ep_dict = {}
|
||||
frame_index = 0
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < episode_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if policy is None:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
while True:
|
||||
if not is_warmup_print:
|
||||
print("Warming up by skipping frames")
|
||||
os.system('say "Warmup"')
|
||||
is_warmup_print = True
|
||||
now = time.perf_counter()
|
||||
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
timestamp = time.perf_counter() - start_time
|
||||
|
||||
if timestamp < warmup_time_s:
|
||||
dt_s = (time.perf_counter() - now)
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f} (Warmup)")
|
||||
continue
|
||||
|
||||
if not is_record_print:
|
||||
print("Recording")
|
||||
os.system(f'say "Recording episode {episode_index}"')
|
||||
is_record_print = True
|
||||
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
not_image_keys = [key for key in observation if "image" not in key]
|
||||
|
||||
for key in image_keys:
|
||||
futures += [
|
||||
executor.submit(
|
||||
save_image, observation[key], key, frame_index, episode_index, videos_dir
|
||||
)
|
||||
]
|
||||
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
executor.submit(save_image, observation[key], key, frame_index, episode_index, videos_dir)
|
||||
|
||||
for key in not_image_keys:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
ep_dict[key].append(observation[key])
|
||||
|
||||
if policy is not None:
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if device.type == "cuda" and hydra_cfg.use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
|
||||
# Order the robot to move
|
||||
robot.send_action(action)
|
||||
action = {"action": action}
|
||||
|
||||
for key in action:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
@@ -496,42 +199,37 @@ def record(
|
||||
|
||||
frame_index += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
dt_s = (time.perf_counter() - now)
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
dt_s = (time.perf_counter() - now)
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if exit_early:
|
||||
exit_early = False
|
||||
if timestamp > episode_time_s - warmup_time_s:
|
||||
break
|
||||
|
||||
if not stop_recording:
|
||||
# Start resetting env while the executor are finishing
|
||||
logging.info("Reset the environment")
|
||||
say("Reset the environment")
|
||||
print("Encoding to `LeRobotDataset` format")
|
||||
os.system('say "Encoding"')
|
||||
|
||||
timestamp = 0
|
||||
start_vencod_t = time.perf_counter()
|
||||
|
||||
# During env reset we save the data and encode the videos
|
||||
num_frames = frame_index
|
||||
|
||||
for key in image_keys:
|
||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
video_path = local_dir / "videos" / fname
|
||||
if video_path.exists():
|
||||
video_path.unlink()
|
||||
# Store the reference to the video frame, even tho the videos are not yet encoded
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
# shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the reference to the video frame
|
||||
ep_dict[key] = []
|
||||
for i in range(num_frames):
|
||||
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
|
||||
|
||||
for key in not_image_keys:
|
||||
ep_dict[key] = torch.stack(ep_dict[key])
|
||||
|
||||
|
||||
for key in action:
|
||||
ep_dict[key] = torch.stack(ep_dict[key])
|
||||
|
||||
@@ -543,76 +241,8 @@ def record(
|
||||
done[-1] = True
|
||||
ep_dict["next.done"] = done
|
||||
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
print("Saving episode dictionary...")
|
||||
torch.save(ep_dict, ep_path)
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
rec_info = {
|
||||
"last_episode_index": episode_index,
|
||||
}
|
||||
with open(rec_info_path, "w") as f:
|
||||
json.dump(rec_info, f)
|
||||
|
||||
is_last_episode = stop_recording or (episode_index == (num_episodes - 1))
|
||||
|
||||
# Wait if necessary
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
while timestamp < reset_time_s and not is_last_episode:
|
||||
time.sleep(1)
|
||||
timestamp = time.perf_counter() - start_vencod_t
|
||||
pbar.update(1)
|
||||
if exit_early:
|
||||
exit_early = False
|
||||
break
|
||||
|
||||
# Skip updating episode index which forces re-recording episode
|
||||
if rerecord_episode:
|
||||
rerecord_episode = False
|
||||
continue
|
||||
|
||||
episode_index += 1
|
||||
|
||||
if is_last_episode:
|
||||
logging.info("Done recording")
|
||||
say("Done recording", blocking=True)
|
||||
if not is_headless():
|
||||
listener.stop()
|
||||
|
||||
logging.info("Waiting for threads writing the images on disk to terminate...")
|
||||
for _ in tqdm.tqdm(
|
||||
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
|
||||
):
|
||||
pass
|
||||
break
|
||||
|
||||
robot.disconnect()
|
||||
if not is_headless():
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
num_episodes = episode_index
|
||||
|
||||
logging.info("Encoding videos")
|
||||
say("Encoding videos")
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
for key in image_keys:
|
||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
video_path = local_dir / "videos" / fname
|
||||
if video_path.exists():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
logging.info("Concatenating episodes")
|
||||
ep_dicts = []
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
ep_dict = torch.load(ep_path)
|
||||
ep_dicts.append(ep_dict)
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
@@ -621,12 +251,16 @@ def record(
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
|
||||
for key in image_keys:
|
||||
time.sleep(10)
|
||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
# shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
@@ -636,35 +270,19 @@ def record(
|
||||
videos_dir=videos_dir,
|
||||
)
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
say("Computing dataset statistics")
|
||||
stats = compute_stats(lerobot_dataset)
|
||||
lerobot_dataset.stats = stats
|
||||
else:
|
||||
stats = {}
|
||||
logging.info("Skipping computation of the dataset statistics")
|
||||
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
if push_to_hub:
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
logging.info("Exiting")
|
||||
say("Exiting")
|
||||
return lerobot_dataset
|
||||
# TODO(rcadene): push to hub
|
||||
|
||||
|
||||
def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
||||
local_dir = Path(root) / repo_id
|
||||
if not local_dir.exists():
|
||||
raise ValueError(local_dir)
|
||||
@@ -674,22 +292,51 @@ def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo
|
||||
from_idx = dataset.episode_data_index["from"][episode].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
robot.init_teleop()
|
||||
|
||||
print("Replaying episode")
|
||||
os.system('say "Replaying episode"')
|
||||
|
||||
logging.info("Replaying episode")
|
||||
say("Replaying episode", blocking=True)
|
||||
for idx in range(from_idx, to_idx):
|
||||
start_episode_t = time.perf_counter()
|
||||
now = time.perf_counter()
|
||||
|
||||
action = items[idx]["action"]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
dt_s = (time.perf_counter() - now)
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
dt_s = (time.perf_counter() - now)
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||
|
||||
|
||||
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig):
|
||||
policy.eval()
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
|
||||
fps = hydra_cfg.env.fps
|
||||
|
||||
while True:
|
||||
now = time.perf_counter()
|
||||
|
||||
observation = robot.capture_observation()
|
||||
|
||||
with torch.inference_mode(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||
action = policy.select_action(observation)
|
||||
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = (time.perf_counter() - now)
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -698,165 +345,56 @@ if __name__ == "__main__":
|
||||
|
||||
# Set common options for all the subparsers
|
||||
base_parser = argparse.ArgumentParser(add_help=False)
|
||||
base_parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
base_parser.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
parser_calib = subparsers.add_parser("calibrate", parents=[base_parser])
|
||||
base_parser.add_argument("--robot", type=str, default="koch", help="Name of the robot provided to the `make_robot(name)` factory function.")
|
||||
|
||||
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
|
||||
parser_teleop.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_teleop.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
|
||||
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default="data",
|
||||
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot/test",
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--warmup-time-s",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--episode-time-s",
|
||||
type=int,
|
||||
default=60,
|
||||
help="Number of seconds for data recording for each episode.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--reset-time-s",
|
||||
type=int,
|
||||
default=60,
|
||||
help="Number of seconds for resetting the environment after each episode.",
|
||||
)
|
||||
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
|
||||
parser_record.add_argument(
|
||||
"--run-compute-stats",
|
||||
type=int,
|
||||
default=1,
|
||||
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--push-to-hub",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Upload dataset to Hugging Face hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--tags",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Add tags to your dataset on the hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-image-writers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of threads writing the frames as png images on disk. Don't set too much as you might get unstable fps due to main thread being blocked.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--force-override",
|
||||
type=int,
|
||||
default=0,
|
||||
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
type=str,
|
||||
parser_record = subparsers.add_parser("record_dataset", parents=[base_parser])
|
||||
parser_record.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
|
||||
parser_record.add_argument('--root', type=Path, default="data", help='')
|
||||
parser_record.add_argument('--repo-id', type=str, default="lerobot/test", help='')
|
||||
parser_record.add_argument('--warmup-time-s', type=int, default=2, help='')
|
||||
parser_record.add_argument('--episode-time-s', type=int, default=10, help='')
|
||||
parser_record.add_argument('--num-episodes', type=int, default=50, help='')
|
||||
parser_record.add_argument('--run-compute-stats', type=int, default=1, help='')
|
||||
|
||||
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
|
||||
parser_replay.add_argument('--fps', type=none_or_int, default=None, help='Frames per second (set to None to disable)')
|
||||
parser_replay.add_argument('--root', type=Path, default="data", help='')
|
||||
parser_replay.add_argument('--repo-id', type=str, default="lerobot/test", help='')
|
||||
parser_replay.add_argument('--episode', type=int, default=0, help='')
|
||||
|
||||
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
|
||||
parser_policy.add_argument('-p', '--pretrained-policy-name-or-path', type=str,
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`."
|
||||
),
|
||||
)
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--policy-overrides",
|
||||
type=str,
|
||||
parser_policy.add_argument(
|
||||
"overrides",
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default="data",
|
||||
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot/test",
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
init_logging()
|
||||
|
||||
control_mode = args.mode
|
||||
robot_path = args.robot_path
|
||||
robot_overrides = args.robot_overrides
|
||||
robot_name = args.robot
|
||||
kwargs = vars(args)
|
||||
del kwargs["mode"]
|
||||
del kwargs["robot_path"]
|
||||
del kwargs["robot_overrides"]
|
||||
del kwargs["robot"]
|
||||
|
||||
robot_cfg = init_hydra_config(robot_path, robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
|
||||
if control_mode == "calibrate":
|
||||
calibrate(robot, **kwargs)
|
||||
|
||||
elif control_mode == "teleoperate":
|
||||
robot = make_robot(robot_name)
|
||||
if control_mode == "teleoperate":
|
||||
teleoperate(robot, **kwargs)
|
||||
elif control_mode == "record_dataset":
|
||||
record_dataset(robot, **kwargs)
|
||||
elif control_mode == "replay_episode":
|
||||
replay_episode(robot, **kwargs)
|
||||
|
||||
elif control_mode == "record":
|
||||
pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
|
||||
policy_overrides = args.policy_overrides
|
||||
del kwargs["pretrained_policy_name_or_path"]
|
||||
del kwargs["policy_overrides"]
|
||||
|
||||
policy_cfg = None
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
policy_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||
policy = make_policy(hydra_cfg=policy_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
record(robot, policy, policy_cfg, **kwargs)
|
||||
else:
|
||||
record(robot, **kwargs)
|
||||
|
||||
elif control_mode == "replay":
|
||||
replay(robot, **kwargs)
|
||||
|
||||
if robot.is_connected:
|
||||
# Disconnect manually to avoid a "Core dump" during process
|
||||
# termination due to camera threads not properly exiting.
|
||||
robot.disconnect()
|
||||
elif control_mode == "run_policy":
|
||||
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path)
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", args.overrides)
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
run_policy(robot, policy, hydra_cfg)
|
||||
|
||||
@@ -13,71 +13,39 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Use this script to get a quick summary of your system config.
|
||||
It should be able to run without any of LeRobot's dependencies or LeRobot itself installed.
|
||||
"""
|
||||
|
||||
import platform
|
||||
|
||||
HAS_HF_HUB = True
|
||||
HAS_HF_DATASETS = True
|
||||
HAS_NP = True
|
||||
HAS_TORCH = True
|
||||
HAS_LEROBOT = True
|
||||
import huggingface_hub
|
||||
|
||||
try:
|
||||
import huggingface_hub
|
||||
except ImportError:
|
||||
HAS_HF_HUB = False
|
||||
# import dataset
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import datasets
|
||||
except ImportError:
|
||||
HAS_HF_DATASETS = False
|
||||
from lerobot import __version__ as version
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
HAS_NP = False
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
HAS_TORCH = False
|
||||
|
||||
try:
|
||||
import lerobot
|
||||
except ImportError:
|
||||
HAS_LEROBOT = False
|
||||
|
||||
|
||||
lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A"
|
||||
hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A"
|
||||
hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A"
|
||||
np_version = np.__version__ if HAS_NP else "N/A"
|
||||
|
||||
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
||||
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
||||
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||
pt_version = torch.__version__
|
||||
pt_cuda_available = torch.cuda.is_available()
|
||||
pt_cuda_available = torch.cuda.is_available()
|
||||
cuda_version = torch._C._cuda_getCompiledVersion() if torch.version.cuda is not None else "N/A"
|
||||
|
||||
|
||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||
def display_sys_info() -> dict:
|
||||
"""Run this to get basic system info to help for tracking issues & bugs."""
|
||||
info = {
|
||||
"`lerobot` version": lerobot_version,
|
||||
"`lerobot` version": version,
|
||||
"Platform": platform.platform(),
|
||||
"Python version": platform.python_version(),
|
||||
"Huggingface_hub version": hf_hub_version,
|
||||
"Dataset version": hf_datasets_version,
|
||||
"Numpy version": np_version,
|
||||
"PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})",
|
||||
"Huggingface_hub version": huggingface_hub.__version__,
|
||||
# TODO(aliberts): Add dataset when https://github.com/huggingface/lerobot/pull/73 is merged
|
||||
# "Dataset version": dataset.__version__,
|
||||
"Numpy version": np.__version__,
|
||||
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
||||
"Cuda version": cuda_version,
|
||||
"Using GPU in script?": "<fill in>",
|
||||
# "Using distributed or parallel set-up in script?": "<fill in>",
|
||||
"Using distributed or parallel set-up in script?": "<fill in>",
|
||||
}
|
||||
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
|
||||
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
||||
print(format_dict(info))
|
||||
return info
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ https://huggingface.co/lerobot/diffusion_pusht/tree/main.
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
@@ -56,13 +57,16 @@ import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value, concatenate_datasets
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from torch import Tensor, nn
|
||||
from PIL import Image as PILImage
|
||||
from torch import Tensor
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.logger import log_output_dir
|
||||
@@ -96,13 +100,13 @@ def rollout(
|
||||
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||
environment termination/truncation).
|
||||
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
"don": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
the first True is followed by True's all the way till the end. This can be used for masking
|
||||
extraneous elements from the sequences above.
|
||||
|
||||
Args:
|
||||
env: The batch of environments.
|
||||
policy: The policy. Must be a PyTorch nn module.
|
||||
policy: The policy.
|
||||
seeds: The environments are seeded once at the start of the rollout. If provided, this argument
|
||||
specifies the seeds for each of the environments.
|
||||
return_observations: Whether to include all observations in the returned rollout data. Observations
|
||||
@@ -113,7 +117,6 @@ def rollout(
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||
device = get_device_from_parameters(policy)
|
||||
|
||||
# Reset the policy and environments.
|
||||
@@ -162,7 +165,10 @@ def rollout(
|
||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||
# available of none of the envs finished.
|
||||
if "final_info" in info:
|
||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
||||
successes = [
|
||||
i["is_success"] if (i is not None and "is_success" in i) else False
|
||||
for i in info["final_info"]
|
||||
]
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
@@ -207,7 +213,7 @@ def eval_policy(
|
||||
policy: torch.nn.Module,
|
||||
n_episodes: int,
|
||||
max_episodes_rendered: int = 0,
|
||||
videos_dir: Path | None = None,
|
||||
video_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
enable_progbar: bool = False,
|
||||
@@ -219,7 +225,7 @@ def eval_policy(
|
||||
policy: The policy.
|
||||
n_episodes: The number of episodes to evaluate.
|
||||
max_episodes_rendered: Maximum number of episodes to render into videos.
|
||||
videos_dir: Where to save rendered videos.
|
||||
video_dir: Where to save rendered videos.
|
||||
return_episode_data: Whether to return episode data for online training. Incorporates the data into
|
||||
the "episodes" key of the returned dictionary.
|
||||
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
|
||||
@@ -229,10 +235,6 @@ def eval_policy(
|
||||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"""
|
||||
if max_episodes_rendered > 0 and not videos_dir:
|
||||
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
|
||||
|
||||
assert isinstance(policy, Policy)
|
||||
start = time.time()
|
||||
policy.eval()
|
||||
|
||||
@@ -273,16 +275,11 @@ def eval_policy(
|
||||
if max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = []
|
||||
|
||||
if start_seed is None:
|
||||
seeds = None
|
||||
else:
|
||||
seeds = range(
|
||||
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||
)
|
||||
seeds = range(start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs))
|
||||
rollout_data = rollout(
|
||||
env,
|
||||
policy,
|
||||
seeds=list(seeds) if seeds else None,
|
||||
seeds=seeds,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
enable_progbar=enable_inner_progbar,
|
||||
@@ -292,8 +289,7 @@ def eval_policy(
|
||||
# this won't be included).
|
||||
n_steps = rollout_data["done"].shape[1]
|
||||
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||
done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
|
||||
|
||||
done_indices = torch.argmax(rollout_data["done"].to(int), axis=1) # (batch_size, rollout_steps)
|
||||
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
||||
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
||||
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
||||
@@ -304,28 +300,48 @@ def eval_policy(
|
||||
max_rewards.extend(batch_max_rewards.tolist())
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
all_successes.extend(batch_successes.tolist())
|
||||
if seeds:
|
||||
all_seeds.extend(seeds)
|
||||
else:
|
||||
all_seeds.append(None)
|
||||
all_seeds.extend(seeds)
|
||||
|
||||
# FIXME: episode_data is either None or it doesn't exist
|
||||
if return_episode_data:
|
||||
this_episode_data = _compile_episode_data(
|
||||
rollout_data,
|
||||
done_indices,
|
||||
start_episode_index=batch_ix * env.num_envs,
|
||||
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
||||
start_data_index=(
|
||||
0 if episode_data is None else (episode_data["episode_data_index"]["to"][-1].item())
|
||||
),
|
||||
fps=env.unwrapped.metadata["render_fps"],
|
||||
)
|
||||
if episode_data is None:
|
||||
episode_data = this_episode_data
|
||||
else:
|
||||
# Some sanity checks to make sure we are correctly compiling the data.
|
||||
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
|
||||
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
|
||||
# Some sanity checks to make sure we are not correctly compiling the data.
|
||||
assert (
|
||||
episode_data["hf_dataset"]["episode_index"][-1] + 1
|
||||
== this_episode_data["hf_dataset"]["episode_index"][0]
|
||||
)
|
||||
assert (
|
||||
episode_data["hf_dataset"]["index"][-1] + 1 == this_episode_data["hf_dataset"]["index"][0]
|
||||
)
|
||||
assert torch.equal(
|
||||
episode_data["episode_data_index"]["to"][-1],
|
||||
this_episode_data["episode_data_index"]["from"][0],
|
||||
)
|
||||
# Concatenate the episode data.
|
||||
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
|
||||
episode_data = {
|
||||
"hf_dataset": concatenate_datasets(
|
||||
[episode_data["hf_dataset"], this_episode_data["hf_dataset"]]
|
||||
),
|
||||
"episode_data_index": {
|
||||
k: torch.cat(
|
||||
[
|
||||
episode_data["episode_data_index"][k],
|
||||
this_episode_data["episode_data_index"][k],
|
||||
]
|
||||
)
|
||||
for k in ["from", "to"]
|
||||
},
|
||||
}
|
||||
|
||||
# Maybe render video for visualization.
|
||||
if max_episodes_rendered > 0 and len(ep_frames) > 0:
|
||||
@@ -335,9 +351,8 @@ def eval_policy(
|
||||
):
|
||||
if n_episodes_rendered >= max_episodes_rendered:
|
||||
break
|
||||
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||
video_paths.append(str(video_path))
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
@@ -407,55 +422,108 @@ def _compile_episode_data(
|
||||
Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`).
|
||||
"""
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
total_frames = 0
|
||||
data_index_from = start_data_index
|
||||
for ep_ix in range(rollout_data["action"].shape[0]):
|
||||
# + 2 to include the first done frame and the last observation frame.
|
||||
num_frames = done_indices[ep_ix].item() + 2
|
||||
num_frames = done_indices[ep_ix].item() + 1 # + 1 to include the first done frame
|
||||
total_frames += num_frames
|
||||
|
||||
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
|
||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
ep_dict = {
|
||||
"action": rollout_data["action"][ep_ix, : num_frames - 1],
|
||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
|
||||
"frame_index": torch.arange(0, num_frames - 1, 1),
|
||||
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
|
||||
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
|
||||
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
|
||||
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
|
||||
"action": rollout_data["action"][ep_ix, :num_frames],
|
||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * num_frames),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
"next.done": rollout_data["done"][ep_ix, :num_frames],
|
||||
"next.reward": rollout_data["reward"][ep_ix, :num_frames].type(torch.float32),
|
||||
}
|
||||
|
||||
# For the last observation frame, all other keys will just be copy padded.
|
||||
for k in ep_dict:
|
||||
ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]])
|
||||
|
||||
for key in rollout_data["observation"]:
|
||||
ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames]
|
||||
|
||||
ep_dict[key] = rollout_data["observation"][key][ep_ix][:num_frames]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
episode_data_index["from"].append(data_index_from)
|
||||
episode_data_index["to"].append(data_index_from + num_frames)
|
||||
|
||||
data_index_from += num_frames
|
||||
|
||||
data_dict = {}
|
||||
for key in ep_dicts[0]:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
if "image" not in key:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for img in ep_dict[key]:
|
||||
# sanity check that images are channel first
|
||||
c, h, w = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
|
||||
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
|
||||
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
|
||||
|
||||
# from float32 in range [0,1] to uint8 in range [0,255]
|
||||
img *= 255
|
||||
img = img.type(torch.uint8)
|
||||
|
||||
# convert to channel last and numpy as expected by PIL
|
||||
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
|
||||
|
||||
data_dict[key].append(img)
|
||||
|
||||
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
||||
episode_data_index["from"] = torch.tensor(episode_data_index["from"])
|
||||
episode_data_index["to"] = torch.tensor(episode_data_index["to"])
|
||||
|
||||
return data_dict
|
||||
# TODO(rcadene): clean this
|
||||
features = {}
|
||||
for key in rollout_data["observation"]:
|
||||
if "image" in key:
|
||||
features[key] = Image()
|
||||
else:
|
||||
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
|
||||
features.update(
|
||||
{
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_index": Value(dtype="int64", id=None),
|
||||
"frame_index": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
}
|
||||
)
|
||||
features = Features(features)
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return {
|
||||
"hf_dataset": hf_dataset,
|
||||
"episode_data_index": episode_data_index,
|
||||
}
|
||||
|
||||
|
||||
def main(
|
||||
pretrained_policy_path: Path | None = None,
|
||||
def eval(
|
||||
pretrained_policy_path: str | None = None,
|
||||
hydra_cfg_path: str | None = None,
|
||||
out_dir: str | None = None,
|
||||
config_overrides: list[str] | None = None,
|
||||
):
|
||||
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
|
||||
if pretrained_policy_path is not None:
|
||||
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
|
||||
if hydra_cfg_path is None:
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
|
||||
else:
|
||||
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
||||
out_dir = (
|
||||
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||
)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
if out_dir is None:
|
||||
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||
raise NotImplementedError()
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
@@ -471,12 +539,10 @@ def main(
|
||||
|
||||
logging.info("Making policy.")
|
||||
if hydra_cfg_path is None:
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
else:
|
||||
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||
@@ -484,8 +550,8 @@ def main(
|
||||
env,
|
||||
policy,
|
||||
hydra_cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(out_dir) / "videos",
|
||||
max_episodes_rendered=hydra_cfg.eval.max_episodes_rendered,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
start_seed=hydra_cfg.seed,
|
||||
enable_progbar=True,
|
||||
enable_inner_progbar=True,
|
||||
@@ -503,7 +569,9 @@ def main(
|
||||
|
||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||
try:
|
||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||
pretrained_policy_path = Path(
|
||||
snapshot_download(pretrained_policy_name_or_path, revision=revision)
|
||||
)
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
@@ -548,13 +616,6 @@ if __name__ == "__main__":
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
"Where to save the evaluation outputs. If not provided, outputs are saved in "
|
||||
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"overrides",
|
||||
nargs="*",
|
||||
@@ -563,14 +624,8 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.pretrained_policy_name_or_path is None:
|
||||
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
|
||||
eval(hydra_cfg_path=args.config, config_overrides=args.overrides)
|
||||
else:
|
||||
pretrained_policy_path = get_pretrained_policy_path(
|
||||
args.pretrained_policy_name_or_path, revision=args.revision
|
||||
)
|
||||
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path, revision=args.revision)
|
||||
|
||||
main(
|
||||
pretrained_policy_path=pretrained_policy_path,
|
||||
out_dir=args.out_dir,
|
||||
config_overrides=args.overrides,
|
||||
)
|
||||
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
|
||||
|
||||
@@ -18,36 +18,54 @@ Use this script to convert your dataset into LeRobot dataset format and upload i
|
||||
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
||||
installation of neural net specific packages like pytorch, tensorflow, jax.
|
||||
|
||||
Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
|
||||
Example:
|
||||
```
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir data/pusht_raw \
|
||||
--data-dir data \
|
||||
--dataset-id pusht \
|
||||
--raw-format pusht_zarr \
|
||||
--repo-id lerobot/pusht
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir data/xarm_lift_medium_raw \
|
||||
--data-dir data \
|
||||
--dataset-id xarm_lift_medium \
|
||||
--raw-format xarm_pkl \
|
||||
--repo-id lerobot/xarm_lift_medium
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir data/aloha_sim_insertion_scripted_raw \
|
||||
--data-dir data \
|
||||
--dataset-id aloha_sim_insertion_scripted \
|
||||
--raw-format aloha_hdf5 \
|
||||
--repo-id lerobot/aloha_sim_insertion_scripted
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir data/umi_cup_in_the_wild_raw \
|
||||
--data-dir data \
|
||||
--dataset-id umi_cup_in_the_wild \
|
||||
--raw-format umi_zarr \
|
||||
--repo-id lerobot/umi_cup_in_the_wild
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
@@ -55,23 +73,21 @@ from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
from lerobot.common.datasets.utils import flatten_dict
|
||||
|
||||
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format):
|
||||
if raw_format == "pusht_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "umi_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "dora_parquet":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_dora":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "xarm_pkl":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "cam_png":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
|
||||
@@ -80,9 +96,7 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||
return from_raw_to_lerobot_format
|
||||
|
||||
|
||||
def save_meta_data(
|
||||
info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
|
||||
):
|
||||
def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# save info
|
||||
@@ -100,7 +114,7 @@ def save_meta_data(
|
||||
save_file(episode_data_index, ep_data_idx_path)
|
||||
|
||||
|
||||
def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
|
||||
def push_meta_data_to_hub(repo_id, meta_data_dir, revision):
|
||||
"""Expect all meta data files to be all stored in a single "meta_data" directory.
|
||||
On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
|
||||
"""
|
||||
@@ -114,15 +128,7 @@ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str
|
||||
)
|
||||
|
||||
|
||||
def push_dataset_card_to_hub(
|
||||
repo_id: str, revision: str | None, tags: list | None = None, text: str | None = None
|
||||
):
|
||||
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
||||
card = create_lerobot_dataset_card(tags=tags, text=text)
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
||||
|
||||
|
||||
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
|
||||
def push_videos_to_hub(repo_id, videos_dir, revision):
|
||||
"""Expect mp4 files to be all stored in a single "videos" directory.
|
||||
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
||||
"""
|
||||
@@ -138,71 +144,55 @@ def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | Non
|
||||
|
||||
|
||||
def push_dataset_to_hub(
|
||||
raw_dir: Path,
|
||||
raw_format: str,
|
||||
repo_id: str,
|
||||
push_to_hub: bool = True,
|
||||
local_dir: Path | None = None,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 8,
|
||||
episodes: list[int] | None = None,
|
||||
force_override: bool = False,
|
||||
resume: bool = False,
|
||||
cache_dir: Path = Path("/tmp"),
|
||||
tests_data_dir: Path | None = None,
|
||||
encoding: dict | None = None,
|
||||
data_dir: Path,
|
||||
dataset_id: str,
|
||||
raw_format: str | None,
|
||||
community_id: str,
|
||||
revision: str,
|
||||
dry_run: bool,
|
||||
save_to_disk: bool,
|
||||
tests_data_dir: Path,
|
||||
save_tests_to_disk: bool,
|
||||
fps: int | None,
|
||||
video: bool,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
debug: bool,
|
||||
):
|
||||
check_repo_id(repo_id)
|
||||
user_id, dataset_id = repo_id.split("/")
|
||||
repo_id = f"{community_id}/{dataset_id}"
|
||||
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
|
||||
out_dir = data_dir / repo_id
|
||||
meta_data_dir = out_dir / "meta_data"
|
||||
videos_dir = out_dir / "videos"
|
||||
|
||||
tests_out_dir = tests_data_dir / repo_id
|
||||
tests_meta_data_dir = tests_out_dir / "meta_data"
|
||||
tests_videos_dir = tests_out_dir / "videos"
|
||||
|
||||
if out_dir.exists():
|
||||
shutil.rmtree(out_dir)
|
||||
|
||||
if tests_out_dir.exists() and save_tests_to_disk:
|
||||
shutil.rmtree(tests_out_dir)
|
||||
|
||||
# Robustify when `raw_dir` is str instead of Path
|
||||
raw_dir = Path(raw_dir)
|
||||
if not raw_dir.exists():
|
||||
raise NotADirectoryError(
|
||||
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub: "
|
||||
f"`python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw`"
|
||||
)
|
||||
|
||||
if local_dir:
|
||||
# Robustify when `local_dir` is str instead of Path
|
||||
local_dir = Path(local_dir)
|
||||
|
||||
# Send warning if local_dir isn't well formated
|
||||
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
# Check we don't override an existing `local_dir` by mistake
|
||||
if local_dir.exists():
|
||||
if force_override:
|
||||
shutil.rmtree(local_dir)
|
||||
elif not resume:
|
||||
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
|
||||
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
videos_dir = local_dir / "videos"
|
||||
else:
|
||||
# Temporary directory used to store images, videos, meta_data
|
||||
meta_data_dir = Path(cache_dir) / "meta_data"
|
||||
videos_dir = Path(cache_dir) / "videos"
|
||||
download_raw(raw_dir, dataset_id)
|
||||
|
||||
if raw_format is None:
|
||||
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
|
||||
raise NotImplementedError()
|
||||
# raw_format = auto_find_raw_format(raw_dir)
|
||||
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
raw_dir, videos_dir, fps, video, episodes, encoding
|
||||
)
|
||||
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
version=revision,
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
@@ -210,82 +200,102 @@ def push_dataset_to_hub(
|
||||
)
|
||||
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
|
||||
|
||||
if local_dir:
|
||||
if save_to_disk:
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
hf_dataset.save_to_disk(str(out_dir / "train"))
|
||||
|
||||
if push_to_hub or local_dir:
|
||||
if not dry_run or save_to_disk:
|
||||
# mandatory for upload
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
if push_to_hub:
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
if not dry_run:
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
||||
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_dataset_card_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
|
||||
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
push_videos_to_hub(repo_id, videos_dir, revision=revision)
|
||||
|
||||
if tests_data_dir:
|
||||
if save_tests_to_disk:
|
||||
# get the first episode
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
||||
episode_data_index = {k: v[:1] for k, v in episode_data_index.items()}
|
||||
|
||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
|
||||
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
|
||||
|
||||
tests_meta_data = tests_data_dir / repo_id / "meta_data"
|
||||
save_meta_data(info, stats, episode_data_index, tests_meta_data)
|
||||
save_meta_data(info, stats, episode_data_index, tests_meta_data_dir)
|
||||
|
||||
# copy videos of first episode to tests directory
|
||||
episode_index = 0
|
||||
tests_videos_dir = tests_data_dir / repo_id / "videos"
|
||||
tests_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
for key in lerobot_dataset.video_frame_keys:
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
||||
|
||||
if local_dir is None:
|
||||
# clear cache
|
||||
shutil.rmtree(meta_data_dir)
|
||||
shutil.rmtree(videos_dir)
|
||||
|
||||
return lerobot_dataset
|
||||
if not save_to_disk and out_dir.exists():
|
||||
# remove possible temporary files remaining in the output directory
|
||||
shutil.rmtree(out_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--raw-dir",
|
||||
"--data-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
||||
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).",
|
||||
)
|
||||
# TODO(rcadene): add automatic detection of the format
|
||||
parser.add_argument(
|
||||
"--raw-format",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).",
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
"--community-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
default="lerobot",
|
||||
help="Community or user ID under which the dataset will be hosted on the Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
|
||||
"--revision",
|
||||
type=str,
|
||||
default=CODEBASE_VERSION,
|
||||
help="Codebase version used to generate the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
"--dry-run",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-to-disk",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Upload to hub.",
|
||||
help="Save the dataset in the directory specified by `--data-dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
default="tests/data",
|
||||
help="Directory containing tests artifacts datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-tests-to-disk",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
@@ -311,30 +321,10 @@ def main():
|
||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="*",
|
||||
help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-override",
|
||||
"--debug",
|
||||
type=int,
|
||||
default=0,
|
||||
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=int,
|
||||
default=0,
|
||||
help="When set to 1, resumes a previous run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
help=(
|
||||
"When provided, save tests artifacts into the given directory "
|
||||
"(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
|
||||
),
|
||||
help="Debug mode process the first episode only.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
211
lerobot/scripts/robot_controls/koch_configure.py
Normal file
211
lerobot/scripts/robot_controls/koch_configure.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
LCR Auto Configure: This program is used to automatically configure the Low Cost Robot (LCR) for the user.
|
||||
|
||||
The program will:
|
||||
1. Disable all torque motors of provided LCR.
|
||||
2. Ask the user to move the LCR to the position 1 (see CONFIGURING.md for more details).
|
||||
3. Record the position of the LCR.
|
||||
4. Ask the user to move the LCR to the position 2 (see CONFIGURING.md for more details).
|
||||
5. Record the position of the LCR.
|
||||
6. Ask the user to move back the LCR to the position 1.
|
||||
7. Record the position of the LCR.
|
||||
8. Calculate the offset of the LCR and save it to the configuration file.
|
||||
|
||||
It will also enable all appropriate operating modes for the LCR according if the LCR is a puppet or a master.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelBus, OperatingMode, DriveMode, TorqueMode
|
||||
|
||||
|
||||
def pause():
|
||||
"""
|
||||
Pause the program until the user presses the enter key.
|
||||
"""
|
||||
input("Press Enter to continue...")
|
||||
|
||||
|
||||
def prepare_configuration(arm: DynamixelBus):
|
||||
"""
|
||||
Prepare the configuration for the LCR.
|
||||
:param arm: DynamixelBus
|
||||
"""
|
||||
|
||||
# To be configured, all servos must be in "torque disable" mode
|
||||
arm.sync_write_torque_enable(TorqueMode.DISABLED.value)
|
||||
|
||||
# We need to work with 'extended position mode' (4) for all servos, because in joint mode (1) the servos can't
|
||||
# rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm,
|
||||
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
||||
arm.sync_write_operating_mode(OperatingMode.EXTENDED_POSITION.value, [1, 2, 3, 4, 5])
|
||||
|
||||
# Gripper is always 'position control current based' (5)
|
||||
arm.write_operating_mode(OperatingMode.CURRENT_CONTROLLED_POSITION.value, 6)
|
||||
|
||||
# We need to reset the homing offset for all servos
|
||||
arm.sync_write_homing_offset(0)
|
||||
|
||||
# We need to work with 'normal drive mode' (0) for all servos
|
||||
arm.sync_write_drive_mode(DriveMode.NON_INVERTED.value)
|
||||
|
||||
|
||||
def invert_appropriate_positions(positions: np.array, inverted: list[bool]) -> np.array:
|
||||
"""
|
||||
Invert the appropriate positions.
|
||||
:param positions: numpy array of positions
|
||||
:param inverted: list of booleans to determine if the position should be inverted
|
||||
:return: numpy array of inverted positions
|
||||
"""
|
||||
for i, invert in enumerate(inverted):
|
||||
if not invert and positions[i] is not None:
|
||||
positions[i] = -positions[i]
|
||||
|
||||
return positions
|
||||
|
||||
|
||||
def calculate_corrections(positions: np.array, inverted: list[bool]) -> np.array:
|
||||
"""
|
||||
Calculate the corrections for the positions.
|
||||
:param positions: numpy array of positions
|
||||
:param inverted: list of booleans to determine if the position should be inverted
|
||||
:return: numpy array of corrections
|
||||
"""
|
||||
|
||||
wanted = wanted_position_1()
|
||||
|
||||
correction = invert_appropriate_positions(positions, inverted)
|
||||
|
||||
for i in range(len(positions)):
|
||||
if correction[i] is not None:
|
||||
if inverted[i]:
|
||||
correction[i] -= wanted[i]
|
||||
else:
|
||||
correction[i] += wanted[i]
|
||||
|
||||
return correction
|
||||
|
||||
|
||||
def calculate_nearest_rounded_positions(positions: np.array) -> np.array:
|
||||
"""
|
||||
Calculate the nearest rounded positions.
|
||||
:param positions: numpy array of positions
|
||||
:return: numpy array of nearest rounded positions
|
||||
"""
|
||||
|
||||
return np.array(
|
||||
[round(positions[i] / 1024) * 1024 if positions[i] is not None else None for i in range(len(positions))])
|
||||
|
||||
|
||||
def configure_homing(arm: DynamixelBus, inverted: list[bool]) -> np.array:
|
||||
"""
|
||||
Configure the homing for the LCR.
|
||||
:param arm: DynamixelBus
|
||||
:param inverted: list of booleans to determine if the position should be inverted
|
||||
"""
|
||||
|
||||
# Reset homing offset for the servos
|
||||
arm.sync_write_homing_offset(0)
|
||||
|
||||
# Get the present positions of the servos
|
||||
present_positions = arm.sync_read_present_position_i32()
|
||||
|
||||
nearest_positions = calculate_nearest_rounded_positions(present_positions)
|
||||
|
||||
correction = calculate_corrections(nearest_positions, inverted)
|
||||
|
||||
# Write the homing offset for the servos
|
||||
arm.sync_write_homing_offset(correction)
|
||||
|
||||
|
||||
def configure_drive_mode(arm: DynamixelBus):
|
||||
"""
|
||||
Configure the drive mode for the LCR.
|
||||
:param arm: DynamixelBus
|
||||
:param homing: numpy array of homing
|
||||
"""
|
||||
|
||||
# Get current positions
|
||||
present_positions = arm.sync_read_present_position_i32()
|
||||
|
||||
nearest_positions = calculate_nearest_rounded_positions(present_positions)
|
||||
|
||||
# construct 'inverted' list comparing nearest_positions and wanted_position_2
|
||||
inverted = []
|
||||
|
||||
for i in range(len(nearest_positions)):
|
||||
inverted.append(nearest_positions[i] != wanted_position_2()[i])
|
||||
|
||||
# Write the drive mode for the servos
|
||||
arm.sync_write_drive_mode(
|
||||
[DriveMode.INVERTED.value if i else DriveMode.NON_INVERTED.value for i in inverted])
|
||||
|
||||
return inverted
|
||||
|
||||
|
||||
def wanted_position_1() -> np.array:
|
||||
"""
|
||||
The present position wanted in position 1 for the arm
|
||||
"""
|
||||
return np.array([0, -1024, 1024, 0, 0, 0])
|
||||
|
||||
|
||||
def wanted_position_2() -> np.array:
|
||||
"""
|
||||
The present position wanted in position 2 for the arm
|
||||
"""
|
||||
return np.array([1024, 0, 0, 1024, 1024, -1024])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LCR Auto Configure: This program is used to automatically configure the Low Cost Robot (LCR) for "
|
||||
"the user.")
|
||||
|
||||
parser.add_argument("--port", type=str, required=True, help="The port of the LCR.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
arm = DynamixelBus(
|
||||
args.port, {
|
||||
1: "x_series",
|
||||
2: "x_series",
|
||||
3: "x_series",
|
||||
4: "x_series",
|
||||
5: "x_series",
|
||||
6: "x_series",
|
||||
}
|
||||
)
|
||||
|
||||
prepare_configuration(arm)
|
||||
|
||||
# Ask the user to move the LCR to the position 1
|
||||
print("Please move the LCR to the position 1")
|
||||
pause()
|
||||
|
||||
configure_homing(arm, [False, False, False, False, False, False])
|
||||
|
||||
# Ask the user to move the LCR to the position 2
|
||||
print("Please move the LCR to the position 2")
|
||||
pause()
|
||||
|
||||
inverted = configure_drive_mode(arm)
|
||||
|
||||
# Ask the user to move back the LCR to the position 1
|
||||
print("Please move back the LCR to the position 1")
|
||||
pause()
|
||||
|
||||
configure_homing(arm, inverted)
|
||||
|
||||
print("Configuration done!")
|
||||
print("Make sure everything is working properly:")
|
||||
|
||||
while True:
|
||||
positions = arm.sync_read_present_position_i32()
|
||||
print(positions)
|
||||
|
||||
time.sleep(1)
|
||||
20
lerobot/scripts/robot_controls/record_dataset.py
Normal file
20
lerobot/scripts/robot_controls/record_dataset.py
Normal file
@@ -0,0 +1,20 @@
|
||||
|
||||
|
||||
import time
|
||||
from lerobot.common.robot_devices.robots.aloha import AlohaRobot
|
||||
import torch
|
||||
|
||||
|
||||
def record_dataset():
|
||||
robot = AlohaRobot(use_cameras=True)
|
||||
robot.init_teleop()
|
||||
|
||||
while True:
|
||||
now = time.time()
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
dt_s = (time.time() - now)
|
||||
print(f"Latency (ms): {dt_s * 1000:.2f}\tFrequency: {1 / dt_s:.2f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
record_dataset()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user